retriever.py 4.99 KB
class Retriever:

    def __init__(self, target_fields):
        self.keys_str = 'keys'
        self.value_str = 'value'
        self.signature_str = 'signature'
        self.signature_have_str = '有'
        self.signature_have_not_str = '无'
        self.target_fields = target_fields
        self.key_text_set = self.get_key_text_set(target_fields)

    def get_key_text_set(self, target_fields):
        key_text_set = set()
        for key_text_list in target_fields[self.keys_str].values():
            for key_text, _, _ in key_text_list:
                key_text_set.add(key_text)
        return key_text_set

    @staticmethod
    def key_top1(coordinates_list, key_coordinates):
        coordinates_list.sort(key=lambda x: x[1])
        return coordinates_list[0]

    @staticmethod
    def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding):
        if len(coordinates_list) == 1:
            return coordinates_list[0]
        height = key_coordinates[-1] - key_coordinates[1]
        y_min = key_coordinates[1] - (top_padding * height)
        y_max = key_coordinates[-1] + (bottom_padding * height)
        x = key_coordinates[2]

        x_min = None
        key_coordinates = None
        for x0, y0, x1, y1 in coordinates_list:
            if y0 > y_min and y1 < y_max and x0 > x:
                if x_min is None or x0 < x_min:
                    x_min = x0
                    key_coordinates = (x0, y0, x1, y1)
        return key_coordinates

    @staticmethod
    def value_right(go_res, key_coordinates, top_padding, bottom_padding):
        height = key_coordinates[-1] - key_coordinates[1]
        y_min = key_coordinates[1] - (top_padding * height)
        y_max = key_coordinates[-1] + (bottom_padding * height)
        x = key_coordinates[2]

        x_min = None
        value = None
        for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
            if y0 > y_min and y1 < y_max and x0 > x:
                if x_min is None or x0 < x_min:
                    x_min = x0
                    value = text
        return value

    @staticmethod
    def value_under(go_res, key_coordinates, left_padding, right_padding):
        width = key_coordinates[2] - key_coordinates[0]
        x_min = key_coordinates[0] - (width * left_padding)
        x_max = key_coordinates[2] + (width * right_padding)
        y = key_coordinates[-1]

        y_min = None
        value = None
        for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
            if x0 > x_min and x1 < x_max and y0 > y:
                if y_min is None or y0 < y_min:
                    y_min = y0
                    value = text
        return value

    def get_target_fields(self, go_res, signature_res_list):
        # 搜索关键词
        key_text_info = dict()
        for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
            if text in self.key_text_set:
                key_text_info.setdefault(text, list()).append((x0, y0, x1, y1))

        # 搜索关键词
        key_coordinates_info = dict()
        for field, key_text_list in self.target_fields[self.keys_str].items():
            pre_key_coordinates = None
            for key_text, direction, kwargs in key_text_list:
                if key_text not in key_text_info:
                    break
                key_coordinates = getattr(self, 'key_{0}'.format(direction))(
                    key_text_info[key_text],
                    pre_key_coordinates,
                    **kwargs)
                if not isinstance(key_coordinates, tuple):
                    break
                pre_key_coordinates = key_coordinates
            else:
                key_coordinates_info[field] = pre_key_coordinates

        # 搜索字段值
        res = dict()
        for field, (direction, kwargs, default_value) in self.target_fields[self.value_str].items():
            if not isinstance(key_coordinates_info.get(field), tuple):
                res[field] = default_value
                break
            value = getattr(self, 'value_{0}'.format(direction))(
                go_res,
                key_coordinates_info[field],
                **kwargs
            )
            if not isinstance(value, str):
                res[field] = default_value
            else:
                res[field] = value

        # 搜索签章
        tmp_signature_count = dict()
        for signature_dict in signature_res_list:
            if signature_dict['label'] in tmp_signature_count:
                tmp_signature_count[signature_dict['label']] += 1
            else:
                tmp_signature_count[signature_dict['label']] = 1
        for field, signature_type_set in self.target_fields[self.signature_str].items():
            for signature_type in signature_type_set:
                if tmp_signature_count.get(signature_type, 0) > 0:
                    res[field] = self.signature_have_str
                    tmp_signature_count[signature_type] -= 1
                    break
                else:
                    res[field] = self.signature_have_not_str

        return res