create_dataset2.py 14.4 KB
import copy
import json
import os
import random
import uuid

import cv2
import pandas as pd
import numpy as np
from shapely.geometry import Polygon, MultiPoint
from tools import get_file_paths, load_json
from word2vec import jwq_word2vec, simple_word2vec

def bbox_iou(go_bbox, label_bbox, mode='iou'):
    # 所有点的最小凸的表示形式,四边形对象,会自动计算四个点,最后顺序为:左上 左下  右下 右上 左上
    go_poly = Polygon(go_bbox).convex_hull
    label_poly = Polygon(label_bbox).convex_hull
    if not go_poly.is_valid or not label_poly.is_valid:
        print('formatting errors for boxes!!!! ')
        return 0
    if go_poly.area == 0 or label_poly.area == 0 :
        return 0

    inter = Polygon(go_poly).intersection(Polygon(label_poly)).area
    go_area = Polygon(go_poly).area

    return inter / go_area
    
    # if mode == 'iou':
    #     union = go_poly.area + label_poly.area - inter
    # elif mode =='tiou':
    #     union_poly = np.concatenate((go_bbox, label_bbox))   #合并两个box坐标,变为8*2
    #     union = MultiPoint(union_poly).convex_hull.area
    #     # coors = MultiPoint(union_poly).convex_hull.wkt
    # elif mode == 'giou':
    #     union_poly = np.concatenate((go_bbox, label_bbox))
    #     union = MultiPoint(union_poly).envelope.area
    #     # coors = MultiPoint(union_poly).envelope.wkt
    # elif mode == 'r_giou':
    #     union_poly = np.concatenate((go_bbox, label_bbox))
    #     union = MultiPoint(union_poly).minimum_rotated_rectangle.area
    #     # coors = MultiPoint(union_poly).minimum_rotated_rectangle.wkt
    # else:
    #     raise Exception('incorrect mode!')

    # if union == 0:
    #     return 0
    # else:
    #     return inter / union



def clean_go_res(go_res_dir):
    go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
    for go_res_json_path in go_res_json_paths:
        print('Info: start {0}'.format(go_res_json_path))
    
        remove_idx_set = set()
        src_go_res_list = load_json(go_res_json_path)
        for idx, (_, text) in enumerate(src_go_res_list):
            if text.strip() == '':
                remove_idx_set.add(idx)
                print(text)
        
        if len(remove_idx_set) > 0:
            for del_idx in remove_idx_set:
                del src_go_res_list[del_idx]
        
        go_res_list = sorted(src_go_res_list, key=lambda x: (x[0][1], x[0][0]), reverse=False)

        with open(go_res_json_path, 'w') as fp:
            json.dump(go_res_list, fp)
            print('Rerewirte {0}'.format(go_res_json_path))

def char_length_statistics(go_res_dir):
    max_char_length = None
    target_file_name = None
    go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
    for go_res_json_path in go_res_json_paths:
        print('Info: start {0}'.format(go_res_json_path))
        src_go_res_list = load_json(go_res_json_path)
        for _, text in src_go_res_list:
            if max_char_length is None or len(text.strip()) > max_char_length:
                max_char_length = len(text.strip())
                target_file_name = go_res_json_path
    return max_char_length, target_file_name

def bbox_statistics(go_res_dir):
    max_seq_count = None
    seq_sum = 0
    file_count = 0

    go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
    for go_res_json_path in go_res_json_paths:
        print('Info: start {0}'.format(go_res_json_path)) 

        go_res_list = load_json(go_res_json_path)
        seq_sum += len(go_res_list)
        file_count += 1
        if max_seq_count is None or len(go_res_list) > max_seq_count:
            max_seq_count = len(go_res_list)
            max_seq_file_name = go_res_json_path

    seq_lens_mean = seq_sum // file_count 
    return max_seq_count, seq_lens_mean, max_seq_file_name

def text_statistics(go_res_dir):
    """
    Args:
        go_res_dir: str 通用OCR的JSON文件夹
    Returns: list 出现次数最多的文本及其次数
    """
    json_count = 0
    text_dict = {}
    go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
    for go_res_json_path in go_res_json_paths:
        print('Info: start {0}'.format(go_res_json_path))
        json_count += 1
        go_res = load_json(go_res_json_path)
        for _, text in go_res.values():
            if text in text_dict:
                text_dict[text] += 1
            else:
                text_dict[text] = 1
    top_text_list = []
    # 按照次数排序
    for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True):
        if text == '':
            continue
        # 丢弃:次数少于总数的2/3
        if count <= json_count // 3:
            break
        top_text_list.append((text, count))
    return top_text_list

