paste_stamp.py 3.5 KB
import os
import cv2
from PIL import Image, ImageEnhance
import numpy as np
import random
import shutil
from tqdm import tqdm
import json
import multiprocessing as mp
import imgaug.augmenters as iaa


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)


stamp_img_root = '/data1/lxl/data/ocr/stamp/aug/img'
stamp_mask_root = '/data1/lxl/data/ocr/stamp/aug/mask'
stamp_blend_root = '/data1/lxl/data/ocr/stamp/aug/blend'
text_root = '/data1/lxl/data/ocr/crop_bg'
gen_root = '/data1/lxl/data/ocr/generate1108/'
gen_img_root = os.path.join(gen_root, 'img')
gen_stamp_img_root = os.path.join(gen_root, 'stamp_img')
gen_stamp_mask_root = os.path.join(gen_root, 'stamp_mask')
gen_text_img_root = os.path.join(gen_root, 'text_img')
mkdir(gen_img_root)
mkdir(gen_text_img_root)
mkdir(gen_stamp_img_root)
mkdir(gen_stamp_mask_root)


def random_idx(s, e):
    idx = int(np.random.randint(s, e, size=(1)))

    return idx


def get_full_path_list(root):
    path_list = list()
    name_list = sorted(os.listdir(root))
    for name in name_list:
        path_list.append(os.path.join(root, name))

    return path_list


def gen(stamp_img, blend_mask, stamp_mask, text_img, gen_img_root, gen_text_img_root, gen_stamp_mask_root, savename):
    stamp_img = Image.open(stamp_img).convert("RGB")
    blend_mask = Image.open(blend_mask)

    stamp_img_width, stamp_img_height = stamp_img.size
    stamp_mask = Image.open(stamp_mask).convert('L')
    stamp_mask_copy = stamp_mask.copy().convert('L')
    text_img = Image.open(text_img).convert("RGB")
    gen_img = text_img.copy().convert("RGB")
    x = random_idx(0, text_img.size[0] - stamp_img.size[0])
    y = random_idx(0, text_img.size[1] - stamp_img.size[1])
    
    gen_img.paste(stamp_img, (x, y), mask=blend_mask)
    
    gen_stamp_img = Image.new('RGB', size=text_img.size)
    gen_stamp_img.paste(stamp_img, (x, y), mask=blend_mask)
    gen_stamp_mask = Image.new('L', size=text_img.size)
    gen_stamp_mask.paste(stamp_mask, (x, y), mask=stamp_mask)
    stamp_coordinate = [x, y, x + stamp_img.size[0], y + stamp_img.size[1]]
    stamp_dict = {'name': str(savename), 'coordinate': stamp_coordinate, 'label': ''}

    gen_img.save(os.path.join(gen_img_root, "{:>06d}.jpg".format(savename)))
    text_img.save(os.path.join(gen_text_img_root, "{:>06d}.jpg".format(savename)))
    gen_stamp_img.save(os.path.join(gen_stamp_img_root, "{:>06d}.jpg".format(savename)))
    gen_stamp_mask.save(os.path.join(gen_stamp_mask_root, "{:>06d}.jpg".format(savename)))


def process():
    stamp_list = sorted(os.listdir(stamp_img_root))
    stamp_list_lth = len(stamp_list)
    text_list = sorted(os.listdir(text_root))
    text_list_lth = len(text_list)
    need = 20000
    pool = mp.Pool(processes=6)
    for i in range(0, need):
        stamp_idx = random_idx(0, stamp_list_lth)
        stamp_img_path = os.path.join(stamp_img_root, stamp_list[stamp_idx])
        stamp_mask_path = os.path.join(stamp_mask_root, stamp_list[stamp_idx])
        blend_mask_path = os.path.join(stamp_blend_root, stamp_list[stamp_idx])
        text_idx = random_idx(0, text_list_lth)
        text_img_path = os.path.join(text_root, text_list[text_idx])
        pool.apply_async(gen, (stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,))
        # gen(stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,)
    pool.close()
    pool.join()


def main():
    process()

if __name__ == '__main__':
    main()