retriever.py 13.8 KB
import re
import math


class Retriever:

    def __init__(self, keys_list=[], values_dict={}):
        self.keys_list = keys_list
        self.values_dict = values_dict
        self.find_keys_list = []

    @staticmethod
    def get_theta(x0, y0, x1, y1):
        theta = math.atan((y0-y1)/(x1-x0))
        return math.cos(theta), math.sin(theta)

    @staticmethod
    def rebuild_xy(x, y, cos, sin):
        rebuild_x =  x * cos - y * sin
        rebuild_y =  y * cos + x * sin
        return rebuild_x, rebuild_y

    def rebuild_coord(self, coord_tuple, cos, sin):
        rebuild_list = []
        for idx in range(0, len(coord_tuple), 2):
            rebuild_list.extend(self.rebuild_xy(coord_tuple[idx], coord_tuple[idx+1], cos, sin))
        return rebuild_list

    @staticmethod
    def prune_no_cn(src_str):
        fix_str = re.sub(r'[^\u4e00-\u9fa5]+', '', src_str)
        return fix_str

    @staticmethod
    def prune_first_char(src_str, char_set):
        if src_str[0] in char_set:
            return src_str[1:]
        return src_str

    @staticmethod
    def prune_amount(src_str):
        fix_str = ''.join(filter(lambda i: i in [',', '.'] or str.isdigit(i), src_str))
        return fix_str 

    @staticmethod
    def replace_whole(src_str, replace_map):
        fix_str = src_str.translate(str.maketrans(replace_map))
        return fix_str 

    @staticmethod
    def replace_last_char(src_str, char_set, target_char):
        if src_str[-1] in char_set:
            return src_str[:-1] + target_char
        return src_str

    # @staticmethod
    # def prune_RMB(src_str):
    #     return src_str 

    @staticmethod
    def choice_xmin(value_list, value_length):
        value_list.sort(key=lambda x: x[1])
        return value_list[0]

    @staticmethod
    def choice_xmax(value_list, value_length):
        value_list.sort(key=lambda x: x[1], reverse=True)
        return value_list[0]

    @staticmethod
    def choice_ymin(value_list, value_length):
        value_list.sort(key=lambda x: x[2])
        return value_list[0]

    @staticmethod
    def choice_ymax(value_list, value_length):
        value_list.sort(key=lambda x: x[2], reverse=True)
        return value_list[0]

    @staticmethod
    def choice_merge(value_list, value_length):
        value_list.sort(key=lambda x: x[2])
        merged_value_list = []
        merged_idx_list = []
        merged_x_list = []
        merged_y_list = []
        for text, x0, y0, x1, y1, idx_tuple in value_list:
            merged_value_list.append(text) 
            merged_idx_list.extend(idx_tuple)
            merged_x_list.append(x0)
            merged_x_list.append(x1)
            merged_y_list.append(y0)
            merged_y_list.append(y1)
        return (''.join(merged_value_list),
                min(merged_x_list),
                min(merged_y_list),
                max(merged_x_list),
                max(merged_y_list),
                tuple(merged_idx_list))

    @staticmethod
    def choice_length(value_list, value_length):
        value_list.sort(key=lambda x: len(x[0]) - value_length)
        return

    def value_direction_left(self, go_res, key_idx, top_or_left, bottom_or_right, offset, scope_tuple, choice_method,
                             if_startswith, length):
        # 字段值查找方向:左侧

        if self.find_keys_list[key_idx] is None:
            return

        _, _, find_key_str, suffix_key, key_x0_src, key_y0_src, key_x1_src, key_y1_src, key_x2_src, key_y2_src, \
            key_x3_src, key_y3_src = self.find_keys_list[key_idx]

        for scope_key_idx in scope_tuple[:-1]:
            if self.find_keys_list[scope_key_idx] is None:
                continue
            key_scope_tuple = (
            self.find_keys_list[scope_key_idx][6], self.find_keys_list[scope_key_idx][7])  # left x1, y1
            break
        else:
            key_scope_tuple = None

        # if isinstance(if_startswith, str):
        #     if isinstance(suffix_key, str):
        #         # TODO suffix_key校验与修正
        #         # TODO 目前只考虑了split的情况
        #         if isinstance(length, int):
        #             if -3 < length - len(suffix_key) < 3:
        #                 return suffix_key, (
        #                 key_x0_src, key_y0_src, key_x1_src, key_y1_src, key_x2_src, key_y2_src, key_x3_src,
        #                 key_y3_src), ()
        #         else:
        #             return suffix_key, (
        #             key_x0_src, key_y0_src, key_x1_src, key_y1_src, key_x2_src, key_y2_src, key_x3_src,
        #             key_y3_src), ()

        # 坐标系转换
        cos, sin = self.get_theta(key_x0_src, key_y0_src, key_x1_src, key_y1_src)
        key_x0, key_y0, key_x1, key_y1, key_x2, key_y2, key_x3, key_y3 = self.rebuild_coord(
            (key_x0_src, key_y0_src, key_x1_src, key_y1_src, key_x2_src, key_y2_src, key_x3_src, key_y3_src), cos,
            sin)

        height = key_y2 - key_y0
        y_min = key_y0 - (top_or_left * height)
        y_max = key_y2 + (bottom_or_right * height)

        width = key_x2 - key_x0
        x_max = key_x0 - (offset * width)
        x_min = x_max - (width * scope_tuple[-1]) if key_scope_tuple is None else \
            self.rebuild_xy(*key_scope_tuple, cos, sin)[0]

        all_find_value_list = []
        for go_key_idx, ((x0, y0, x1, y1, x2, y3, x3, y3), text) in go_res.items():
            cent_x, cent_y = self.rebuild_xy(x0 + ((x2 - x0) / 2), y0 + ((y2 - y0) / 2), cos, sin)
            # if go_key_idx == '98' and key_idx == 34:
            #     print(cent_x)
            #     print(cent_y)
            #     print('-----------')
            #     print(key_x0)
            #     print(key_x1)
            #     print(key_y0)
            #     print(key_y1)
            #     print('-----------')
            #     print(x_min)
            #     print(x_max)
            #     print(y_min)
            #     print(y_max)
            if x_min < cent_x < x_max and y_min < cent_y < y_max:
                all_find_value_list.append((text, x0, y0, x1, y1, x2, y2, x3, y3, (go_key_idx,)))

        if len(all_find_value_list) == 0:
            return
        elif len(all_find_value_list) == 1:
            return all_find_value_list[0]
        else:
            # TODO choice时的坐标转换?
            choice_value = getattr(self, 'choice_{0}'.format(choice_method))(all_find_value_list, length)
            return choice_value

        # if isinstance(value_type, str) and value_type in self.replace_map and isinstance(value, str):
        #     new_value = value.translate(str.maketrans(self.replace_map.get(value_type, {})))
        #     return new_value, coordinates

    def value_direction_right(self, go_res, key_idx, top_or_left, bottom_or_right, offset, scope_tuple, choice_method,
                              if_startswith, length):
        # 字段值查找方向:右侧

        if self.find_keys_list[key_idx] is None:
            return

        _, _, find_key_str, suffix_key, key_x0_src, key_y0_src, key_x1_src, key_y1_src, key_x2_src, key_y2_src, \
            key_x3_src, key_y3_src = self.find_keys_list[key_idx]
        
        for scope_key_idx in scope_tuple[:-1]:
            if self.find_keys_list[scope_key_idx] is None:
                continue
            key_scope_tuple = (self.find_keys_list[scope_key_idx][4], self.find_keys_list[scope_key_idx][5])  # right x0, y0
            break            
        else:
            key_scope_tuple = None        
        
        if isinstance(if_startswith, str):
            if isinstance(suffix_key, str):
                # TODO suffix_key校验与修正
                # TODO 目前只考虑了split的情况
                if isinstance(length, int):
                    if -3 < length - len(suffix_key) < 3:
                        return suffix_key, (key_x0_src, key_y0_src, key_x1_src, key_y1_src, key_x2_src, key_y2_src, key_x3_src, key_y3_src), ()
                else:
                    return suffix_key, (key_x0_src, key_y0_src, key_x1_src, key_y1_src, key_x2_src, key_y2_src, key_x3_src, key_y3_src), ()

        # 坐标系转换
        cos, sin = self.get_theta(key_x0_src, key_y0_src, key_x1_src, key_y1_src)
        key_x0, key_y0, key_x1, key_y1, key_x2, key_y2, key_x3, key_y3 = self.rebuild_coord(
            (key_x0_src, key_y0_src, key_x1_src, key_y1_src, key_x2_src, key_y2_src, key_x3_src, key_y3_src), cos, sin)

        height = key_y2 - key_y0
        y_min = key_y0 - (top_or_left * height)
        y_max = key_y2 + (bottom_or_right * height)

        width = key_x2 - key_x0
        x_min = key_x2 + (offset * width)
        x_max = x_min + (width * scope_tuple[-1]) if key_scope_tuple is None else self.rebuild_xy(
            *key_scope_tuple, cos, sin)[0]

        all_find_value_list = []
        for go_key_idx, ((x0, y0, x1, y1, x2, y3, x3, y3), text) in go_res.items():
            cent_x, cent_y = self.rebuild_xy(x0 + ((x2 - x0) / 2), y0 + ((y2 - y0) / 2), cos, sin)
            # if go_key_idx == '98' and key_idx == 34:
            #     print(cent_x)
            #     print(cent_y)
            #     print('-----------')
            #     print(key_x0)
            #     print(key_x1)
            #     print(key_y0)
            #     print(key_y1)
            #     print('-----------')
            #     print(x_min)
            #     print(x_max)
            #     print(y_min)
            #     print(y_max)
            if x_min < cent_x < x_max and y_min < cent_y < y_max:
                all_find_value_list.append((text, x0, y0, x1, y1, x2, y2, x3, y3, (go_key_idx, )))

        if len(all_find_value_list) == 0:
            return
        elif len(all_find_value_list) == 1:
            return all_find_value_list[0]
        else:
            # TODO choice时的坐标转换?
            choice_value = getattr(self, 'choice_{0}'.format(choice_method))(all_find_value_list, length)
            return choice_value

        # if isinstance(value_type, str) and value_type in self.replace_map and isinstance(value, str):
        #     new_value = value.translate(str.maketrans(self.replace_map.get(value_type, {})))
        #     return new_value, coordinates

    @staticmethod
    def splitext(base_str, key_str, x0, y0, x1, y1, x2, y2, x3, y3):
        suffix_value = base_str[len(key_str):]  # TODO 坐标切分
        return key_str, suffix_value, x1, y1, x2, y2
        # return prefix_key, suffix_value, new_x1

    def search_keys(self, go_res):
        find_keys_list = [None for _ in range(len(self.keys_list))]
        rm_go_key_set = set()
        done_key_idx_set = set()

        for key_idx, key_tuple in enumerate(self.keys_list):
            for str_idx, ((x0, y0, x1, y1, x2, y2, x3, y3), text) in go_res.items():
                if len(text.strip()) == 0:  # 去除空格
                    rm_go_key_set.add(str_idx)
                    continue
                for key_str in key_tuple[:-1]:
                    if text == key_str:  # 全值匹配
                        find_keys_list[key_idx] = (key_tuple[0], key_str, text, None, x0, y0, x1, y1, x2, y2, x3, y3)
                        done_key_idx_set.add(key_idx)
                        rm_go_key_set.add(str_idx)
                        break
                else:
                    continue
                break

            for go_key in rm_go_key_set:
                go_res.pop(go_key)
            rm_go_key_set.clear()

        for key_idx, key_tuple in enumerate(self.keys_list):
            if key_idx in done_key_idx_set or not key_tuple[-1]:
                continue

            for str_idx, ((x0, y0, x1, y1, x2, y2, x3, y3), text) in go_res.items():
                if text.startswith(key_tuple[0]):  # 以key开头
                    prefix_key, suffix_value, new_x1, new_y1, new_x2, new_y2 = self.splitext(
                        text, key_tuple[0], x0, y0, x1, y1, x2, y2, x3, y3)
                    find_keys_list[key_idx] = (key_tuple[0], key_tuple[0], text, suffix_value,
                                               x0, y0, new_x1, new_y1, new_x2, new_y2, x3, y3)
                    done_key_idx_set.add(key_idx)
                    rm_go_key_set.add(str_idx)
                    break

            for go_key in rm_go_key_set:
                go_res.pop(go_key)
            rm_go_key_set.clear()

        self.find_keys_list = find_keys_list

        # for i in find_keys_list:
        #    print(i)

    def search_values(self, go_res):
        # idx, location, top, bottom, left, (idx, scope), choice, if_startswith
        find_value_dict = dict()
        rm_go_key_set = set()
        for cn_key, search_dict in self.values_dict.items():
            for key_idx, direction_str, top_or_left, bottom_or_right, offset, scope_tuple, choice_method, if_startswith in search_dict['location']:
                value_tuple = getattr(self, 'value_direction_{0}'.format(direction_str))(
                    go_res,
                    key_idx,
                    top_or_left,
                    bottom_or_right,
                    offset,
                    scope_tuple,
                    choice_method,
                    if_startswith,   
                    search_dict['length'],
                )
                if isinstance(value_tuple, tuple):
                    break

            if isinstance(value_tuple, tuple):
                fixed_str = value_tuple[0]
                for fix_method, kwargs in search_dict.get('fix_methods', []):
                    fixed_str = getattr(self, fix_method)(fixed_str, **kwargs)
                find_value_dict[cn_key] = fixed_str
            else:
                find_value_dict[cn_key] = '' 
            
            # TODO 坐标重构

            if isinstance(value_tuple, tuple):
                for go_key in value_tuple[-1]:
                    go_res.pop(go_key)

        return find_value_dict

    def extract_fields(self, go_res):
        # 搜索关键词
        self.search_keys(go_res)
        res = self.search_values(go_res)
        return res