def build_anno_file(dataset_dir, anno_file_path):
    img_list = os.listdir(dataset_dir)
    random.shuffle(img_list)
    df = pd.DataFrame(columns=['name'])
    df['name'] = img_list
    df.to_csv(anno_file_path)

def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir, is_create_map=False):
    """
    Args:
        img_dir: str 图片目录
        go_res_dir: str 通用OCR的JSON保存目录
        label_dir: str 标注的JSON保存目录
        top_text_list: list 出现次数最多的文本及其次数
        skip_list: list 跳过的图片列表
        save_dir: str 数据集保存目录
    """
    if os.path.exists(save_dir):
        return
    else:
        os.makedirs(save_dir, exist_ok=True)

    # 开票日期 发票代码 机打号码 车辆类型 电话 
    # 发动机号码 车架号 帐号 开户银行 小写
    group_cn_list = ['开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
    test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]

    create_map = {}
    for img_name in sorted(os.listdir(img_dir)):
        if img_name in skip_list:
            print('Info: skip {0}'.format(img_name))
            continue

        print('Info: start {0}'.format(img_name))
        image_path = os.path.join(img_dir, img_name)
        img = cv2.imread(image_path)
        h, w, _ = img.shape
        base_image_name, _ = os.path.splitext(img_name)
        go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name))
        go_res_list = load_json(go_res_json_path)

        valid_lens = len(go_res_list)

        top_text_idx_set = set()
        for top_text, _ in top_text_list:
            for go_idx, (_, text) in enumerate(go_res_list):
                if text == top_text:
                    top_text_idx_set.add(go_idx)
                    break

        label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name))
        label_res = load_json(label_json_path)

        group_list = []
        for group_id in test_group_id:
            for item in label_res.get("shapes", []):
                if item.get("group_id") == group_id:
                    label_bbox = list()
                    for point in item['points']:
                        label_bbox.extend(point)
                    group_list.append(label_bbox)
                    break
            else:
                group_list.append(None)
        
        label_idx_dict = dict()
        for label_idx, label_bbox in enumerate(group_list):
            if isinstance(label_bbox, list):
                for go_idx, (go_bbox, _) in enumerate(go_res_list):
                    if go_idx in top_text_idx_set or go_idx in label_idx_dict:
                        continue
                    go_bbox_rebuild = [
                        [go_bbox[0], go_bbox[1]],
                        [go_bbox[2], go_bbox[3]],
                        [go_bbox[4], go_bbox[5]],
                        [go_bbox[6], go_bbox[7]],
                    ]
                    label_bbox_rebuild = [
                        [label_bbox[0], label_bbox[1]],
                        [label_bbox[2], label_bbox[1]],
                        [label_bbox[2], label_bbox[3]],
                        [label_bbox[0], label_bbox[3]],
                    ]
                    iou = bbox_iou(go_bbox_rebuild, label_bbox_rebuild)
                    if iou >= 0.5:
                        label_idx_dict[go_idx] = label_idx 
        
        X = list()
        y_true = list()

        # text_vec_max_lens = 15 * 50
        # dim = 1 + 5 + 8 + text_vec_max_lens 
        dim = 1 + 5 + 8
        num_classes = 10
        for i in range(160):
            if i >= valid_lens:
                X.append([0. for _ in range(dim)])
                y_true.append([0 for _ in range(num_classes)])

            elif i in top_text_idx_set:
                (x0, y0, x1, y1, x2, y2, x3, y3), text = go_res_list[i]
                feature_vec = [1.]
                feature_vec.extend(simple_word2vec(text))
                feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
                # feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
                X.append(feature_vec)

                y_true.append([0 for _ in range(num_classes)])

            elif i in label_idx_dict:
                (x0, y0, x1, y1, x2, y2, x3, y3), text = go_res_list[i]
                feature_vec = [0.]
                feature_vec.extend(simple_word2vec(text))
                feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
                # feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
                X.append(feature_vec)

                base_label_list = [0 for _ in range(num_classes)]
                base_label_list[label_idx_dict[i]] = 1
                y_true.append(base_label_list)
            else:
                (x0, y0, x1, y1, x2, y2, x3, y3), text = go_res_list[i]
                feature_vec = [0.]
                feature_vec.extend(simple_word2vec(text))
                feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
                # feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
                X.append(feature_vec)

                y_true.append([0 for _ in range(num_classes)])

        all_data = [X, y_true, valid_lens]

        save_json_name = '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, img_name))
        with open(os.path.join(save_dir, save_json_name), 'w') as fp:
            json.dump(all_data, fp)

        if is_create_map:
            create_map[img_name] = {
                'x_y_valid_lens': save_json_name, 
                'find_top_text': [go_res_list[i][-1] for i in top_text_idx_set],
                'find_value': {go_res_list[k][-1]: group_cn_list[v] for k, v in label_idx_dict.items()}
            }
    
    # print(create_map)
    # print(is_create_map)
    if create_map:
        # print(create_map)
        with open(os.path.join(os.path.dirname(save_dir), 'create_map.json'), 'w') as fp:
            json.dump(create_map, fp) 

        # print('top text find:')
        # for i in top_text_idx_set:
        #     _, text = go_res_list[i]
        #     print(text)

        # print('-------------')
        # print('label value find:')
        # for k, v in label_idx_dict.items():
        #     _, text = go_res_list[k]
        #     print('{0}: {1}'.format(group_cn_list[v], text)) 

        # break


