create_dataset.py 9.21 KB
import os
import cv2
import uuid
import json
import random
import copy

import pandas as pd
from tools import get_file_paths, load_json


def text_statistics(go_res_dir):
    """
    Args:
        go_res_dir: str 通用OCR的JSON文件夹
    Returns: list 出现次数最多的文本及其次数
    """
    json_count = 0
    text_dict = {}
    go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
    for go_res_json_path in go_res_json_paths:
        print('Info: start {0}'.format(go_res_json_path))
        json_count += 1
        go_res = load_json(go_res_json_path)
        for _, text in go_res.values():
            if text in text_dict:
                text_dict[text] += 1
            else:
                text_dict[text] = 1
    top_text_list = []
    # 按照次数排序
    for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True):
        if text == '':
            continue
        # 丢弃:次数少于总数的2/3
        if count <= json_count // 3:
            break
        top_text_list.append((text, count))
    return top_text_list

def build_anno_file(dataset_dir, anno_file_path):
    img_list = os.listdir(dataset_dir)
    random.shuffle(img_list)
    df = pd.DataFrame(columns=['name'])
    df['name'] = img_list
    df.to_csv(anno_file_path)

def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir):
    """
    Args:
        img_dir: str 图片目录
        go_res_dir: str 通用OCR的JSON保存目录
        label_dir: str 标注的JSON保存目录
        top_text_list: list 出现次数最多的文本及其次数
        skip_list: list 跳过的图片列表
        save_dir: str 数据集保存目录
    """
    # if os.path.exists(save_dir):
    #     return
    # else:
    #     os.makedirs(save_dir, exist_ok=True)

    count = 0
    un_count = 0
    top_text_count = len(top_text_list)
    for img_name in sorted(os.listdir(img_dir)):
        if img_name in skip_list:
            print('Info: skip {0}'.format(img_name))
            continue

        print('Info: start {0}'.format(img_name))
        image_path = os.path.join(img_dir, img_name)
        img = cv2.imread(image_path)
        h, w, _ = img.shape
        base_image_name, _ = os.path.splitext(img_name)
        go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name))
        go_res = load_json(go_res_json_path)

        input_key_list = []
        not_found_count = 0
        go_key_set = set()
        for top_text, _ in top_text_list:
            for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), text) in go_res.items():
                if text == top_text:
                    input_key_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
                    go_key_set.add(go_key)
                    break
            else:
                not_found_count += 1
                input_key_list.append([0, 0, 0, 0, 0, 0, 0, 0])
        if not_found_count >= top_text_count // 3:
            print('Info: skip {0} : {1}/{2}'.format(img_name, not_found_count, top_text_count))
            continue

        label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name))
        label_res = load_json(label_json_path)

        # 开票日期 发票代码 机打号码 车辆类型 电话
        test_group_id = [1, 2, 5, 9, 20]
        group_list = []
        for group_id in test_group_id:
            for item in label_res.get("shapes", []):
                if item.get("group_id") == group_id:
                    x_list = []
                    y_list = []
                    for point in item['points']:
                        x_list.append(point[0])
                        y_list.append(point[1])
                    group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2])
                    break
            else:
                group_list.append(None)
        
        go_center_list = []
        for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), _) in go_res.items():
            if go_key in go_key_set:
                continue
            xmin = min(x0, x1, x2, x3)
            ymin = min(y0, y1, y2, y3)
            xmax = max(x0, x1, x2, x3)
            ymax = max(y0, y1, y2, y3)
            xcenter = xmin + (xmax - xmin)/2
            ycenter = ymin + (ymax - ymin)/2
            go_center_list.append([xcenter, ycenter, go_key])
        
        group_go_key_list = []
        for label_center_list in group_list:
            if isinstance(label_center_list, list):
                min_go_key = None
                min_length = None
                for go_x_center, go_y_center, go_key in go_center_list:
                    if go_key in go_key_set:
                        continue
                    length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1])
                    if min_go_key is None or length < min_length:
                        min_go_key = go_key
                        min_length = length
                if min_go_key is not None:
                    go_key_set.add(min_go_key)
                    group_go_key_list.append(min_go_key)
                else:
                    group_go_key_list.append(None)
            else:
                group_go_key_list.append(None)

        src_label_list = [0 for _ in test_group_id]
        for idx, find_go_key in enumerate(group_go_key_list):
            if find_go_key is None:
                continue
            (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res[find_go_key]
            input_list = copy.deepcopy(input_key_list)
            input_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])

            input_label = copy.deepcopy(src_label_list)
            input_label[idx] = 1
            # with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, '{0}-{1}'.format(img_name, find_go_key)))), 'w') as fp:
            #     json.dump([input_list, input_label], fp)
            count += 1

        for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), _) in go_res.items(): 
            if go_key in go_key_set:
                continue
            input_list = copy.deepcopy(input_key_list)
            input_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
            # with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, '{0}-{1}'.format(img_name, go_key)))), 'w') as fp:
            #     json.dump([input_list, src_label_list], fp)
            un_count += 1

        # break
    print(count)
    print(un_count)


