make_seg_anno.py 2.99 KB
import os
import numpy as np
import cv2
import json
from tqdm import tqdm
from math import ceil
import pandas as pd


src_root = '/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET'
anno_root = '/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET_SEG_ANNOTATIONS_FULL'
csv_root = '/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET_SEG_ANNOTATIONS_FULL/csv'
if not os.path.exists(anno_root):
    os.mkdir(anno_root)

if not os.path.exists(os.path.join(anno_root, 'train')):
    os.mkdir(os.path.join(anno_root, 'train'))

if not os.path.exists(os.path.join(anno_root, 'val')):
    os.mkdir(os.path.join(anno_root, 'val'))


def read_list(root):
    all_list = sorted(os.listdir(root))
    name_list = list()
    for all_name in all_list:
        name_list.append(all_name.split('.')[0])
    name_list = list(set(name_list))

    return name_list


def read_dict(file):
    f = open(file)
    dict = json.load(f)
    shape_dict_list = dict['shapes']
    f.close()

    return shape_dict_list


def parse_shapes_dict(shape_dict_list):
    points_list = list()
    for shape_dict in shape_dict_list:
        points = shape_dict['points']
        points_list = points_list + points

    return points_list


def get_seg_points(points_list):
    x_list = list()
    y_list = list()
    for points in points_list:
        x_list.append(points[0])
        y_list.append(points[1])

    max_x = max(x_list)
    min_x = min(x_list)
    max_y = max(y_list)
    min_y = min(y_list)

    return [int(min_x), int(min_y), ceil(max_x), ceil(max_y)]

def gen_mask(h, w, c):
    # min_x, min_y, max_x, max_y = points
    mask = 255 * np.ones((h, w, c))

    return mask.astype(np.uint8)

def process(root):
    sub_dir = root.split('/')[-1]
    name_list = read_list(root)
    for name in tqdm(name_list):
    # for name in name_list:
        img = cv2.imread(os.path.join(root, name+'.jpg'))
        h, w, c = img.shape
        json_path = os.path.join(root, name+'.json')
        if os.path.getsize(json_path) != 0:
            # shape_dict_list = read_dict(json_path)
            # points_list = parse_shapes_dict(shape_dict_list)
            # points = get_seg_points(points_list)
            # mask = gen_mask(points, h, w, c)
            mask = gen_mask(h, w, c)
            if not os.path.exists(os.path.join(anno_root, sub_dir)):
                os.mkdir(os.path.join(anno_root, sub_dir))
            cv2.imwrite(os.path.join(anno_root, sub_dir, name+'.jpg'), mask)
        else:
            print(f'{name}.json size is 0!')


def gen_csv(root):
    sub_dir = root.split('/')[-1]
    img_name_list = sorted(os.listdir(root))
    df = pd.DataFrame(columns=['path'])
    df.path = img_name_list

    df.to_csv(os.path.join(csv_root, sub_dir+'.csv'))

def main():
    print('processing train...')
    process(os.path.join(src_root, 'train'))
    # gen_csv(os.path.join(anno_root, 'train'))
    print('processing val...')
    process(os.path.join(src_root, 'val'))
    # gen_csv(os.path.join(anno_root, 'val'))

if __name__ == '__main__':
    main()