stamp_aug.py 2.27 KB
import imgaug.augmenters as iaa
import numpy as np
import cv2
import os
from tqdm import tqdm
from PIL import Image

seq = iaa.Sequential(
    [
        iaa.Fliplr(0.5),
        iaa.Crop(percent=(0, 0.1), keep_size=True),
        iaa.Sometimes(0.5, iaa.GaussianBlur(sigma=(0, 0.5))),
        iaa.AddElementwise((-40, -10), per_channel=0.5),
        iaa.Sometimes(0.5, iaa.MultiplyElementwise((0.7, 1.0))),
        iaa.OneOf([
            iaa.Rotate((-45, 45)),
            iaa.Rot90((1, 3))
        ]),
        iaa.Sometimes(0.7, iaa.CoarseDropout(p=(0.1, 0.4), size_percent=(0.02, 0.2))),
        iaa.Sometimes(0.02, iaa.imgcorruptlike.MotionBlur(severity=2)),
])

gen_blend = iaa.Sequential([
    iaa.GaussianBlur(sigma=(0, 0.5)),
    iaa.MultiplyElementwise((0.8, 3.5)),
])


img_root = '/data1/lxl/data/ocr/stamp/src'
gen_root = '/data1/lxl/data/ocr/stamp/aug'
gen_img_root = os.path.join(gen_root, 'img')
gen_mask_root = os.path.join(gen_root, 'mask')
gen_blend_root = os.path.join(gen_root, 'blend')
if not os.path.exists(gen_img_root):
    os.makedirs(gen_img_root)
if not os.path.exists(gen_mask_root):
    os.makedirs(gen_mask_root)
if not os.path.exists(gen_blend_root):
    os.makedirs(gen_blend_root)

name_list = sorted(os.listdir(img_root))

for i in range(2):
    for name in tqdm(name_list):
        name_no_ex = name.split('.')[0]
        ext = name.split('.')[1]
        img = cv2.imread(os.path.join(img_root, name), -1)[:, :, :3]
        mask = cv2.imread(os.path.join(img_root, name), -1)[:, :, -1]
        mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
        img = np.asarray(img, dtype=np.uint8)
        mask = np.asarray(mask, dtype=np.uint8)
        img = img[np.newaxis, :]
        mask = mask[np.newaxis, :]
        img_aug, mask_aug = seq(images=img, segmentation_maps=mask)
        img_aug = img_aug.squeeze()
        mask_aug = mask_aug.squeeze(0)

        blend = cv2.cvtColor(img_aug, cv2.COLOR_BGR2GRAY)
        blend_aug = gen_blend(images=blend)
        blend_aug = blend_aug.squeeze()

        cv2.imwrite(os.path.join(gen_img_root, (name_no_ex + '_' + str(i) + '.' + ext)), img_aug)
        cv2.imwrite(os.path.join(gen_mask_root, (name_no_ex + '_' + str(i) + '.' + ext)), mask_aug)
        cv2.imwrite(os.path.join(gen_blend_root, (name_no_ex + '_' + str(i) + '.' + ext)), blend_aug)