data_creation.py 5.23 KB
import os
import cv2
import numpy as np
from PIL import Image
from tqdm.contrib import tzip
import threading

src_root = '/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET'
seg_root = '/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET_SEG_ANNOTATIONS_FULL'
train_seg_root = os.path.join(seg_root, 'train')
val_seg_root = os.path.join(seg_root, 'val')
bg_root = '/home/mly/data/datasets/humanMatting/bg'

train_root = os.path.join(src_root, 'train')
val_root = os.path.join(src_root, 'val')
# gen_root = '/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET_SYN/'
gen_root = '/home/mly/data/datasets/text_recognition/CDLA/new_syn'
if not os.path.exists(gen_root):
    os.mkdir(gen_root)
if not os.path.exists(os.path.join(gen_root, 'img')):
    os.mkdir(os.path.join(gen_root, 'img'))
if not os.path.exists(os.path.join(gen_root, 'mask')):
    os.mkdir(os.path.join(gen_root, 'mask'))
if not os.path.exists(os.path.join(gen_root, 'edge')):
    os.mkdir(os.path.join(gen_root, 'edge'))
gen_img_root = os.path.join(gen_root, 'img')
gen_mask_root = os.path.join(gen_root, 'mask')
gen_edge_root = os.path.join(gen_root, 'edge')


def get_img_mask_list(root):
    file_list = os.listdir(root)
    img_list = list()
    for file in file_list:
        if file[-1] == 'g':
            img_list.append(file)

    return img_list


def get_img_mask_full_path_list(img_list, img_root, mask_root):
    img_full_path_list = list()
    mask_full_path_list = list()
    for img_name in img_list:
        img_full_path_list.append(os.path.join(img_root, img_name))
        mask_full_path_list.append(os.path.join(mask_root, img_name))

    return img_full_path_list, mask_full_path_list


def read_bg_list(root):
    img_list = os.listdir(root)
    needed = list()
    for img_name in img_list:
        if img_name[-1] == 'g':
            needed.append(os.path.join(root, img_name))

    return needed


def paste_img_and_mask(img, mask, bg):
    img = cv2.imread(img)
    # mask = cv2.imread(mask)
    # img = cv2.resize(img, (512, 512))
    # mask = cv2.resize(mask, (512, 512), cv2.INTER_NEAREST)
    resize_h = int(np.random.randint(448, 2048, 1))
    resize_w = int(np.random.randint(448, 2048, 1))
    img = cv2.resize(img, (resize_h, resize_w))
    mask = np.ones(img.shape) * 255
    y_max = 2048 - img.shape[1]
    x_max = 2048 - img.shape[0]
    x = int(np.random.randint(0, x_max, 1))
    y = int(np.random.randint(0, y_max, 1))
    point = (x, y)
    try:
        bg = cv2.imread(bg)
        bg = cv2.resize(bg, (2048, 2048))
    except BaseException:
        print('error and replace by 4257')
        bg = cv2.imread('/home/mly/data/datasets/humanMatting/bg/办公桌_4257.jpg')
        bg = cv2.resize(bg, (2048, 2048))
    bg = cv2.resize(bg, (2048, 2048))
    bg_mask = np.zeros_like(bg)
    # img[img == 0] = 100
    bg[point[0]: point[0] + img.shape[0], point[1]: point[1] + img.shape[1], :] = img
    bg_mask[point[0]: point[0] + img.shape[0], point[1]: point[1] + img.shape[1], :] = mask
    edge = np.asarray(bg_mask.copy())
    edge = cv2.Canny(edge, 50, 255)

    return bg, bg_mask, edge


def generate(img_list, mask_list, gen_iter=5):
    # img_list = img_list[:10]
    bg_list = read_bg_list(bg_root)
    len_bg_list = len(bg_list)
    len_img_list = len(img_list)
    for it in range(gen_iter):
        for img, mask in tzip(img_list, mask_list):
            idx = int(np.random.randint(0, len_bg_list, 1))
            while os.path.getsize(bg_list[idx]) < 100:
                idx = int(np.random.randint(0, len_bg_list, 1))
            bn = img.split('/')[-2]
            name = img.split('/')[-1].split('.')[0]
            img, mask, edge = paste_img_and_mask(img, mask, bg_list[idx])
            if not os.path.exists(os.path.join(gen_img_root, bn)):
                os.mkdir(os.path.join(gen_img_root, bn))
            cv2.imwrite(os.path.join(gen_img_root, bn, name + '_' + str(it) + '.jpg'), img)
            if not os.path.exists(os.path.join(gen_mask_root, bn)):
                os.mkdir(os.path.join(gen_mask_root, bn))
            cv2.imwrite(os.path.join(gen_mask_root, bn, name + '_' + str(it) + '.jpg'), mask)
            if not os.path.exists(os.path.join(gen_edge_root, bn)):
                os.mkdir(os.path.join(gen_edge_root, bn))
            cv2.imwrite(os.path.join(gen_edge_root, bn, name + '_' + str(it) + '.jpg'), edge)


def main():
    train_img_mask_list = get_img_mask_list(root=train_root)
    val_img_mask_list = get_img_mask_list(root=val_root)
    train_img_full_path, train_mask_full_path = get_img_mask_full_path_list(train_img_mask_list, train_root,
                                                                            train_seg_root)
    val_img_full_path, val_mask_full_path = get_img_mask_full_path_list(val_img_mask_list, val_root, val_seg_root)
    print('processing train!')
    generate(train_img_full_path, train_mask_full_path, gen_iter=6)
    print('processing val!')
    generate(val_img_full_path, val_mask_full_path, gen_iter=4)


if __name__ == '__main__':
    main()
    # t1 = threading.Thread(target=main)
    # t2 = threading.Thread(target=main)
    # t3 = threading.Thread(target=main)
    # t4 = threading.Thread(target=main)
    # t1.start()
    # t2.start()
    # t3.start()
    # t4.start()
    # t1.join()
    # t2.join()
    # t3.join()
    # t4.join()