paste_stamp.py
3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import cv2
from PIL import Image, ImageEnhance
import numpy as np
import random
import shutil
from tqdm import tqdm
import json
import multiprocessing as mp
import imgaug.augmenters as iaa
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
stamp_img_root = '/data1/lxl/data/ocr/stamp/aug/img'
stamp_mask_root = '/data1/lxl/data/ocr/stamp/aug/mask'
stamp_blend_root = '/data1/lxl/data/ocr/stamp/aug/blend'
text_root = '/data1/lxl/data/ocr/crop_bg'
gen_root = '/data1/lxl/data/ocr/generate1108/'
gen_img_root = os.path.join(gen_root, 'img')
gen_stamp_img_root = os.path.join(gen_root, 'stamp_img')
gen_stamp_mask_root = os.path.join(gen_root, 'stamp_mask')
gen_text_img_root = os.path.join(gen_root, 'text_img')
mkdir(gen_img_root)
mkdir(gen_text_img_root)
mkdir(gen_stamp_img_root)
mkdir(gen_stamp_mask_root)
def random_idx(s, e):
idx = int(np.random.randint(s, e, size=(1)))
return idx
def get_full_path_list(root):
path_list = list()
name_list = sorted(os.listdir(root))
for name in name_list:
path_list.append(os.path.join(root, name))
return path_list
def gen(stamp_img, blend_mask, stamp_mask, text_img, gen_img_root, gen_text_img_root, gen_stamp_mask_root, savename):
stamp_img = Image.open(stamp_img).convert("RGB")
blend_mask = Image.open(blend_mask)
stamp_img_width, stamp_img_height = stamp_img.size
stamp_mask = Image.open(stamp_mask).convert('L')
stamp_mask_copy = stamp_mask.copy().convert('L')
text_img = Image.open(text_img).convert("RGB")
gen_img = text_img.copy().convert("RGB")
x = random_idx(0, text_img.size[0] - stamp_img.size[0])
y = random_idx(0, text_img.size[1] - stamp_img.size[1])
gen_img.paste(stamp_img, (x, y), mask=blend_mask)
gen_stamp_img = Image.new('RGB', size=text_img.size)
gen_stamp_img.paste(stamp_img, (x, y), mask=blend_mask)
gen_stamp_mask = Image.new('L', size=text_img.size)
gen_stamp_mask.paste(stamp_mask, (x, y), mask=stamp_mask)
stamp_coordinate = [x, y, x + stamp_img.size[0], y + stamp_img.size[1]]
stamp_dict = {'name': str(savename), 'coordinate': stamp_coordinate, 'label': ''}
gen_img.save(os.path.join(gen_img_root, "{:>06d}.jpg".format(savename)))
text_img.save(os.path.join(gen_text_img_root, "{:>06d}.jpg".format(savename)))
gen_stamp_img.save(os.path.join(gen_stamp_img_root, "{:>06d}.jpg".format(savename)))
gen_stamp_mask.save(os.path.join(gen_stamp_mask_root, "{:>06d}.jpg".format(savename)))
def process():
stamp_list = sorted(os.listdir(stamp_img_root))
stamp_list_lth = len(stamp_list)
text_list = sorted(os.listdir(text_root))
text_list_lth = len(text_list)
need = 20000
pool = mp.Pool(processes=6)
for i in range(0, need):
stamp_idx = random_idx(0, stamp_list_lth)
stamp_img_path = os.path.join(stamp_img_root, stamp_list[stamp_idx])
stamp_mask_path = os.path.join(stamp_mask_root, stamp_list[stamp_idx])
blend_mask_path = os.path.join(stamp_blend_root, stamp_list[stamp_idx])
text_idx = random_idx(0, text_list_lth)
text_img_path = os.path.join(text_root, text_list[text_idx])
pool.apply_async(gen, (stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,))
# gen(stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,)
pool.close()
pool.join()
def main():
process()
if __name__ == '__main__':
main()