cdc6a87d by 刘晓龙

init commit

0 parents
1 # template
...\ No newline at end of file ...\ No newline at end of file
1 seed: 3407
2
3 dataset:
4 name: 'ReconData'
5 args:
6 data_root: '/data1/lxl/data/ocr/generate1108'
7 train_anno_file: '/data1/lxl/data/ocr/generate1108/train.csv'
8 val_anno_file: '/data1/lxl/data/ocr/generate1108/val.csv'
9 fixed_size: 512
10
11 dataloader:
12 batch_size: 8
13 num_workers: 16
14 pin_memory: true
15 collate_fn: 'base_collate_fn'
16
17 model:
18 name: 'Unet'
19 args:
20 encoder_name: 'resnet50'
21 encoder_weights: 'imagenet'
22 in_channels: 3
23 classes: 3
24 activation: 'tanh'
25
26 solver:
27 name: 'BaseSolver'
28 args:
29 epoch: 30
30
31 optimizer:
32 name: 'Adam'
33 args:
34 lr: !!float 1e-4
35 weight_decay: !!float 5e-5
36
37 lr_scheduler:
38 name: 'StepLR'
39 args:
40 step_size: 15
41 gamma: 0.1
42
43 loss:
44 name: 'L1Loss'
45 args:
46 reduction: 'mean'
47
48 logger:
49 log_root: '/data1/lxl/code/ocr/removal/log'
50 suffix: 'residual'
51
52 metric:
53 name: 'Recon'
1 seed: 3407
2
3 dataset:
4 name: 'ReconData'
5 args:
6 data_root: '/data1/lxl/data/ocr/generate1108'
7 train_anno_file: '/data1/lxl/data/ocr/generate1108/train.csv'
8 val_anno_file: '/data1/lxl/data/ocr/generate1108/val.csv'
9 fixed_size: 512
10
11 dataloader:
12 batch_size: 8
13 num_workers: 16
14 pin_memory: true
15 collate_fn: 'base_collate_fn'
16
17 model:
18 name: 'UnetSkip'
19 args:
20 encoder_name: 'resnet50'
21 encoder_weights: 'imagenet'
22 in_channels: 3
23 classes: 3
24 activation: 'tanh'
25
26 solver:
27 name: 'SkipSolver'
28 args:
29 epoch: 30
30
31 optimizer:
32 name: 'Adam'
33 args:
34 lr: !!float 1e-4
35 weight_decay: !!float 5e-5
36
37 lr_scheduler:
38 name: 'StepLR'
39 args:
40 step_size: 15
41 gamma: 0.1
42
43 loss:
44 name: 'L1Loss'
45 args:
46 reduction: 'mean'
47
48 logger:
49 log_root: '/data1/lxl/code/ocr/removal/log'
50 suffix: 'skip'
51
52 metric:
53 name: 'Recon'
1 from .solver import build_solver
1 from torch.utils.data import DataLoader, Dataset
2 import torchvision.transforms as transforms
3 import torch
4 import pandas as pd
5 import os
6 import random
7 import cv2
8 import albumentations as A
9 import albumentations.pytorch
10 import numpy as np
11 from PIL import Image
12
13 from utils.registery import DATASET_REGISTRY
14
15
16 @DATASET_REGISTRY.register()
17 class ReconData(Dataset):
18 def __init__(self,
19 data_root: str = '/data1/lxl/data/ocr/generate1108',
20 anno_file: str = 'train.csv',
21 fixed_size: int = 448,
22 phase: str = 'train'):
23 self.data_root = data_root
24 self.df = pd.read_csv(anno_file)
25 self.img_root = os.path.join(data_root, 'img')
26 self.gt_root = os.path.join(data_root, 'text_img')
27 self.fixed_size = fixed_size
28
29 self.phase = phase
30 transform_fn = self.__get_transform()
31 self.transform = transform_fn[phase]
32
33 def __get_transform(self):
34
35 train_transform = A.Compose([
36 A.Resize(height=self.fixed_size, width=self.fixed_size),
37 # A.RandomBrightness(limit=(-0.5, 0), p=0.5),
38 A.RandomBrightnessContrast(brightness_limit=(-0.5, 0), contrast_limit=0, p=0.5),
39 A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
40 A.pytorch.transforms.ToTensorV2()
41 ], additional_targets={'label': 'image'})
42
43 val_transform = A.Compose([
44 A.Resize(height=self.fixed_size, width=self.fixed_size),
45 A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
46 A.pytorch.transforms.ToTensorV2()
47 ], additional_targets={'label': 'image'})
48
49 transform_fn = {'train': train_transform, 'val': val_transform}
50
51 return transform_fn
52
53 def __len__(self):
54 return len(self.df)
55
56
57 def __getitem__(self, idx):
58 series = self.df.iloc[idx]
59 name = series['name']
60
61 # img = Image.open(os.path.join(self.img_root, name))
62 # gt = Image.open(os.path.join(self.gt_root, name))
63 img = cv2.imread(os.path.join(self.img_root, name))
64 gt = cv2.imread(os.path.join(self.gt_root, name))
65
66 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
67 gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
68
69 transformed = self.transform(image=img, label=gt)
70 img = transformed['image']
71 label = transformed['label']
72
73 return img, label
74
75
1 from .builder import build_dataloader
1 import copy
2 import random
3 import numpy as np
4 import torch
5 from torch.utils.data import DataLoader, Dataset
6 from torch.utils.data.distributed import DistributedSampler
7 from utils.registery import DATASET_REGISTRY, COLLATE_FN_REGISTRY
8 from .collate_fn import base_collate_fn
9
10 from .ReconData import ReconData
11
12
13 def build_dataset(cfg):
14
15 dataset_cfg = copy.deepcopy(cfg)
16 try:
17 dataset_cfg = dataset_cfg['dataset']
18 except Exception:
19 raise 'should contain {dataset}!'
20
21 train_cfg = copy.deepcopy(dataset_cfg)
22 val_cfg = copy.deepcopy(dataset_cfg)
23 train_cfg['args']['anno_file'] = train_cfg['args'].pop('train_anno_file')
24 train_cfg['args'].pop('val_anno_file')
25 train_cfg['args']['phase'] = 'train'
26 val_cfg['args']['anno_file'] = val_cfg['args'].pop('val_anno_file')
27 val_cfg['args'].pop('train_anno_file')
28 val_cfg['args']['phase'] = 'val'
29
30 train_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**train_cfg['args'])
31 val_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**val_cfg['args'])
32
33 return train_data, val_data
34
35
36
37 def build_dataloader(cfg):
38
39 dataloader_cfg = copy.deepcopy(cfg)
40 try:
41 dataloader_cfg = cfg['dataloader']
42 except Exception:
43 raise 'should contain {dataloader}!'
44
45 train_ds, val_ds = build_dataset(cfg)
46 train_sampler = DistributedSampler(train_ds)
47 collate_fn = COLLATE_FN_REGISTRY.get(dataloader_cfg.pop('collate_fn'))
48
49 train_loader = DataLoader(train_ds,
50 sampler=train_sampler,
51 collate_fn=collate_fn,
52 **dataloader_cfg)
53
54 val_loader = DataLoader(val_ds,
55 collate_fn=collate_fn,
56 **dataloader_cfg)
57
58 return train_loader, val_loader
59
1 import torch
2 from utils.registery import COLLATE_FN_REGISTRY
3
4
5 @COLLATE_FN_REGISTRY.register()
6 def base_collate_fn(batch):
7 images, labels = list(), list()
8 for image, label in batch:
9 images.append(image.unsqueeze(0))
10 labels.append(label.unsqueeze(0))
11
12 images = torch.cat(images, dim=0)
13 labels = torch.cat(labels, dim=0)
14
15 return {'image': images, 'label': labels}
16
17
1 import os
2 import pandas as pd
3 import random
4
5 root = '/data1/lxl/data/ocr/generate1108/'
6 img_path = os.path.join(root, 'img')
7 img_list = os.listdir(img_path)
8 random.shuffle(img_list)
9 train_df = pd.DataFrame(columns=['name'])
10 val_df = pd.DataFrame(columns=['name'])
11
12 train_df['name'] = img_list[:16000]
13 val_df['name'] = img_list[16000:]
14
15 train_df.to_csv(os.path.join(root, 'train.csv'))
16 val_df.to_csv(os.path.join(root, 'val.csv'))
1 import albumentations as A
2 import cv2
3 import os
4 from tqdm import tqdm
5
6 transform = A.Compose([
7 A.RandomResizedCrop(height=400, width=400, scale=(0.1, 0.2)),
8 A.HorizontalFlip(p=0.5),
9 A.VerticalFlip(p=0.5),
10 ])
11
12 root = '/data1/lxl/data/ocr/doc_remove_stamp/'
13 save_path = '/data1/lxl/data/ocr/crop_bg/'
14
15 date_dirs = sorted(os.listdir(root))
16
17 img_list = list()
18
19 img_names = sorted(os.listdir(root))
20 for img_name in img_names:
21 img_path = os.path.join(root, img_name)
22 img_list.append(img_path)
23
24 print(f'src img number: {len(img_list)}')
25
26 cnt = 0
27 for crt_img in tqdm(img_list):
28 try:
29 img = cv2.imread(crt_img)
30 for _ in range(5):
31 transformed = transform(image=img)
32 transformed_img = transformed['image']
33 cv2.imwrite(os.path.join(save_path, f'{cnt :06d}.jpg'), transformed_img)
34 cnt += 1
35 except Exception:
36 continue
37
1 import os
2 import cv2
3
4 stamp_root = '/data1/lxl/data/ocr/stamp/src'
5 mask_root = '/data1/lxl/data/ocr/stamp/mask'
6
7
8 def read_img_list(root):
9 img_full_path_list = list()
10 img_list = os.listdir(root)
11 for img_name in img_list:
12 img_full_path_list.append(os.path.join(root, img_name))
13
14 return img_full_path_list
15
16
17 def get_mask_list(stamp_list):
18 for stamp in stamp_list:
19 img_name = stamp.split('/')[-1]
20 img = cv2.imread(stamp, -1)
21 mask = img[:, :, -1]
22 cv2.imwrite(os.path.join(mask_root, img_name), mask)
23
24
25 if __name__ == '__main__':
26 full_path_list = read_img_list(stamp_root)
27 get_mask_list(full_path_list)
1 import os
2 import cv2
3 from PIL import Image, ImageEnhance
4 import numpy as np
5 import random
6 import shutil
7 from tqdm import tqdm
8 import json
9 import multiprocessing as mp
10 import imgaug.augmenters as iaa
11
12
13 def mkdir(path):
14 if not os.path.exists(path):
15 os.makedirs(path)
16
17
18 stamp_img_root = '/data1/lxl/data/ocr/stamp/aug/img'
19 stamp_mask_root = '/data1/lxl/data/ocr/stamp/aug/mask'
20 stamp_blend_root = '/data1/lxl/data/ocr/stamp/aug/blend'
21 text_root = '/data1/lxl/data/ocr/crop_bg'
22 gen_root = '/data1/lxl/data/ocr/generate1108/'
23 gen_img_root = os.path.join(gen_root, 'img')
24 gen_stamp_img_root = os.path.join(gen_root, 'stamp_img')
25 gen_stamp_mask_root = os.path.join(gen_root, 'stamp_mask')
26 gen_text_img_root = os.path.join(gen_root, 'text_img')
27 mkdir(gen_img_root)
28 mkdir(gen_text_img_root)
29 mkdir(gen_stamp_img_root)
30 mkdir(gen_stamp_mask_root)
31
32
33 def random_idx(s, e):
34 idx = int(np.random.randint(s, e, size=(1)))
35
36 return idx
37
38
39 def get_full_path_list(root):
40 path_list = list()
41 name_list = sorted(os.listdir(root))
42 for name in name_list:
43 path_list.append(os.path.join(root, name))
44
45 return path_list
46
47
48 def gen(stamp_img, blend_mask, stamp_mask, text_img, gen_img_root, gen_text_img_root, gen_stamp_mask_root, savename):
49 stamp_img = Image.open(stamp_img).convert("RGB")
50 blend_mask = Image.open(blend_mask)
51
52 stamp_img_width, stamp_img_height = stamp_img.size
53 stamp_mask = Image.open(stamp_mask).convert('L')
54 stamp_mask_copy = stamp_mask.copy().convert('L')
55 text_img = Image.open(text_img).convert("RGB")
56 gen_img = text_img.copy().convert("RGB")
57 x = random_idx(0, text_img.size[0] - stamp_img.size[0])
58 y = random_idx(0, text_img.size[1] - stamp_img.size[1])
59
60 gen_img.paste(stamp_img, (x, y), mask=blend_mask)
61
62 gen_stamp_img = Image.new('RGB', size=text_img.size)
63 gen_stamp_img.paste(stamp_img, (x, y), mask=blend_mask)
64 gen_stamp_mask = Image.new('L', size=text_img.size)
65 gen_stamp_mask.paste(stamp_mask, (x, y), mask=stamp_mask)
66 stamp_coordinate = [x, y, x + stamp_img.size[0], y + stamp_img.size[1]]
67 stamp_dict = {'name': str(savename), 'coordinate': stamp_coordinate, 'label': ''}
68
69 gen_img.save(os.path.join(gen_img_root, "{:>06d}.jpg".format(savename)))
70 text_img.save(os.path.join(gen_text_img_root, "{:>06d}.jpg".format(savename)))
71 gen_stamp_img.save(os.path.join(gen_stamp_img_root, "{:>06d}.jpg".format(savename)))
72 gen_stamp_mask.save(os.path.join(gen_stamp_mask_root, "{:>06d}.jpg".format(savename)))
73
74
75 def process():
76 stamp_list = sorted(os.listdir(stamp_img_root))
77 stamp_list_lth = len(stamp_list)
78 text_list = sorted(os.listdir(text_root))
79 text_list_lth = len(text_list)
80 need = 20000
81 pool = mp.Pool(processes=6)
82 for i in range(0, need):
83 stamp_idx = random_idx(0, stamp_list_lth)
84 stamp_img_path = os.path.join(stamp_img_root, stamp_list[stamp_idx])
85 stamp_mask_path = os.path.join(stamp_mask_root, stamp_list[stamp_idx])
86 blend_mask_path = os.path.join(stamp_blend_root, stamp_list[stamp_idx])
87 text_idx = random_idx(0, text_list_lth)
88 text_img_path = os.path.join(text_root, text_list[text_idx])
89 pool.apply_async(gen, (stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,))
90 # gen(stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,)
91 pool.close()
92 pool.join()
93
94
95 def main():
96 process()
97
98 if __name__ == '__main__':
99 main()
signatureDet @ 2dde6823
1 Subproject commit 2dde682312388f5c5099d161019cd97df06315e4
1 import imgaug.augmenters as iaa
2 import numpy as np
3 import cv2
4 import os
5 from tqdm import tqdm
6 from PIL import Image
7
8 seq = iaa.Sequential(
9 [
10 iaa.Fliplr(0.5),
11 iaa.Crop(percent=(0, 0.1), keep_size=True),
12 iaa.Sometimes(0.5, iaa.GaussianBlur(sigma=(0, 0.5))),
13 iaa.AddElementwise((-40, -10), per_channel=0.5),
14 iaa.Sometimes(0.5, iaa.MultiplyElementwise((0.7, 1.0))),
15 iaa.OneOf([
16 iaa.Rotate((-45, 45)),
17 iaa.Rot90((1, 3))
18 ]),
19 iaa.Sometimes(0.7, iaa.CoarseDropout(p=(0.1, 0.4), size_percent=(0.02, 0.2))),
20 iaa.Sometimes(0.02, iaa.imgcorruptlike.MotionBlur(severity=2)),
21 ])
22
23 gen_blend = iaa.Sequential([
24 iaa.GaussianBlur(sigma=(0, 0.5)),
25 iaa.MultiplyElementwise((0.8, 3.5)),
26 ])
27
28
29 img_root = '/data1/lxl/data/ocr/stamp/src'
30 gen_root = '/data1/lxl/data/ocr/stamp/aug'
31 gen_img_root = os.path.join(gen_root, 'img')
32 gen_mask_root = os.path.join(gen_root, 'mask')
33 gen_blend_root = os.path.join(gen_root, 'blend')
34 if not os.path.exists(gen_img_root):
35 os.makedirs(gen_img_root)
36 if not os.path.exists(gen_mask_root):
37 os.makedirs(gen_mask_root)
38 if not os.path.exists(gen_blend_root):
39 os.makedirs(gen_blend_root)
40
41 name_list = sorted(os.listdir(img_root))
42
43 for i in range(2):
44 for name in tqdm(name_list):
45 name_no_ex = name.split('.')[0]
46 ext = name.split('.')[1]
47 img = cv2.imread(os.path.join(img_root, name), -1)[:, :, :3]
48 mask = cv2.imread(os.path.join(img_root, name), -1)[:, :, -1]
49 mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
50 img = np.asarray(img, dtype=np.uint8)
51 mask = np.asarray(mask, dtype=np.uint8)
52 img = img[np.newaxis, :]
53 mask = mask[np.newaxis, :]
54 img_aug, mask_aug = seq(images=img, segmentation_maps=mask)
55 img_aug = img_aug.squeeze()
56 mask_aug = mask_aug.squeeze(0)
57
58 blend = cv2.cvtColor(img_aug, cv2.COLOR_BGR2GRAY)
59 blend_aug = gen_blend(images=blend)
60 blend_aug = blend_aug.squeeze()
61
62 cv2.imwrite(os.path.join(gen_img_root, (name_no_ex + '_' + str(i) + '.' + ext)), img_aug)
63 cv2.imwrite(os.path.join(gen_mask_root, (name_no_ex + '_' + str(i) + '.' + ext)), mask_aug)
64 cv2.imwrite(os.path.join(gen_blend_root, (name_no_ex + '_' + str(i) + '.' + ext)), blend_aug)
1 from .builder import build_loss
2
3 __all__ = ['BCE', 'build_loss']
1 import torch
2 import inspect
3 from utils.registery import LOSS_REGISTRY
4 import copy
5
6
7 def register_torch_loss():
8 for module_name in dir(torch.nn):
9 if module_name.startswith('__') or 'Loss' not in module_name:
10 continue
11 _loss = getattr(torch.nn, module_name)
12 if inspect.isclass(_loss) and issubclass(_loss, torch.nn.Module):
13 LOSS_REGISTRY.register()(_loss)
14
15 def build_loss(cfg):
16 register_torch_loss()
17 loss_cfg = copy.deepcopy(cfg)
18
19 try:
20 loss_cfg = cfg['solver']['loss']
21 except Exception:
22 raise 'should contain {solver.loss}!'
23
24 return LOSS_REGISTRY.get(loss_cfg['name'])(**loss_cfg['args'])
1 from .builder import build_metric
2
3 __all__ = ['build_metric']
1 from utils.registery import METRIC_REGISTRY
2 import copy
3
4 from .recon_metric import Recon
5
6 def build_metric(cfg):
7 cfg = copy.deepcopy(cfg)
8 try:
9 metric_cfg = cfg['solver']['metric']
10 except Exception:
11 raise 'should contain {solver.metric}!'
12
13 return METRIC_REGISTRY.get(metric_cfg['name'])()
1 import numpy as np
2 import torchmetrics
3
4
5 from utils.registery import METRIC_REGISTRY
6
7 @METRIC_REGISTRY.register()
8 class Recon(object):
9 def __init__(self):
10 self.psnr = torchmetrics.PeakSignalNoiseRatio()
11 self.ssim = torchmetrics.StructuralSimilarityIndexMeasure()
12
13 def __call__(self, pred, label):
14
15 assert pred.shape[0] == label.shape[0]
16
17 psnr_list = list()
18 ssim_list = list()
19
20 for i in range(len(pred)):
21 psnr_list.append(self.psnr(pred[i].unsqueeze(0), label[i].unsqueeze(0)))
22 ssim_list.append(self.ssim(pred[i].unsqueeze(0), label[i].unsqueeze(0)))
23
24 psnr_result = sum(psnr_list) / len(psnr_list)
25 ssim_result = sum(ssim_list) / len(ssim_list)
26
27 return {'psnr': psnr_result, 'ssim': ssim_result}
28
29
1 from .builder import build_model
2
1 import torch
2 import torch.nn as nn
3 from abc import ABCMeta
4 import math
5
6
7 class BaseModel(nn.Module, metaclass=ABCMeta):
8 """
9 Base model class
10 """
11 def __init__(self):
12 super().__init__()
13
14 def _initialize_weights(self):
15 for m in self.modules():
16 if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
17 nn.init.xavier_uniform_(m.weight)
18 if m.bias is not None:
19 nn.init.constant_(m.bias, 0)
20 elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
21 nn.init.constant_(m.weight, 1)
22 nn.init.constant_(m.bias, 0)
23 elif isinstance(m, nn.Linear):
24 nn.init.xavier_uniform_(m.weight)
25 if m.bias is not None:
26 nn.init.constant_(m.bias, 0)
27
1 import copy
2 from utils import MODEL_REGISTRY
3
4 from .unet import Unet
5 from .unet_skip import UnetSkip
6
7
8 def build_model(cfg):
9 model_cfg = copy.deepcopy(cfg)
10 try:
11 model_cfg = model_cfg['model']
12 except Exception:
13 raise 'should contain {model}'
14
15 model = MODEL_REGISTRY.get(model_cfg['name'])(**model_cfg['args'])
16
17 return model
18
1 import segmentation_models_pytorch as smp
2
3 from utils.registery import MODEL_REGISTRY
4 from core.model.base_model import BaseModel
5
6
7 @MODEL_REGISTRY.register()
8 class Unet(BaseModel):
9 def __init__(self,
10 encoder_name: str = 'resnet50',
11 encoder_weights: str = 'imagenet',
12 in_channels: int = 3,
13 classes: int = 3,
14 activation: str = 'tanh'):
15 super().__init__()
16 self.model = smp.Unet(
17 encoder_name=encoder_name,
18 encoder_weights=encoder_weights,
19 in_channels=in_channels,
20 classes=classes,
21 activation=activation
22 )
23
24 self._initialize_weights()
25
26 def forward(self, x):
27 out = x + self.model(x)
28 # out = self.model(x)
29
30 return out
31
32
1 import segmentation_models_pytorch as smp
2
3 from utils.registery import MODEL_REGISTRY
4 from core.model.base_model import BaseModel
5
6
7 @MODEL_REGISTRY.register()
8 class UnetSkip(BaseModel):
9 def __init__(self,
10 encoder_name: str = 'resnet50',
11 encoder_weights: str = 'imagenet',
12 in_channels: int = 3,
13 classes: int = 3,
14 activation: str = 'tanh'):
15 super().__init__()
16 self.model = smp.Unet(
17 encoder_name=encoder_name,
18 encoder_weights=encoder_weights,
19 in_channels=in_channels,
20 classes=classes,
21 activation=activation
22 )
23
24 self._initialize_weights()
25
26 def forward(self, x):
27 residual = self.model(x)
28 reconstruction = x + residual
29
30 returned_dict = {'reconstruction': reconstruction, 'residual': residual}
31
32 return returned_dict
33
34
1 from .builder import build_optimizer, build_lr_scheduler
2
3 __all__ = ['build_optimizer', 'build_lr_scheduler']
1 import torch
2 import inspect
3 from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
4 import copy
5
6 def register_torch_optimizers():
7 """
8 Register all optimizers implemented by torch
9 """
10 for module_name in dir(torch.optim):
11 if module_name.startswith('__'):
12 continue
13 _optim = getattr(torch.optim, module_name)
14 if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer):
15 OPTIMIZER_REGISTRY.register()(_optim)
16
17 def build_optimizer(cfg):
18 register_torch_optimizers()
19 optimizer_cfg = copy.deepcopy(cfg)
20
21 try:
22 optimizer_cfg = cfg['solver']['optimizer']
23 except Exception:
24 raise 'should contain {solver.optimizer}!'
25
26 return OPTIMIZER_REGISTRY.get(optimizer_cfg['name'])
27
28 def register_torch_lr_scheduler():
29 """
30 Register all lr_schedulers implemented by torch
31 """
32 for module_name in dir(torch.optim.lr_scheduler):
33 if module_name.startswith('__'):
34 continue
35
36 _scheduler = getattr(torch.optim.lr_scheduler, module_name)
37 if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler):
38 LR_SCHEDULER_REGISTRY.register()(_scheduler)
39
40 def build_lr_scheduler(cfg):
41 register_torch_lr_scheduler()
42 scheduler_cfg = copy.deepcopy(cfg)
43
44 try:
45 scheduler_cfg = cfg['solver']['lr_scheduler']
46 except Exception:
47 raise 'should contain {solver.lr_scheduler}!'
48
49 return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name'])
50
1 from .builder import build_solver
2
3 __all__ = ['build_solver']
1 import torch
2 from core.data import build_dataloader
3 from core.model import build_model
4 from core.optimizer import build_optimizer, build_lr_scheduler
5 from core.loss import build_loss
6 from core.metric import build_metric
7 from utils.registery import SOLVER_REGISTRY
8 from utils.logger import get_logger_and_log_path
9 import os
10 import copy
11 import datetime
12 from torch.nn.parallel import DistributedDataParallel
13 import numpy as np
14 import pandas as pd
15 import yaml
16
17
18 @SOLVER_REGISTRY.register()
19 class BaseSolver(object):
20 def __init__(self, cfg):
21 self.cfg = copy.deepcopy(cfg)
22 self.local_rank = torch.distributed.get_rank()
23 self.train_loader, self.val_loader = build_dataloader(cfg)
24 self.len_train_loader, self.len_val_loader = len(self.train_loader), len(self.val_loader)
25 self.criterion = build_loss(cfg).cuda(self.local_rank)
26 model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(build_model(cfg))
27 self.model = DistributedDataParallel(model.cuda(self.local_rank), device_ids=[self.local_rank], find_unused_parameters=True)
28 self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
29 self.hyper_params = cfg['solver']['args']
30 crt_date = datetime.date.today().strftime('%Y-%m-%d')
31 self.logger, self.log_path = get_logger_and_log_path(crt_date=crt_date, **cfg['solver']['logger'])
32 self.metric_fn = build_metric(cfg)
33 try:
34 self.epoch = self.hyper_params['epoch']
35 except Exception:
36 raise 'should contain epoch in {solver.args}'
37 if self.local_rank == 0:
38 self.save_dict_to_yaml(self.cfg, os.path.join(self.log_path, 'config.yaml'))
39 self.logger.info(self.cfg)
40
41 def train(self):
42 if torch.distributed.get_rank() == 0:
43 self.logger.info('==> Start Training')
44 lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
45
46 for t in range(self.epoch):
47 self.train_loader.sampler.set_epoch(t)
48 if torch.distributed.get_rank() == 0:
49 self.logger.info(f'==> epoch {t + 1}')
50 self.model.train()
51
52 pred_list = list()
53 label_list = list()
54
55 mean_loss = 0.0
56
57 for i, data in enumerate(self.train_loader):
58 self.optimizer.zero_grad()
59 image = data['image'].cuda(self.local_rank)
60 label = data['label'].cuda(self.local_rank)
61
62 pred = self.model(image)
63
64 loss = self.criterion(pred, label)
65 mean_loss += loss.item()
66
67 if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0):
68 loss_value = loss.item()
69 self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, loss: {loss_value :.4f}')
70
71 loss.backward()
72 self.optimizer.step()
73
74 # batch_pred = [torch.zeros_like(pred) for _ in range(torch.distributed.get_world_size())] # 1
75 # torch.distributed.all_gather(batch_pred, pred)
76 # pred_list.append(torch.cat(batch_pred, dim=0).detach().cpu())
77
78 # batch_label = [torch.zeros_like(label) for _ in range(torch.distributed.get_world_size())]
79 # torch.distributed.all_gather(batch_label, label)
80 # label_list.append(torch.cat(batch_label, dim=0).detach().cpu())
81
82 # pred_list = torch.cat(pred_list, dim=0)
83 # label_list = torch.cat(label_list, dim=0)
84 # metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
85 mean_loss = mean_loss / self.len_train_loader
86
87 if torch.distributed.get_rank() == 0:
88 # self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
89 self.logger.info(f'==> train mean loss: {mean_loss :.4f}')
90 self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1)
91 self.val(t + 1)
92 lr_scheduler.step()
93
94 if self.local_rank == 0:
95 self.logger.info('==> End Training')
96
97 @torch.no_grad()
98 def val(self, t):
99 self.model.eval()
100
101 pred_list = list()
102 label_list = list()
103
104 for i, data in enumerate(self.val_loader):
105 feat = data['image'].cuda(self.local_rank)
106 label = data['label'].cuda(self.local_rank)
107
108 pred = self.model(feat)
109
110 pred_list.append(pred.detach().cpu())
111 label_list.append(label.detach().cpu())
112
113 pred_list = torch.cat(pred_list, dim=0)
114 label_list = torch.cat(label_list, dim=0)
115
116 metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
117 if torch.distributed.get_rank() == 0:
118 self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
119
120 def run(self):
121 self.train()
122
123 @staticmethod
124 def save_dict_to_yaml(dict_value, save_path):
125 with open(save_path, 'w', encoding='utf-8') as file:
126 yaml.dump(dict_value, file, sort_keys=False)
127
128
129 def save_checkpoint(self, model, cfg, log_path, epoch_id):
130 model.eval()
131 torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))
1 from utils.registery import SOLVER_REGISTRY
2 import copy
3
4 from .base_solver import BaseSolver
5 from .skip_solver import SkipSolver
6
7
8 def build_solver(cfg):
9 cfg = copy.deepcopy(cfg)
10
11 try:
12 solver_cfg = cfg['solver']
13 except Exception:
14 raise 'should contain {solver}!'
15
16 return SOLVER_REGISTRY.get(solver_cfg['name'])(cfg)
1 import torch
2 from core.data import build_dataloader
3 from core.model import build_model
4 from core.optimizer import build_optimizer, build_lr_scheduler
5 from core.loss import build_loss
6 from core.metric import build_metric
7 from utils.registery import SOLVER_REGISTRY
8 from utils.logger import get_logger_and_log_path
9 import os
10 import copy
11 import datetime
12 from torch.nn.parallel import DistributedDataParallel
13 import numpy as np
14 import pandas as pd
15 import yaml
16
17 from .base_solver import BaseSolver
18
19
20 @SOLVER_REGISTRY.register()
21 class SkipSolver(BaseSolver):
22 def __init__(self, cfg):
23 super().__init__(cfg)
24
25 def train(self):
26 if torch.distributed.get_rank() == 0:
27 self.logger.info('==> Start Training')
28 lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
29
30 for t in range(self.epoch):
31 self.train_loader.sampler.set_epoch(t)
32 if torch.distributed.get_rank() == 0:
33 self.logger.info(f'==> epoch {t + 1}')
34 self.model.train()
35
36 pred_list = list()
37 label_list = list()
38
39 mean_loss = 0.0
40
41 for i, data in enumerate(self.train_loader):
42 self.optimizer.zero_grad()
43 image = data['image'].cuda(self.local_rank)
44 label = data['label'].cuda(self.local_rank)
45 residual = image - label
46
47 pred = self.model(image)
48
49 reconstruction_loss = self.criterion(pred['reconstruction'], label)
50 residual_loss = self.criterion(pred['residual'], residual)
51 loss = reconstruction_loss + residual_loss
52 mean_loss += loss.item()
53
54 if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0):
55 reconstruction_loss_value = reconstruction_loss.item()
56 residual_loss_value = residual_loss.item()
57 loss_value = loss.item()
58 self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, reconstruction loss: {reconstruction_loss_value :.4f}, residual loss: {residual_loss_value :.4f}, loss: {loss_value :.4f}')
59
60 loss.backward()
61 self.optimizer.step()
62
63 mean_loss = mean_loss / self.len_train_loader
64
65 if torch.distributed.get_rank() == 0:
66 # self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
67 self.logger.info(f'==> train mean loss: {mean_loss :.4f}')
68 self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1)
69 self.val(t + 1)
70 lr_scheduler.step()
71
72 if self.local_rank == 0:
73 self.logger.info('==> End Training')
74
75 @torch.no_grad()
76 def val(self, t):
77 self.model.eval()
78
79 pred_list = list()
80 label_list = list()
81
82 for i, data in enumerate(self.val_loader):
83 image = data['image'].cuda(self.local_rank)
84 label = data['label'].cuda(self.local_rank)
85 residual = image - label
86
87 pred = self.model(image)
88
89 pred_list.append(pred['reconstruction'].detach().cpu())
90 label_list.append(label.detach().cpu())
91
92 pred_list = torch.cat(pred_list, dim=0)
93 label_list = torch.cat(label_list, dim=0)
94
95 metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
96 if torch.distributed.get_rank() == 0:
97 self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
98
99 def run(self):
100 self.train()
101
102 @staticmethod
103 def save_dict_to_yaml(dict_value, save_path):
104 with open(save_path, 'w', encoding='utf-8') as file:
105 yaml.dump(dict_value, file, sort_keys=False)
106
107
108 def save_checkpoint(self, model, cfg, log_path, epoch_id):
109 model.eval()
110 torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))
1 import yaml
2 from core.solver import build_solver
3 import torch
4 import numpy as np
5 import random
6 import argparse
7
8
9 def init_seed(seed=778):
10 random.seed(seed)
11 np.random.seed(seed)
12 torch.manual_seed(seed)
13 torch.cuda.manual_seed(seed)
14 torch.cuda.manual_seed_all(seed)
15 torch.backends.cudnn.benchmark = False
16 torch.backends.cudnn.deterministic = True
17
18 def main():
19 parser = argparse.ArgumentParser()
20 parser.add_argument('--config', default='./config/baseline.yaml', type=str, help='config file')
21 parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
22 args = parser.parse_args()
23
24 cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader)
25 init_seed(cfg['seed'])
26
27 torch.distributed.init_process_group(backend='nccl')
28 torch.cuda.set_device(args.local_rank)
29
30 solver = build_solver(cfg)
31
32 solver.run()
33
34 if __name__ == '__main__':
35 main()
1 import os
2 os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3 import cv2
4 import numpy as np
5 import torch
6 from PIL import Image
7 import torchvision.transforms as transforms
8 from tqdm import tqdm
9
10
11 from core.model.unet import Unet
12 from core.model.unet_skip import UnetSkip
13
14 # def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/baseline/ckpt_epoch_18.pt'):
15 # def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/residual/ckpt_epoch_28.pt'):
16 def load_model(ckpt='./log/2022-11-11/skip/ckpt_epoch_30.pt'):
17 # model = Unet(
18 # encoder_name='resnet50',
19 # encoder_weights='imagenet',
20 # in_channels=3,
21 # classes=3,
22 # activation='tanh'
23 # )
24
25 model = UnetSkip(
26 encoder_name='resnet50',
27 encoder_weights='imagenet',
28 in_channels=3,
29 classes=3,
30 activation='tanh'
31 )
32
33 model.load_state_dict(torch.load(ckpt, map_location='cpu'))
34 model.eval()
35
36 return model
37
38
39 def infer(model, img_path, gen_path):
40 img = cv2.imread(img_path)
41 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
42 img = cv2.resize(img, (512, 512))
43 img = img / 255.
44 img = img.transpose(2, 0, 1).astype(np.float32)
45 img = torch.from_numpy(img).unsqueeze(0)
46 out = model(img)
47 out = out['reconstruction']
48 out[out < 0] = 0
49 out[out > 1] = 1
50 out = out * 255
51 out = out.detach().cpu().numpy().squeeze(0).transpose(1, 2, 0)
52 out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
53 cv2.imwrite(gen_path, out)
54
55 def infer_list():
56 img_root = '/data1/lxl/data/ocr/real/src/'
57 gen_root = '/data1/lxl/data/ocr/real/removed/'
58 img_list = sorted(os.listdir(img_root))
59 model = load_model()
60 for img in tqdm(img_list):
61 img_path = os.path.join(img_root, img)
62 gen_path = os.path.join(gen_root, img)
63 infer(model, img_path, gen_path)
64
65 def infer_img():
66 model = load_model()
67 img_path = '../DocEnTR/real/hetong_006_00.png'
68 gen_path = './out.jpg'
69 infer(model, img_path, gen_path)
70
71 infer_img()
1 CUDA_VISIBLE_DEVICES=0 nohup python -m torch.distributed.launch --master_port 8999 --nproc_per_node=1 main.py --config ./config/skip.yaml &
1 from .registery import *
2 from .logger import get_logger_and_log_path
3 from .helper import save_checkpoint
4
5 __all__ = [
6 'Registry',
7 'get_logger_and_log_path',
8 'save_checkpoint'
9 ]
1 import torch
2 import yaml
3 import os
4
5
6 def save_dict_to_yaml(dict_value, save_path):
7 with open(save_path, 'w', encoding='utf-8') as file:
8 yaml.dump(dict_value, file, sort_keys=False)
9
10
11 def save_checkpoint(model, cfg, log_path, epoch_id):
12 save_dict_to_yaml(cfg, os.path.join(log_path, 'config.yaml'))
13 torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))
1 import loguru
2 import copy
3 import os
4 import datetime
5
6 def get_logger_and_log_path(log_root,
7 crt_date,
8 suffix):
9 """
10 get logger and log path
11
12 Args:
13 log_root (str): root path of log
14 crt_date (str): formated date name (Y-M-D)
15 suffix (str): log save name
16
17 Returns:
18 logger (loguru.logger): logger object
19 log_path (str): current root log path (with suffix)
20 """
21 log_path = os.path.join(log_root, crt_date, suffix)
22 if not os.path.exists(log_path):
23 os.makedirs(log_path)
24
25 logger_path = os.path.join(log_path, 'logfile.log')
26 logger = loguru.logger
27 fmt = '{time:YYYY-MM-DD at HH:mm:ss} | {message}'
28 logger.add(logger_path, format=fmt)
29
30 return logger, log_path
1 class Registry():
2 """
3 The registry that provides name -> object mapping, to support third-party
4 users' custom modules.
5
6 """
7
8 def __init__(self, name):
9 """
10 Args:
11 name (str): the name of this registry
12 """
13 self._name = name
14 self._obj_map = {}
15
16 def _do_register(self, name, obj, suffix=None):
17 if isinstance(suffix, str):
18 name = name + '_' + suffix
19
20 assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
21 f"in '{self._name}' registry!")
22 self._obj_map[name] = obj
23
24 def register(self, obj=None, suffix=None):
25 """
26 Register the given object under the the name `obj.__name__`.
27 Can be used as either a decorator or not.
28 See docstring of this class for usage.
29 """
30 if obj is None:
31 # used as a decorator
32 def deco(func_or_class):
33 name = func_or_class.__name__
34 self._do_register(name, func_or_class, suffix)
35 return func_or_class
36
37 return deco
38
39 # used as a function call
40 name = obj.__name__
41 self._do_register(name, obj, suffix)
42
43 def get(self, name, suffix='soulwalker'):
44 ret = self._obj_map.get(name)
45 if ret is None:
46 ret = self._obj_map.get(name + '_' + suffix)
47 print(f'Name {name} is not found, use name: {name}_{suffix}!')
48 if ret is None:
49 raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
50 return ret
51
52 def __contains__(self, name):
53 return name in self._obj_map
54
55 def __iter__(self):
56 return iter(self._obj_map.items())
57
58 def keys(self):
59 return self._obj_map.keys()
60
61
62 DATASET_REGISTRY = Registry('dataset')
63 MODEL_REGISTRY = Registry('model')
64 LOSS_REGISTRY = Registry('loss')
65 METRIC_REGISTRY = Registry('metric')
66 OPTIMIZER_REGISTRY = Registry('optimizer')
67 SOLVER_REGISTRY = Registry('solver')
68 LR_SCHEDULER_REGISTRY = Registry('lr_scheduler')
69 COLLATE_FN_REGISTRY = Registry('collate_fn')
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!