create_dataset2.py 10.5 KB
import copy
import json
import os
import random
import uuid

import cv2
import pandas as pd
from tools import get_file_paths, load_json


def clean_go_res(go_res_dir):
    max_seq_count = None
    seq_sum = 0
    file_count = 0

    go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
    for go_res_json_path in go_res_json_paths:
        print('Info: start {0}'.format(go_res_json_path))
    
        remove_key_set = set()
        go_res = load_json(go_res_json_path)
        for key, (_, text) in go_res.items():
            if text.strip() == '':
                remove_key_set.add(key)
                print(text)
        
        if len(remove_key_set) > 0:
            for del_key in remove_key_set:
                del go_res[del_key]
        
        go_res_list = sorted(list(go_res.values()), key=lambda x: (x[0][1], x[0][0]), reverse=False)

        with open(go_res_json_path, 'w') as fp:
            json.dump(go_res_list, fp)
            print('Rerewirte {0}'.format(go_res_json_path))

        seq_sum += len(go_res_list)
        file_count += 1
        if max_seq_count is None or len(go_res_list) > max_seq_count:
            max_seq_count = len(go_res_list)
            max_seq_file_name = go_res_json_path

    seq_lens_mean = seq_sum // file_count 
    return max_seq_count, seq_lens_mean, max_seq_file_name

def text_statistics(go_res_dir):
    """
    Args:
        go_res_dir: str 通用OCR的JSON文件夹
    Returns: list 出现次数最多的文本及其次数
    """
    json_count = 0
    text_dict = {}
    go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
    for go_res_json_path in go_res_json_paths:
        print('Info: start {0}'.format(go_res_json_path))
        json_count += 1
        go_res = load_json(go_res_json_path)
        for _, text in go_res.values():
            if text in text_dict:
                text_dict[text] += 1
            else:
                text_dict[text] = 1
    top_text_list = []
    # 按照次数排序
    for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True):
        if text == '':
            continue
        # 丢弃:次数少于总数的2/3
        if count <= json_count // 3:
            break
        top_text_list.append((text, count))
    return top_text_list

def build_anno_file(dataset_dir, anno_file_path):
    img_list = os.listdir(dataset_dir)
    random.shuffle(img_list)
    df = pd.DataFrame(columns=['name'])
    df['name'] = img_list
    df.to_csv(anno_file_path)

def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir):
    """
    Args:
        img_dir: str 图片目录
        go_res_dir: str 通用OCR的JSON保存目录
        label_dir: str 标注的JSON保存目录
        top_text_list: list 出现次数最多的文本及其次数
        skip_list: list 跳过的图片列表
        save_dir: str 数据集保存目录
    """
    if os.path.exists(save_dir):
        return
    else:
        os.makedirs(save_dir, exist_ok=True)

    # 开票日期 发票代码 机打号码 车辆类型 电话 
    # 发动机号码 车架号 帐号 开户银行 小写
    group_cn_list = ['开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
    test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]

    for img_name in sorted(os.listdir(img_dir)):
        if img_name in skip_list:
            print('Info: skip {0}'.format(img_name))
            continue

        print('Info: start {0}'.format(img_name))
        image_path = os.path.join(img_dir, img_name)
        img = cv2.imread(image_path)
        h, w, _ = img.shape
        base_image_name, _ = os.path.splitext(img_name)
        go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name))
        go_res_list = load_json(go_res_json_path)

        valid_lens = len(go_res_list)

        top_text_idx_set = set()
        for top_text, _ in top_text_list:
            for go_idx, (_, text) in enumerate(go_res_list):
                if text == top_text:
                    top_text_idx_set.add(go_idx)
                    break

        label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name))
        label_res = load_json(label_json_path)

        group_list = []
        for group_id in test_group_id:
            for item in label_res.get("shapes", []):
                if item.get("group_id") == group_id:
                    x_list = []
                    y_list = []
                    for point in item['points']:
                        x_list.append(point[0])
                        y_list.append(point[1])
                    group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2])
                    break
            else:
                group_list.append(None)
        
        go_center_list = []
        for (x0, y0, x1, y1, x2, y2, x3, y3), _ in go_res_list:
            xmin = min(x0, x1, x2, x3)
            ymin = min(y0, y1, y2, y3)
            xmax = max(x0, x1, x2, x3)
            ymax = max(y0, y1, y2, y3)
            xcenter = xmin + (xmax - xmin)/2
            ycenter = ymin + (ymax - ymin)/2
            go_center_list.append((xcenter, ycenter))
        
        label_idx_dict = dict()
        for label_idx, label_center_list in enumerate(group_list):
            if isinstance(label_center_list, list):
                min_go_key = None
                min_length = None
                for go_idx, (go_x_center, go_y_center) in enumerate(go_center_list):
                    if go_idx in top_text_idx_set or go_idx in label_idx_dict:
                        continue
                    length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1])
                    if min_go_key is None or length < min_length:
                        min_go_key = go_idx
                        min_length = length
                if min_go_key is not None:
                    label_idx_dict[min_go_key] = label_idx 
        
        X = list()
        y_true = list()
        for i in range(200):
            if i >= valid_lens:
                X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.])
                y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
            elif i in top_text_idx_set:
                (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
                X.append([1., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
                y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
            elif i in label_idx_dict:
                (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
                X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
                base_label_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
                base_label_list[label_idx_dict[i]] = 1
                y_true.append(base_label_list)
            else:
                (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
                X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) 
                y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

        all_data = [X, y_true, valid_lens]

        with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, img_name))), 'w') as fp:
            json.dump(all_data, fp)

        # print('top text find:')
        # for i in top_text_idx_set:
        #     _, text = go_res_list[i]
        #     print(text)

        # print('-------------')
        # print('label value find:')
        # for k, v in label_idx_dict.items():
        #     _, text = go_res_list[k]
        #     print('{0}: {1}'.format(group_cn_list[v], text)) 

        # break


