retriever.py 11 KB
import re


class HMHRetriever:

    def __init__(self):
        self.words_str = 'words'
        self.position_str = 'location'
        self.fix_hava_str = '有'
        self.default_position = [0, 0, 0, 0] 
        self.search_fields_list = [
            ('借款/承租人姓名', ''),
            ('证件号码', ''),
            ('渠道', ''),
            ('合同编号', ''),
            ('借款人签字/盖章', '无'),
        ]

    def get_target_fields(self, pdf_text_list):
        result = dict()
        is_find_name_id_company, is_find_application_no, is_find_name_date = False, False, False
        for bbox, text in pdf_text_list.pop(str(0), []):
            # print(text)
            if not is_find_name_id_company:
                name_id_company_list = re.findall(r'姓名(.*)证件号码(.*)与(.*公司)', text)
                for name_id_company_tuple in name_id_company_list:
                    if len(name_id_company_tuple) == 3: 
                        result[self.search_fields_list[0][0]] = {
                            self.words_str: name_id_company_tuple[0].replace('\u3000', '').strip(),
                            self.position_str: bbox
                        }
                        result[self.search_fields_list[1][0]] = {
                            self.words_str: name_id_company_tuple[1].replace('\u3000', '').replace(')', '').replace(')', '').strip(),
                            self.position_str: bbox
                        }
                        result[self.search_fields_list[2][0]] = {
                            self.words_str: name_id_company_tuple[2],
                            self.position_str: bbox
                        }
                        is_find_name_id_company = True
                        break
            if not is_find_application_no:
                application_no_list = re.findall(r'合同编号.*(CH-B\d*-\d*).*', text)
                if len(application_no_list) == 1:
                    result[self.search_fields_list[3][0]] = {
                        self.words_str: application_no_list[0],
                        self.position_str: bbox
                    }
                    is_find_application_no = True
            if not is_find_name_date:
                name_date_list = re.findall(r'(.*).*签署日期.*(\d{4}-\d{2}-\d{2})', text)
                for name_date_tuple in name_date_list:
                    if len(name_date_tuple) == 2: 
                        result[self.search_fields_list[4][0]] = {
                            # self.words_str: '{0} {1}'.format(name_date_tuple[0].replace('\u3000', '').strip(), name_date_tuple[1]),
                            self.words_str: self.fix_hava_str,
                            self.position_str: bbox
                        }
                        is_find_name_date = True
                        break
        
        for find_key, default_value in self.search_fields_list:
            if find_key not in result:
                result[find_key] = {
                    self.words_str: default_value,
                    self.position_str: self.default_position,
                }
        # simple_result = []
        # for key, value_dict in result.items():
        #     simple_result.append((key, value_dict[self.words_str]))

        # return simple_result
        return {"words_result": result}

