afc_contract_ocr.py 2.38 KB
# -*- coding: utf-8 -*-
# @Author        : lk
# @Email         : 9428.al@gmail.com
# @Created Date  : 2021-06-29 17:43:46
# @Last Modified : 2021-09-07 14:11:25
# @Description   :

from .get_char import Finder
from .get_char_fsm import Finder as FSMFinder
import numpy as np


def extract_info(ocr_results):
    contract_no = {
        "words": None,
        "position": None
    } 
    for bbox, text in ocr_results.get('0', {}).values():
        if text.startswith('CH-B'):
            contract_no['words'] = text 
            contract_no['position'] = [bbox[0], bbox[1], bbox[2], bbox[-1]] 
            break

    return {'page_1': {'合同编号': contract_no}}


def predict(pdf_info, is_qrs=False, is_fsm=False):
    pop_seceond_page_info = {}
    if not is_fsm and not is_qrs and len(pdf_info) == 9:
        pop_seceond_page_info = pdf_info.pop('1', {})
        for pno in range(8):
            if pno == 0:
                pdf_info[str(pno)]['blocks'].extend(pop_seceond_page_info.get('blocks', []))
            else:
                pdf_info[str(pno)] = pdf_info.pop(str(pno+1))

    ocr_results = {}
    for pno in pdf_info:
        ocr_results[pno] = {}
        ocr_result = []
        for key, block in enumerate(pdf_info[pno]['blocks']):
            if block['type'] != 0:
                continue
            for line in block['lines']:
                for span in line['spans']:
                    bbox, text = span['bbox'], span['text']
                    if len(text) == 0:
                        continue
                    # print(text)
                    xmin, ymin, xmax, ymax = bbox
                    polygon = [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]
                    polygon = np.array(polygon, dtype=np.int32).tolist()
                    text = text.replace(":", ":").replace(" ", "")
                    ocr_result.append([polygon, text])
        ocr_result = sorted(ocr_result, key=lambda x: x[0][1], reverse=False)  # 按 y0 从小到大排
        keys = list(range(len(ocr_result)))
        ocr_result = dict(zip(keys, ocr_result))
        ocr_results[pno] = ocr_result
    if is_qrs:
        results = extract_info(ocr_results)
    else:
        # 输入是整个 PDF 中的信息
        if is_fsm:
            f = FSMFinder(pdf_info, ocr_results=ocr_results) 
        else:
            f = Finder(pdf_info, ocr_results=ocr_results)
        results = f.get_info()
    return results