if __name__ == '__main__':
    base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
    go_dir = os.path.join(base_dir, 'go_res')
    dataset_save_dir = os.path.join(base_dir, 'dataset2')
    label_dir = os.path.join(base_dir, 'labeled')

    train_go_path = os.path.join(go_dir, 'train')
    train_image_path = os.path.join(label_dir, 'train', 'image')
    train_label_path = os.path.join(label_dir, 'train', 'label')
    train_dataset_dir = os.path.join(dataset_save_dir, 'train')
    train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv')

    valid_go_path = os.path.join(go_dir, 'valid')
    valid_image_path = os.path.join(label_dir, 'valid', 'image')
    valid_label_path = os.path.join(label_dir, 'valid', 'label')
    valid_dataset_dir = os.path.join(dataset_save_dir, 'valid')
    valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv')

    # max_seq_lens, seq_lens_mean, max_seq_file_name = clean_go_res(go_dir)
    # print(max_seq_lens) # 152
    # print(max_seq_file_name) # CH-B101805176_page_2_img_0.json
    # print(seq_lens_mean) # 92

    # top_text_list = text_statistics(go_dir)
    # for t in top_text_list:
    #     print(t)

    filter_from_top_text_list = [
        ('机器编号', 496),
        ('购买方名称', 496),
        ('合格证号', 495),
        ('进口证明书号', 495),
        ('机打代码', 494),
        ('车辆类型', 492),
        ('完税凭证号码', 492),
        ('机打号码', 491),
        ('发动机号码', 491),
        ('主管税务', 491),
        ('价税合计', 489),
        ('机关及代码', 489),
        ('销货单位名称', 486),
        ('厂牌型号', 485),
        ('产地', 485),
        ('商检单号', 483),
        ('电话', 476),
        ('开户银行', 472),
        ('车辆识别代号/车架号码', 463),
        ('身份证号码', 454),
        ('吨位', 452),
        ('备注:一车一票', 439),
        ('地', 432),
        ('账号', 431),
        ('统一社会信用代码/', 424),
        ('限乘人数', 404),
        ('税额', 465),
        ('址', 392)
    ]

    skip_list_train = [
        'CH-B101910792-page-12.jpg',
        'CH-B101655312-page-13.jpg',
        'CH-B102278656.jpg',
        'CH-B101846620_page_1_img_0.jpg',
        'CH-B103062528-0.jpg',
        'CH-B102613120-3.jpg',
        'CH-B102997980-3.jpg',
        'CH-B102680060-3.jpg',
        # 'CH-B102995500-2.jpg',  # 没value
    ]

    skip_list_valid = [
        'CH-B102897920-2.jpg',
        'CH-B102551284-0.jpg',
        'CH-B102879376-2.jpg',
        'CH-B101509488-page-16.jpg',
        'CH-B102708352-2.jpg',
    ]

    build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
    build_anno_file(train_dataset_dir, train_anno_file_path)

    build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
    build_anno_file(valid_dataset_dir, valid_anno_file_path)