class Retriever:

    def __init__(self, target_fields):
        self.keys_str = 'keys'
        self.value_str = 'value'
        self.text_str = 'text'
        self.words_str = 'words'
        self.position_str = 'position'
        self.default_position = [-1, -1, -1, -1] 
        self.target_fields = target_fields
        self.replace_map = {
            'int': {
                '(': '0'
            }
        }

    @staticmethod
    def key_top1(coordinates_list, key_coordinates):
        # 关键词查找方向:最上面
        coordinates_list.sort(key=lambda x: x[1])
        return coordinates_list[0]

    def key_right(self, coordinates_list, key_coordinates, offset_tuple, rigorous=False):
        # 关键词查找方向:右侧
        if len(coordinates_list) == 1:
            return coordinates_list[0]

        # 没有上一层关键词的坐标时,返回最上面的坐标
        if key_coordinates is None:
            return self.key_top1(coordinates_list, key_coordinates)

        x_min, y_min, x_max, y_max = self.get_target_bbox(key_coordinates, offset_tuple)

        x_min_find, find_key_coordinates = None, None
        for x0, y0, x1, y1 in coordinates_list:
            if rigorous:
                is_eligible = x_min < x0 and x1 < x_max and y_min < y0 and y1 < y_max
            else:
                cent_x = x0 + ((x1 - x0) / 2)
                cent_y = y0 + ((y1 - y0) / 2)
                is_eligible = x_min < cent_x < x_max and y_min < cent_y < y_max
            if is_eligible:
                if x_min_find is None or x0 < x_min_find:
                    x_min_find = x0
                    find_key_coordinates = (x0, y0, x1, y1)

        if find_key_coordinates is None:
            return self.key_top1(coordinates_list, key_coordinates)
        else:
            return find_key_coordinates

    def value_right(self, search_list, key_coordinates, offset_tuple, value_type=None, rigorous=False):
        # 字段值查找方向:右侧
        x_min, y_min, x_max, y_max = self.get_target_bbox(key_coordinates, offset_tuple)

        x_min_find, value, coordinates = None, None, None
        for (x0, y0, x1, y1), text in search_list:
            if rigorous:
                is_eligible = x_min < x0 and x1 < x_max and y_min < y0 and y1 < y_max
            else:
                cent_x = x0 + ((x1 - x0) / 2)
                cent_y = y0 + ((y1 - y0) / 2)
                is_eligible = x_min < cent_x < x_max and y_min < cent_y < y_max
            if is_eligible:
                if x_min_find is None or x0 < x_min_find:
                    if len(text.strip()) > 0:
                        x_min_find = x0
                        value = text
                        coordinates = (x0, y0, x1, y1)

        if isinstance(value_type, str) and value_type in self.replace_map and isinstance(value, str):
            new_value = value.translate(str.maketrans(self.replace_map.get(value_type, {})))
            return new_value, coordinates

        return value, coordinates

    def value_under(self, search_list, key_coordinates, offset_tuple, value_type=None, append=False, rigorous=False):
        # 字段值查找方向:下方
        x_min, y_min, x_max, y_max = self.get_target_bbox(key_coordinates, offset_tuple)

        find_list = []
        for (x0, y0, x1, y1), text in search_list:
            if rigorous:
                is_eligible = x_min < x0 and x1 < x_max and y_min < y0 and y1 < y_max
            else:
                cent_x = x0 + ((x1 - x0) / 2)
                cent_y = y0 + ((y1 - y0) / 2)
                is_eligible = x_min < cent_x < x_max and y_min < cent_y < y_max
            if is_eligible:
                if len(text.strip()) > 0:
                    find_list.append((x0, y0, x1, y1, text))

        if len(find_list) == 0:
            return None, None
        else:
            find_list.sort(key=lambda x: (x[1], x[0]))
            coordinates = find_list[0][:-1]
            if append:
                value = ''.join([text for _, _, _, _, text in find_list])
            else:
                value = find_list[0][-1]

            if isinstance(value_type, str) and value_type in self.replace_map and isinstance(value, str):
                new_value = value.translate(str.maketrans(self.replace_map.get(value_type, {})))
                return new_value, coordinates

            return value, coordinates
    
    @staticmethod
    def get_target_bbox(key_coordinates, offset_tuple):
        offset_xmin, offset_xmax, offset_ymin, offset_ymax = offset_tuple 

        width = key_coordinates[2] - key_coordinates[0]
        height = key_coordinates[-1] - key_coordinates[1]

        x_min = key_coordinates[0] - (width * offset_xmin) # -1
        x_max = key_coordinates[2] + (width * offset_xmax)
        y_min = key_coordinates[1] - (height * offset_ymin) # -1
        y_max = key_coordinates[-1] + (height * offset_ymax)
        return x_min, y_min, x_max, y_max

    def get_target_fields(self, pdf_text_list, pdf_img_list):
        pdf_result = dict()

        for pno_str, fields_dict in self.target_fields.items():
            is_last_pno = False
            if pno_str == '-1':
                is_last_pno = True
                pno_int_list = [int(pno_str) for pno_str in pdf_text_list.keys()]
                pno_str = str(max(pno_int_list)) 

            # 搜索关键词
            key_text_info = dict()
            for key_text_list in fields_dict[self.keys_str].values():
                for key_text, key_re_tuple, _, _ in key_text_list:
                    for (x0, y0, x1, y1), text in pdf_text_list.get(pno_str, []):
                        for key_re in key_re_tuple:
                            if re.match(key_re, text):
                                key_text_info.setdefault(key_text, list()).append((x0, y0, x1, y1))

            # 搜索关键词
            key_coordinates_info = dict()
            for field, key_text_list in fields_dict[self.keys_str].items():
                last_key_coordinates = None
                for key_text, _, direction, kwargs in key_text_list:
                    if key_text not in key_text_info:
                        last_key_coordinates = None
                        continue
                    last_key_coordinates = getattr(self, 'key_{0}'.format(direction))(
                        key_text_info[key_text],
                        last_key_coordinates,
                        **kwargs)

                key_coordinates_info[field] = last_key_coordinates

            # 搜索字段值
            page_result = dict()
            for field, (source, direction, kwargs, default_value) in fields_dict[self.value_str].items():
                if not isinstance(key_coordinates_info.get(field), tuple):
                    page_result[field] = {
                        self.words_str: default_value,
                        self.position_str: self.default_position,
                    }
                    continue
                value, coordinates = getattr(self, 'value_{0}'.format(direction))(
                    pdf_text_list.get(pno_str, []) if source == self.text_str else pdf_img_list.get(pno_str, []),
                    key_coordinates_info[field],
                    **kwargs
                )
                if not isinstance(value, str):
                    page_result[field] = {
                        self.words_str: default_value,
                        self.position_str: self.default_position,
                    }
                else:
                    page_result[field] = {
                        self.words_str: value,
                        self.position_str: list(coordinates),
                    }
            
            page_key = 'page_12' if is_last_pno else 'page_{0}'.format(int(pno_str) + 1)
            pdf_result[page_key] = page_result

        return pdf_result