if __name__ == '__main__':
    base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
    go_dir = os.path.join(base_dir, 'go_res')
    dataset_save_dir = os.path.join(base_dir, 'dataset')
    label_dir = os.path.join(base_dir, 'labeled')

    train_go_path = os.path.join(go_dir, 'train')
    train_image_path = os.path.join(label_dir, 'train', 'image')
    train_label_path = os.path.join(label_dir, 'train', 'label')
    train_dataset_dir = os.path.join(dataset_save_dir, 'train')
    train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv')

    valid_go_path = os.path.join(go_dir, 'valid')
    valid_image_path = os.path.join(label_dir, 'valid', 'image')
    valid_label_path = os.path.join(label_dir, 'valid', 'label')
    valid_dataset_dir = os.path.join(dataset_save_dir, 'valid')
    valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv')

    # top_text_list = text_statistics(go_dir)
    # for t in top_text_list:
    #     print(t)

    filter_from_top_text_list = [
        ('机器编号', 496),
        ('购买方名称', 496),
        ('合格证号', 495),
        ('进口证明书号', 495),
        ('机打代码', 494),
        ('车辆类型', 492),
        ('完税凭证号码', 492),
        ('机打号码', 491),
        ('发动机号码', 491),
        ('主管税务', 491),
        ('价税合计', 489),
        ('机关及代码', 489),
        ('销货单位名称', 486),
        ('厂牌型号', 485),
        ('产地', 485),
        ('商检单号', 483),
        ('电话', 476),
        ('开户银行', 472),
        ('车辆识别代号/车架号码', 463),
        ('身份证号码', 454),
        ('吨位', 452),
        ('备注:一车一票', 439),
        ('地', 432),
        ('账号', 431),
        ('统一社会信用代码/', 424),
        ('限乘人数', 404),
        ('税额', 465),
        ('址', 392)
    ]

    skip_list_train = [
        'CH-B101910792-page-12.jpg',
        'CH-B101655312-page-13.jpg',
        'CH-B102278656.jpg',
        'CH-B101846620_page_1_img_0.jpg',
        'CH-B103062528-0.jpg',
        'CH-B102613120-3.jpg',
        'CH-B102997980-3.jpg',
        'CH-B102680060-3.jpg',
        # 'CH-B102995500-2.jpg',  # 没value
    ]

    skip_list_valid = [
        'CH-B102897920-2.jpg',
        'CH-B102551284-0.jpg',
        'CH-B102879376-2.jpg',
        'CH-B101509488-page-16.jpg',
        'CH-B102708352-2.jpg',
    ]

    # build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)

    build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)

    # build_anno_file(train_dataset_dir, train_anno_file_path)
    # build_anno_file(valid_dataset_dir, valid_anno_file_path)