create_paste_version.py 5.96 KB
import os
import cv2
import numpy as np
from PIL import Image
# import albumentations as A
import imgaug as ia
import imgaug.augmenters as iaa
from tqdm.contrib import tzip

src_root = '/home/mly/data/datasets/text_recognition/CDLA/gen/'
seg_root = '/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET_SEG_ANNOTATIONS_FULL'
train_seg_root = os.path.join(seg_root, 'train')
val_seg_root = os.path.join(seg_root, 'val')
bg_root = '/home/mly/data/datasets/humanMatting/bg'

train_root = os.path.join(src_root, 'train')
val_root = os.path.join(src_root, 'val')
gen_root = '/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET_SYN/'
if not os.path.exists(os.path.join(gen_root, 'img')):
    os.mkdir(os.path.join(gen_root, 'img'))
if not os.path.exists(os.path.join(gen_root, 'mask')):
    os.mkdir(os.path.join(gen_root, 'mask'))
if not os.path.exists(os.path.join(gen_root, 'edge')):
    os.mkdir(os.path.join(gen_root, 'edge'))
gen_img_root = os.path.join(gen_root, 'img')
gen_mask_root = os.path.join(gen_root, 'mask')
gen_edge_root = os.path.join(gen_root, 'edge')


def get_img_mask_list(root):
    file_list = os.listdir(root)
    img_list = list()
    for file in file_list:
        if file[-1] == 'g':
            img_list.append(file)

    return img_list


def get_img_mask_full_path_list(img_list, img_root, mask_root):
    img_full_path_list = list()
    mask_full_path_list = list()
    for img_name in img_list:
        img_full_path_list.append(os.path.join(img_root, img_name))
        mask_full_path_list.append(os.path.join(mask_root, img_name))

    return img_full_path_list, mask_full_path_list


def read_bg_list(root):
    img_list = os.listdir(root)
    needed = list()
    for img_name in img_list:
        needed.append(os.path.join(root, img_name))

    return needed


def img_aug(img, mask):
    # transform = A.Compose([
    #     A.GaussNoise(p=0.5),
    #     A.Rotate(limit=20, p=0.3),
    #     A.RandomRotate90(p=0.5),
    #     A.RandomBrightnessContrast(p=0.5),
    #     A.Affine(p=0.5),
    #     A.ElasticTransform(p=0.5)
    # ])

    # augmented = transform(image=img, mask=mask)
    # img = augmented['image']
    # mask = augmented['mask']


    # return img, mask

    # transform = iaa.Sequential(
    #     iaa.Fliplr(0.5),
    #     iaa.Flipud(0.2),
    #     iaa.Sometimes(0.4, iaa.CropAndPad(percent=(-0.3, 0.3), pad_mode=ia.ALL, pad_cval=(0, 255), keep_size=False)),
    #     iaa.OneOf(
    #         iaa.GaussianBlur((0, 3.0)),
    #         iaa.AverageBlur(k=(2, 7)),
    #         iaa.MedianBlur(k=(3, 11)),
    #     ),
    #     iaa.Sometimes(0.2, iaa.Dropout(p=(0, 0.1), per_channel=0.5)),
    #     iaa.PerspectiveTransform(scale=(0.01, 0.3), keep_size=False),
    #     iaa.ElasticTransformation(alpha=(0, 5.0), sigma=0.25),
    #     iaa.OneOf(
    #         iaa.Rot90((1, 3), keep_size=False),
    #         iaa.Rotate((-45, 45), keep_size=False)
    #     )
    # )

    # img, mask = transform(images=img, segmentation_maps=mask)

    return img, mask


def paste_img_and_mask(img, mask, bg):
    img = cv2.resize(img, (512, 512))
    mask = cv2.resize(mask, (512, 512), cv2.INTER_NEAREST)
    y_max = 2048 - img.shape[1]
    x_max = 2048 - img.shape[0]
    x = int(np.random.randint(0, x_max, 1))
    y = int(np.random.randint(0, y_max, 1))
    point = (x, y)
    bg = cv2.imread(bg)
    bg = cv2.resize(bg, (2048, 2048))
    bg_mask = np.zeros_like(bg)
    bg[point[0]: point[0] + img.shape[0], point[1]: point[1] + img.shape[1], :] = img
    bg_mask[point[0]: point[0] + img.shape[0], point[1]: point[1] + img.shape[1], :] = mask
    edge = np.asarray(bg_mask.copy())
    edge = cv2.Canny(edge, 50, 150)

    return bg, bg_mask, edge


def generate(img_list, mask_list, gen_iter=4):
    bg_list = read_bg_list(bg_root)
    len_img_list = len(img_list)
    lth = len(bg_list)
    for it in range(gen_iter):
        print(f'processing iteration: {it}')
        cnt = 0
        for img, mask in tzip(img_list, mask_list):
            print(img)
            print(mask)
            img_path = img
            name = img_path.split('/')[-1].split('.')[0]
            cnt += 1
            bn = img_path.split('/')[-2]
            img = cv2.imread(img)
            mask = cv2.imread(mask)
            img, mask = img_aug(img, mask)
            idx = int(np.random.randint(0, lth, size=1))
            while os.path.getsize(bg_list[idx]) <= 100:
                idx = int(np.random.randint(0, lth, size=1))
            img, mask, edge = paste_img_and_mask(img, mask, bg_list[idx])
            if not os.path.exists(os.path.join(gen_img_root, bn)):
                os.mkdir(os.path.join(gen_img_root, bn))
            if not os.path.exists(os.path.join(gen_mask_root, bn)):
                os.mkdir(os.path.join(gen_mask_root, bn))
            if not os.path.exists(os.path.join(gen_edge_root, bn)):
                os.mkdir(os.path.join(gen_edge_root, bn))
            cv2.imwrite(os.path.join(gen_img_root, bn, name + '_' + str(it) + '.jpg'), img)
            cv2.imwrite(os.path.join(gen_mask_root, bn, name + '_' + str(it) + '.jpg'), mask)
            cv2.imwrite(os.path.join(gen_edge_root, bn, name + '_' + str(it) + '.jpg'), edge)


# def demo():
#     paper_img = Image.open('/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET/train/train_5000.jpg')
#     print(paper_img.size)
#     bg_img = Image.open('/home/mly/data/datasets/humanMatting/bg/办公桌_4057.jpg')
#     bg_img = bg_img.resize((2048, 2048))
#     bg_img.paste(paper_img, (0, 0))
#     bg_img.save('./paste.jpg')


def main():
    train_img_mask_list = get_img_mask_list(root=train_root)
    val_img_mask_list = get_img_mask_list(root=val_root)
    train_img_full_path, train_mask_full_path = get_img_mask_full_path_list(train_img_mask_list, train_root, train_seg_root)
    val_img_full_path, val_mask_full_path = get_img_mask_full_path_list(val_img_mask_list, val_root, val_seg_root)
    generate(train_img_full_path, train_mask_full_path, gen_iter=2)
    generate(val_img_full_path, val_mask_full_path, gen_iter=1)


if __name__ == '__main__':
    main()