if __name__ == '__main__':
    base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
    go_dir = os.path.join(base_dir, 'go_res')
    dataset_save_dir = os.path.join(base_dir, 'dataset160x14-pro-all-valid')
    label_dir = os.path.join(base_dir, 'labeled')

    train_go_path = os.path.join(go_dir, 'train')
    train_image_path = os.path.join(label_dir, 'train', 'image')
    train_label_path = os.path.join(label_dir, 'train', 'label')
    train_dataset_dir = os.path.join(dataset_save_dir, 'train')
    train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv')

    valid_go_path = os.path.join(go_dir, 'valid')
    valid_image_path = os.path.join(label_dir, 'valid', 'image')
    valid_label_path = os.path.join(label_dir, 'valid', 'label')
    valid_dataset_dir = os.path.join(dataset_save_dir, 'valid')
    valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv')

    # max_seq_lens, seq_lens_mean, max_seq_file_name = bbox_statistics(go_dir)
    # print(max_seq_lens) # 152
    # print(max_seq_file_name) # train/CH-B101805176_page_2_img_0.json
    # print(seq_lens_mean) # 92

    # max_char_lens, target_file_name = char_length_statistics(go_dir)
    # print(max_char_lens) # 72
    # print(target_file_name) # train/CH-B103053828-4.json

    # top_text_list = text_statistics(go_dir)
    # for t in top_text_list:
    #     print(t)

    filter_from_top_text_list = [
        ('机器编号', 496),
        ('购买方名称', 496),
        ('合格证号', 495),
        ('进口证明书号', 495),
        ('机打代码', 494),
        ('车辆类型', 492),
        ('完税凭证号码', 492),
        ('机打号码', 491),
        ('发动机号码', 491),
        ('主管税务', 491),
        ('价税合计', 489),
        ('机关及代码', 489),
        ('销货单位名称', 486),
        ('厂牌型号', 485),
        ('产地', 485),
        ('商检单号', 483),
        ('电话', 476),
        ('开户银行', 472),
        ('车辆识别代号/车架号码', 463),
        ('身份证号码', 454),
        ('吨位', 452),
        ('备注:一车一票', 439),
        ('地', 432),
        ('账号', 431),
        ('统一社会信用代码/', 424),
        ('限乘人数', 404),
        ('税额', 465),
        ('址', 392)
    ]

    skip_list_train = [
        'CH-B101910792-page-12.jpg',
        'CH-B101655312-page-13.jpg',
        'CH-B102278656.jpg',
        'CH-B101846620_page_1_img_0.jpg',
        'CH-B103062528-0.jpg',
        'CH-B102613120-3.jpg',
        'CH-B102997980-3.jpg',
        'CH-B102680060-3.jpg',
        # # 'CH-B102995500-2.jpg',  # 没value
    ]

    skip_list_valid = [
        # 'CH-B102897920-2.jpg',
        # 'CH-B102551284-0.jpg',
        # 'CH-B102879376-2.jpg',
        # 'CH-B101509488-page-16.jpg',
        # 'CH-B102708352-2.jpg',
    ]

    build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
    build_anno_file(train_dataset_dir, train_anno_file_path)

    build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir, True)
    build_anno_file(valid_dataset_dir, valid_anno_file_path)

    # print(simple_word2vec(' fd2jk接口 额24;叁‘,。测ADF壹试!¥? '))
    # print(jwq_word2vec('发', 15*50))