mk_new_syn.py 2.37 KB
import os
from PIL import Image
import random
from tqdm import tqdm
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import multiprocessing as mp


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

syn_root = '/home/mly/data/datasets/text_recognition/CDLA/syn'
src_root = '/home/mly/data/datasets/text_recognition/CDLA/gen/'
bg_root = '/home/mly/data/datasets/humanMatting/bg/'
src_train_root = os.path.join(src_root, 'train')
src_val_root = os.path.join(src_root, 'val')
img_train_root = os.path.join(src_train_root, 'img')
mask_train_root = os.path.join(src_train_root, 'mask')
img_val_root = os.path.join(src_val_root, 'img')
mask_val_root = os.path.join(src_val_root, 'mask')

img_train_syn_root = os.path.join(syn_root, 'train', 'img')
mask_train_syn_root = os.path.join(syn_root, 'train', 'mask')
img_val_syn_root = os.path.join(syn_root, 'val', 'img')
mask_val_syn_root = os.path.join(syn_root, 'val', 'mask')

mkdir(img_train_syn_root)
mkdir(mask_train_syn_root)
mkdir(img_val_syn_root)
mkdir(mask_val_syn_root)


def process(img_root, mask_root, gen_img_root, gen_mask_root, bg_list):
    bg_lth = len(bg_list)
    name_list = os.listdir(img_root)
    
    for name in name_list:
        img = Image.open(os.path.join(img_root, name)).convert('RGB')
        mask = Image.open(os.path.join(mask_root, name)).convert('L')
        bg_idx = random.randint(0, bg_lth - 1)
        while os.path.getsize(bg_list[bg_idx]) <= 200:
            bg_idx = random.randint(0, bg_lth - 1)
        bg = Image.open(bg_list[bg_idx]).convert('RGB')
        bg = bg.resize((2048, 2048))
        width = random.randint(512, 2040)
        height = random.randint(512, 2040)
        img = img.resize((width, height))
        mask = mask.resize((width, height))
        x = random.randint(0, (2048 - width) - 1)
        y = random.randint(0, (2048 - height) - 1)
        bg.paste(img, (x, y), mask)
        new_mask = Image.new('L', bg.size, 0)
        new_mask.paste(mask, (x, y), mask)

        bg.save(os.path.join(gen_img_root, name))
        new_mask.save(os.path.join(gen_mask_root, name))

def main():
    bg_list = [os.path.join(bg_root, bg) for bg in os.listdir(bg_root)]
    process(img_train_root, mask_train_root, img_train_syn_root, mask_train_syn_root, bg_list)
    process(img_val_root, mask_val_root, img_val_syn_root, mask_val_syn_root, bg_list)

if __name__ == '__main__':
    main()