retriever.py 8.25 KB
import re


class Retriever:

    def __init__(self, target_fields):
        self.keys_str = 'keys'
        self.value_str = 'value'
        self.signature_str = 'signature'
        self.signature_have_str = '有'
        self.signature_have_not_str = '无'
        self.target_fields = target_fields
        # self.key_text_set = self.get_key_text_set(target_fields)
        self.replace_map = {
            'int': {
                '(': '0'
            }
        }

    # def get_key_text_set(self, target_fields):
    #     # 关键词集合
    #     key_text_set = set()
    #     for key_text_list in target_fields[self.keys_str].values():
    #         for key_text, key_re, _, _ in key_text_list:
    #             key_text_set.add(key_text)
    #     return key_text_set

    @staticmethod
    def key_top1(coordinates_list, key_coordinates):
        # 关键词查找方向:最上面
        coordinates_list.sort(key=lambda x: x[1])
        return coordinates_list[0]

    @staticmethod
    def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding, scope):
        # 关键词查找方向:右侧
        if len(coordinates_list) == 1:
            return coordinates_list[0]
        height = key_coordinates[-1] - key_coordinates[1]
        y_min = key_coordinates[1] - (top_padding * height)
        y_max = key_coordinates[-1] + (bottom_padding * height)

        width = key_coordinates[2] - key_coordinates[0]
        x_min = key_coordinates[2]
        x_max = key_coordinates[2] + (width * scope)

        x_min_find = None
        key_coordinates = None
        for x0, y0, x1, y1 in coordinates_list:
            cent_x = x0 + ((x1 - x0) / 2)
            cent_y = y0 + ((y1 - y0) / 2)
            if x_min < cent_x < x_max and y_min < cent_y < y_max:
                if x_min_find is None or x0 < x_min_find:
                    x_min_find = x0
                    key_coordinates = (x0, y0, x1, y1)
        return key_coordinates

    def value_right(self, go_res, key_coordinates, top_padding, bottom_padding, scope, value_type=None):
        # 字段值查找方向:右侧
        height = key_coordinates[-1] - key_coordinates[1]
        y_min = key_coordinates[1] - (top_padding * height)
        y_max = key_coordinates[-1] + (bottom_padding * height)

        width = key_coordinates[2] - key_coordinates[0]
        x_min = key_coordinates[2]
        x_max = key_coordinates[2] + (width * scope)

        x_min_find = None
        value = None
        coordinates = None
        for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
            cent_x = x0 + ((x1 - x0) / 2)
            cent_y = y0 + ((y1 - y0) / 2)
            if x_min < cent_x < x_max and y_min < cent_y < y_max:
                if x_min_find is None or x0 < x_min_find:
                    if len(text.strip()) > 0:
                        x_min_find = x0
                        value = text
                        coordinates = (x0, y0, x1, y1)

        if isinstance(value_type, str) and value_type in self.replace_map and isinstance(value, str):
            new_value = value.translate(str.maketrans(self.replace_map.get(value_type, {})))
            return new_value, coordinates
        return value, coordinates

    @staticmethod
    def value_under(go_res, key_coordinates, left_padding, right_padding, scope, value_type=None):
        # 字段值查找方向:下方
        width = key_coordinates[2] - key_coordinates[0]
        x_min = key_coordinates[0] - (width * left_padding)
        x_max = key_coordinates[2] + (width * right_padding)

        height = key_coordinates[-1] - key_coordinates[1]
        y_min = key_coordinates[-1]
        y_max = key_coordinates[-1] + (height * scope)

        y_min_find = None
        value = None
        coordinates = None
        for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
            cent_x = x0 + ((x1 - x0)/2)
            cent_y = y0 + ((y1 - y0)/2)
            if x_min < cent_x < x_max and y_min < cent_y < y_max:
                if y_min_find is None or y0 < y_min_find:
                    if len(text.strip()) > 0:
                        y_min_find = y0
                        value = text
                        coordinates = (x0, y0, x1, y1)
        return value, coordinates

    @staticmethod
    def rebuild_res(value_res, coordinates_res, is_signature=False):
        words_result = dict()
        for key, value in value_res.items():
            if is_signature:
                coordinates_dict = coordinates_res.get(key, dict())
                x0 = coordinates_dict.get('xmin', -1)
                y0 = coordinates_dict.get('ymin', -1)
                x1 = coordinates_dict.get('xmax', -1)
                y1 = coordinates_dict.get('ymax', -1)
            else:
                x0, y0, x1, y1 = coordinates_res.get(key, (-1, -1, -1, -1))
            words_result[key] = {
                'words': value,
                'score': -1 if not is_signature and x0 == -1 else 1,
                'location': {
                    'left': x0,
                    'top': y0,
                    'width': x1-x0,
                    'height': y1-y0,
                }
            }
        return words_result

    def get_target_fields(self, go_res, signature_res_list):
        # 搜索关键词
        key_text_info = dict()
        for key_text_list in self.target_fields[self.keys_str].values():
            for key_text, key_re, _, _ in key_text_list:
                for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
                    if re.match(key_re, text):
                        key_text_info.setdefault(key_text, list()).append((x0, y0, x1, y1))

            # if text in self.key_text_set:
            #     key_text_info.setdefault(text, list()).append((x0, y0, x1, y1))

        # 搜索关键词
        key_coordinates_info = dict()
        for field, key_text_list in self.target_fields[self.keys_str].items():
            pre_key_coordinates = None
            for key_text, _, direction, kwargs in key_text_list:
                if key_text not in key_text_info:
                    break
                key_coordinates = getattr(self, 'key_{0}'.format(direction))(
                    key_text_info[key_text],
                    pre_key_coordinates,
                    **kwargs)
                if not isinstance(key_coordinates, tuple):
                    break
                pre_key_coordinates = key_coordinates
            else:
                key_coordinates_info[field] = pre_key_coordinates

        # 搜索字段值
        value_res = dict()
        coordinates_res = dict()
        for field, (direction, kwargs, default_value) in self.target_fields[self.value_str].items():
            if not isinstance(key_coordinates_info.get(field), tuple):
                value_res[field] = default_value
                continue
            value, coordinates = getattr(self, 'value_{0}'.format(direction))(
                go_res,
                key_coordinates_info[field],
                **kwargs
            )
            if not isinstance(value, str):
                value_res[field] = default_value
            else:
                value_res[field] = value
                coordinates_res[field] = coordinates

        # 搜索签章
        tmp_signature_info = dict()
        signature_coordinates_res = dict()
        signature_value_res = dict()
        for signature_dict in signature_res_list:
            tmp_signature_info.setdefault(signature_dict['label'], list()).append(signature_dict['location'])

        for field, signature_type_set in self.target_fields[self.signature_str].items():
            for signature_type in signature_type_set:
                if len(tmp_signature_info.get(signature_type, [])) > 0:
                    signature_value_res[field] = self.signature_have_str
                    signature_coordinates_res[field] = tmp_signature_info[signature_type].pop(0)
                    break
                else:
                    signature_value_res[field] = self.signature_have_not_str

        words_result = self.rebuild_res(value_res, coordinates_res)
        words_result_signature = self.rebuild_res(signature_value_res, signature_coordinates_res, True)
        words_result.update(words_result_signature)

        # signature_value_res.update(value_res)
        # return signature_value_res
        return {'words_result': words_result}