init commit
0 parents
Showing
37 changed files
with
1230 additions
and
0 deletions
config/baseline.yaml
0 → 100644
1 | seed: 3407 | ||
2 | |||
3 | dataset: | ||
4 | name: 'ReconData' | ||
5 | args: | ||
6 | data_root: '/data1/lxl/data/ocr/generate1108' | ||
7 | train_anno_file: '/data1/lxl/data/ocr/generate1108/train.csv' | ||
8 | val_anno_file: '/data1/lxl/data/ocr/generate1108/val.csv' | ||
9 | fixed_size: 512 | ||
10 | |||
11 | dataloader: | ||
12 | batch_size: 8 | ||
13 | num_workers: 16 | ||
14 | pin_memory: true | ||
15 | collate_fn: 'base_collate_fn' | ||
16 | |||
17 | model: | ||
18 | name: 'Unet' | ||
19 | args: | ||
20 | encoder_name: 'resnet50' | ||
21 | encoder_weights: 'imagenet' | ||
22 | in_channels: 3 | ||
23 | classes: 3 | ||
24 | activation: 'tanh' | ||
25 | |||
26 | solver: | ||
27 | name: 'BaseSolver' | ||
28 | args: | ||
29 | epoch: 30 | ||
30 | |||
31 | optimizer: | ||
32 | name: 'Adam' | ||
33 | args: | ||
34 | lr: !!float 1e-4 | ||
35 | weight_decay: !!float 5e-5 | ||
36 | |||
37 | lr_scheduler: | ||
38 | name: 'StepLR' | ||
39 | args: | ||
40 | step_size: 15 | ||
41 | gamma: 0.1 | ||
42 | |||
43 | loss: | ||
44 | name: 'L1Loss' | ||
45 | args: | ||
46 | reduction: 'mean' | ||
47 | |||
48 | logger: | ||
49 | log_root: '/data1/lxl/code/ocr/removal/log' | ||
50 | suffix: 'residual' | ||
51 | |||
52 | metric: | ||
53 | name: 'Recon' |
config/skip.yaml
0 → 100644
1 | seed: 3407 | ||
2 | |||
3 | dataset: | ||
4 | name: 'ReconData' | ||
5 | args: | ||
6 | data_root: '/data1/lxl/data/ocr/generate1108' | ||
7 | train_anno_file: '/data1/lxl/data/ocr/generate1108/train.csv' | ||
8 | val_anno_file: '/data1/lxl/data/ocr/generate1108/val.csv' | ||
9 | fixed_size: 512 | ||
10 | |||
11 | dataloader: | ||
12 | batch_size: 8 | ||
13 | num_workers: 16 | ||
14 | pin_memory: true | ||
15 | collate_fn: 'base_collate_fn' | ||
16 | |||
17 | model: | ||
18 | name: 'UnetSkip' | ||
19 | args: | ||
20 | encoder_name: 'resnet50' | ||
21 | encoder_weights: 'imagenet' | ||
22 | in_channels: 3 | ||
23 | classes: 3 | ||
24 | activation: 'tanh' | ||
25 | |||
26 | solver: | ||
27 | name: 'SkipSolver' | ||
28 | args: | ||
29 | epoch: 30 | ||
30 | |||
31 | optimizer: | ||
32 | name: 'Adam' | ||
33 | args: | ||
34 | lr: !!float 1e-4 | ||
35 | weight_decay: !!float 5e-5 | ||
36 | |||
37 | lr_scheduler: | ||
38 | name: 'StepLR' | ||
39 | args: | ||
40 | step_size: 15 | ||
41 | gamma: 0.1 | ||
42 | |||
43 | loss: | ||
44 | name: 'L1Loss' | ||
45 | args: | ||
46 | reduction: 'mean' | ||
47 | |||
48 | logger: | ||
49 | log_root: '/data1/lxl/code/ocr/removal/log' | ||
50 | suffix: 'skip' | ||
51 | |||
52 | metric: | ||
53 | name: 'Recon' |
core/__init__.py
0 → 100644
1 | from .solver import build_solver |
core/data/ReconData.py
0 → 100644
1 | from torch.utils.data import DataLoader, Dataset | ||
2 | import torchvision.transforms as transforms | ||
3 | import torch | ||
4 | import pandas as pd | ||
5 | import os | ||
6 | import random | ||
7 | import cv2 | ||
8 | import albumentations as A | ||
9 | import albumentations.pytorch | ||
10 | import numpy as np | ||
11 | from PIL import Image | ||
12 | |||
13 | from utils.registery import DATASET_REGISTRY | ||
14 | |||
15 | |||
16 | @DATASET_REGISTRY.register() | ||
17 | class ReconData(Dataset): | ||
18 | def __init__(self, | ||
19 | data_root: str = '/data1/lxl/data/ocr/generate1108', | ||
20 | anno_file: str = 'train.csv', | ||
21 | fixed_size: int = 448, | ||
22 | phase: str = 'train'): | ||
23 | self.data_root = data_root | ||
24 | self.df = pd.read_csv(anno_file) | ||
25 | self.img_root = os.path.join(data_root, 'img') | ||
26 | self.gt_root = os.path.join(data_root, 'text_img') | ||
27 | self.fixed_size = fixed_size | ||
28 | |||
29 | self.phase = phase | ||
30 | transform_fn = self.__get_transform() | ||
31 | self.transform = transform_fn[phase] | ||
32 | |||
33 | def __get_transform(self): | ||
34 | |||
35 | train_transform = A.Compose([ | ||
36 | A.Resize(height=self.fixed_size, width=self.fixed_size), | ||
37 | # A.RandomBrightness(limit=(-0.5, 0), p=0.5), | ||
38 | A.RandomBrightnessContrast(brightness_limit=(-0.5, 0), contrast_limit=0, p=0.5), | ||
39 | A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0), | ||
40 | A.pytorch.transforms.ToTensorV2() | ||
41 | ], additional_targets={'label': 'image'}) | ||
42 | |||
43 | val_transform = A.Compose([ | ||
44 | A.Resize(height=self.fixed_size, width=self.fixed_size), | ||
45 | A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0), | ||
46 | A.pytorch.transforms.ToTensorV2() | ||
47 | ], additional_targets={'label': 'image'}) | ||
48 | |||
49 | transform_fn = {'train': train_transform, 'val': val_transform} | ||
50 | |||
51 | return transform_fn | ||
52 | |||
53 | def __len__(self): | ||
54 | return len(self.df) | ||
55 | |||
56 | |||
57 | def __getitem__(self, idx): | ||
58 | series = self.df.iloc[idx] | ||
59 | name = series['name'] | ||
60 | |||
61 | # img = Image.open(os.path.join(self.img_root, name)) | ||
62 | # gt = Image.open(os.path.join(self.gt_root, name)) | ||
63 | img = cv2.imread(os.path.join(self.img_root, name)) | ||
64 | gt = cv2.imread(os.path.join(self.gt_root, name)) | ||
65 | |||
66 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
67 | gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB) | ||
68 | |||
69 | transformed = self.transform(image=img, label=gt) | ||
70 | img = transformed['image'] | ||
71 | label = transformed['label'] | ||
72 | |||
73 | return img, label | ||
74 | |||
75 |
core/data/__init__.py
0 → 100644
1 | from .builder import build_dataloader |
core/data/builder.py
0 → 100644
1 | import copy | ||
2 | import random | ||
3 | import numpy as np | ||
4 | import torch | ||
5 | from torch.utils.data import DataLoader, Dataset | ||
6 | from torch.utils.data.distributed import DistributedSampler | ||
7 | from utils.registery import DATASET_REGISTRY, COLLATE_FN_REGISTRY | ||
8 | from .collate_fn import base_collate_fn | ||
9 | |||
10 | from .ReconData import ReconData | ||
11 | |||
12 | |||
13 | def build_dataset(cfg): | ||
14 | |||
15 | dataset_cfg = copy.deepcopy(cfg) | ||
16 | try: | ||
17 | dataset_cfg = dataset_cfg['dataset'] | ||
18 | except Exception: | ||
19 | raise 'should contain {dataset}!' | ||
20 | |||
21 | train_cfg = copy.deepcopy(dataset_cfg) | ||
22 | val_cfg = copy.deepcopy(dataset_cfg) | ||
23 | train_cfg['args']['anno_file'] = train_cfg['args'].pop('train_anno_file') | ||
24 | train_cfg['args'].pop('val_anno_file') | ||
25 | train_cfg['args']['phase'] = 'train' | ||
26 | val_cfg['args']['anno_file'] = val_cfg['args'].pop('val_anno_file') | ||
27 | val_cfg['args'].pop('train_anno_file') | ||
28 | val_cfg['args']['phase'] = 'val' | ||
29 | |||
30 | train_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**train_cfg['args']) | ||
31 | val_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**val_cfg['args']) | ||
32 | |||
33 | return train_data, val_data | ||
34 | |||
35 | |||
36 | |||
37 | def build_dataloader(cfg): | ||
38 | |||
39 | dataloader_cfg = copy.deepcopy(cfg) | ||
40 | try: | ||
41 | dataloader_cfg = cfg['dataloader'] | ||
42 | except Exception: | ||
43 | raise 'should contain {dataloader}!' | ||
44 | |||
45 | train_ds, val_ds = build_dataset(cfg) | ||
46 | train_sampler = DistributedSampler(train_ds) | ||
47 | collate_fn = COLLATE_FN_REGISTRY.get(dataloader_cfg.pop('collate_fn')) | ||
48 | |||
49 | train_loader = DataLoader(train_ds, | ||
50 | sampler=train_sampler, | ||
51 | collate_fn=collate_fn, | ||
52 | **dataloader_cfg) | ||
53 | |||
54 | val_loader = DataLoader(val_ds, | ||
55 | collate_fn=collate_fn, | ||
56 | **dataloader_cfg) | ||
57 | |||
58 | return train_loader, val_loader | ||
59 |
core/data/collate_fn.py
0 → 100644
1 | import torch | ||
2 | from utils.registery import COLLATE_FN_REGISTRY | ||
3 | |||
4 | |||
5 | @COLLATE_FN_REGISTRY.register() | ||
6 | def base_collate_fn(batch): | ||
7 | images, labels = list(), list() | ||
8 | for image, label in batch: | ||
9 | images.append(image.unsqueeze(0)) | ||
10 | labels.append(label.unsqueeze(0)) | ||
11 | |||
12 | images = torch.cat(images, dim=0) | ||
13 | labels = torch.cat(labels, dim=0) | ||
14 | |||
15 | return {'image': images, 'label': labels} | ||
16 | |||
17 |
core/data/preprocess/create_anno.py
0 → 100644
1 | import os | ||
2 | import pandas as pd | ||
3 | import random | ||
4 | |||
5 | root = '/data1/lxl/data/ocr/generate1108/' | ||
6 | img_path = os.path.join(root, 'img') | ||
7 | img_list = os.listdir(img_path) | ||
8 | random.shuffle(img_list) | ||
9 | train_df = pd.DataFrame(columns=['name']) | ||
10 | val_df = pd.DataFrame(columns=['name']) | ||
11 | |||
12 | train_df['name'] = img_list[:16000] | ||
13 | val_df['name'] = img_list[16000:] | ||
14 | |||
15 | train_df.to_csv(os.path.join(root, 'train.csv')) | ||
16 | val_df.to_csv(os.path.join(root, 'val.csv')) |
core/data/preprocess/make_bg.py
0 → 100644
1 | import albumentations as A | ||
2 | import cv2 | ||
3 | import os | ||
4 | from tqdm import tqdm | ||
5 | |||
6 | transform = A.Compose([ | ||
7 | A.RandomResizedCrop(height=400, width=400, scale=(0.1, 0.2)), | ||
8 | A.HorizontalFlip(p=0.5), | ||
9 | A.VerticalFlip(p=0.5), | ||
10 | ]) | ||
11 | |||
12 | root = '/data1/lxl/data/ocr/doc_remove_stamp/' | ||
13 | save_path = '/data1/lxl/data/ocr/crop_bg/' | ||
14 | |||
15 | date_dirs = sorted(os.listdir(root)) | ||
16 | |||
17 | img_list = list() | ||
18 | |||
19 | img_names = sorted(os.listdir(root)) | ||
20 | for img_name in img_names: | ||
21 | img_path = os.path.join(root, img_name) | ||
22 | img_list.append(img_path) | ||
23 | |||
24 | print(f'src img number: {len(img_list)}') | ||
25 | |||
26 | cnt = 0 | ||
27 | for crt_img in tqdm(img_list): | ||
28 | try: | ||
29 | img = cv2.imread(crt_img) | ||
30 | for _ in range(5): | ||
31 | transformed = transform(image=img) | ||
32 | transformed_img = transformed['image'] | ||
33 | cv2.imwrite(os.path.join(save_path, f'{cnt :06d}.jpg'), transformed_img) | ||
34 | cnt += 1 | ||
35 | except Exception: | ||
36 | continue | ||
37 |
core/data/preprocess/make_stamp_mask.py
0 → 100644
1 | import os | ||
2 | import cv2 | ||
3 | |||
4 | stamp_root = '/data1/lxl/data/ocr/stamp/src' | ||
5 | mask_root = '/data1/lxl/data/ocr/stamp/mask' | ||
6 | |||
7 | |||
8 | def read_img_list(root): | ||
9 | img_full_path_list = list() | ||
10 | img_list = os.listdir(root) | ||
11 | for img_name in img_list: | ||
12 | img_full_path_list.append(os.path.join(root, img_name)) | ||
13 | |||
14 | return img_full_path_list | ||
15 | |||
16 | |||
17 | def get_mask_list(stamp_list): | ||
18 | for stamp in stamp_list: | ||
19 | img_name = stamp.split('/')[-1] | ||
20 | img = cv2.imread(stamp, -1) | ||
21 | mask = img[:, :, -1] | ||
22 | cv2.imwrite(os.path.join(mask_root, img_name), mask) | ||
23 | |||
24 | |||
25 | if __name__ == '__main__': | ||
26 | full_path_list = read_img_list(stamp_root) | ||
27 | get_mask_list(full_path_list) |
core/data/preprocess/paste_stamp.py
0 → 100644
1 | import os | ||
2 | import cv2 | ||
3 | from PIL import Image, ImageEnhance | ||
4 | import numpy as np | ||
5 | import random | ||
6 | import shutil | ||
7 | from tqdm import tqdm | ||
8 | import json | ||
9 | import multiprocessing as mp | ||
10 | import imgaug.augmenters as iaa | ||
11 | |||
12 | |||
13 | def mkdir(path): | ||
14 | if not os.path.exists(path): | ||
15 | os.makedirs(path) | ||
16 | |||
17 | |||
18 | stamp_img_root = '/data1/lxl/data/ocr/stamp/aug/img' | ||
19 | stamp_mask_root = '/data1/lxl/data/ocr/stamp/aug/mask' | ||
20 | stamp_blend_root = '/data1/lxl/data/ocr/stamp/aug/blend' | ||
21 | text_root = '/data1/lxl/data/ocr/crop_bg' | ||
22 | gen_root = '/data1/lxl/data/ocr/generate1108/' | ||
23 | gen_img_root = os.path.join(gen_root, 'img') | ||
24 | gen_stamp_img_root = os.path.join(gen_root, 'stamp_img') | ||
25 | gen_stamp_mask_root = os.path.join(gen_root, 'stamp_mask') | ||
26 | gen_text_img_root = os.path.join(gen_root, 'text_img') | ||
27 | mkdir(gen_img_root) | ||
28 | mkdir(gen_text_img_root) | ||
29 | mkdir(gen_stamp_img_root) | ||
30 | mkdir(gen_stamp_mask_root) | ||
31 | |||
32 | |||
33 | def random_idx(s, e): | ||
34 | idx = int(np.random.randint(s, e, size=(1))) | ||
35 | |||
36 | return idx | ||
37 | |||
38 | |||
39 | def get_full_path_list(root): | ||
40 | path_list = list() | ||
41 | name_list = sorted(os.listdir(root)) | ||
42 | for name in name_list: | ||
43 | path_list.append(os.path.join(root, name)) | ||
44 | |||
45 | return path_list | ||
46 | |||
47 | |||
48 | def gen(stamp_img, blend_mask, stamp_mask, text_img, gen_img_root, gen_text_img_root, gen_stamp_mask_root, savename): | ||
49 | stamp_img = Image.open(stamp_img).convert("RGB") | ||
50 | blend_mask = Image.open(blend_mask) | ||
51 | |||
52 | stamp_img_width, stamp_img_height = stamp_img.size | ||
53 | stamp_mask = Image.open(stamp_mask).convert('L') | ||
54 | stamp_mask_copy = stamp_mask.copy().convert('L') | ||
55 | text_img = Image.open(text_img).convert("RGB") | ||
56 | gen_img = text_img.copy().convert("RGB") | ||
57 | x = random_idx(0, text_img.size[0] - stamp_img.size[0]) | ||
58 | y = random_idx(0, text_img.size[1] - stamp_img.size[1]) | ||
59 | |||
60 | gen_img.paste(stamp_img, (x, y), mask=blend_mask) | ||
61 | |||
62 | gen_stamp_img = Image.new('RGB', size=text_img.size) | ||
63 | gen_stamp_img.paste(stamp_img, (x, y), mask=blend_mask) | ||
64 | gen_stamp_mask = Image.new('L', size=text_img.size) | ||
65 | gen_stamp_mask.paste(stamp_mask, (x, y), mask=stamp_mask) | ||
66 | stamp_coordinate = [x, y, x + stamp_img.size[0], y + stamp_img.size[1]] | ||
67 | stamp_dict = {'name': str(savename), 'coordinate': stamp_coordinate, 'label': ''} | ||
68 | |||
69 | gen_img.save(os.path.join(gen_img_root, "{:>06d}.jpg".format(savename))) | ||
70 | text_img.save(os.path.join(gen_text_img_root, "{:>06d}.jpg".format(savename))) | ||
71 | gen_stamp_img.save(os.path.join(gen_stamp_img_root, "{:>06d}.jpg".format(savename))) | ||
72 | gen_stamp_mask.save(os.path.join(gen_stamp_mask_root, "{:>06d}.jpg".format(savename))) | ||
73 | |||
74 | |||
75 | def process(): | ||
76 | stamp_list = sorted(os.listdir(stamp_img_root)) | ||
77 | stamp_list_lth = len(stamp_list) | ||
78 | text_list = sorted(os.listdir(text_root)) | ||
79 | text_list_lth = len(text_list) | ||
80 | need = 20000 | ||
81 | pool = mp.Pool(processes=6) | ||
82 | for i in range(0, need): | ||
83 | stamp_idx = random_idx(0, stamp_list_lth) | ||
84 | stamp_img_path = os.path.join(stamp_img_root, stamp_list[stamp_idx]) | ||
85 | stamp_mask_path = os.path.join(stamp_mask_root, stamp_list[stamp_idx]) | ||
86 | blend_mask_path = os.path.join(stamp_blend_root, stamp_list[stamp_idx]) | ||
87 | text_idx = random_idx(0, text_list_lth) | ||
88 | text_img_path = os.path.join(text_root, text_list[text_idx]) | ||
89 | pool.apply_async(gen, (stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,)) | ||
90 | # gen(stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,) | ||
91 | pool.close() | ||
92 | pool.join() | ||
93 | |||
94 | |||
95 | def main(): | ||
96 | process() | ||
97 | |||
98 | if __name__ == '__main__': | ||
99 | main() |
signatureDet @ 2dde6823
1 | Subproject commit 2dde682312388f5c5099d161019cd97df06315e4 |
core/data/preprocess/stamp_aug.py
0 → 100644
1 | import imgaug.augmenters as iaa | ||
2 | import numpy as np | ||
3 | import cv2 | ||
4 | import os | ||
5 | from tqdm import tqdm | ||
6 | from PIL import Image | ||
7 | |||
8 | seq = iaa.Sequential( | ||
9 | [ | ||
10 | iaa.Fliplr(0.5), | ||
11 | iaa.Crop(percent=(0, 0.1), keep_size=True), | ||
12 | iaa.Sometimes(0.5, iaa.GaussianBlur(sigma=(0, 0.5))), | ||
13 | iaa.AddElementwise((-40, -10), per_channel=0.5), | ||
14 | iaa.Sometimes(0.5, iaa.MultiplyElementwise((0.7, 1.0))), | ||
15 | iaa.OneOf([ | ||
16 | iaa.Rotate((-45, 45)), | ||
17 | iaa.Rot90((1, 3)) | ||
18 | ]), | ||
19 | iaa.Sometimes(0.7, iaa.CoarseDropout(p=(0.1, 0.4), size_percent=(0.02, 0.2))), | ||
20 | iaa.Sometimes(0.02, iaa.imgcorruptlike.MotionBlur(severity=2)), | ||
21 | ]) | ||
22 | |||
23 | gen_blend = iaa.Sequential([ | ||
24 | iaa.GaussianBlur(sigma=(0, 0.5)), | ||
25 | iaa.MultiplyElementwise((0.8, 3.5)), | ||
26 | ]) | ||
27 | |||
28 | |||
29 | img_root = '/data1/lxl/data/ocr/stamp/src' | ||
30 | gen_root = '/data1/lxl/data/ocr/stamp/aug' | ||
31 | gen_img_root = os.path.join(gen_root, 'img') | ||
32 | gen_mask_root = os.path.join(gen_root, 'mask') | ||
33 | gen_blend_root = os.path.join(gen_root, 'blend') | ||
34 | if not os.path.exists(gen_img_root): | ||
35 | os.makedirs(gen_img_root) | ||
36 | if not os.path.exists(gen_mask_root): | ||
37 | os.makedirs(gen_mask_root) | ||
38 | if not os.path.exists(gen_blend_root): | ||
39 | os.makedirs(gen_blend_root) | ||
40 | |||
41 | name_list = sorted(os.listdir(img_root)) | ||
42 | |||
43 | for i in range(2): | ||
44 | for name in tqdm(name_list): | ||
45 | name_no_ex = name.split('.')[0] | ||
46 | ext = name.split('.')[1] | ||
47 | img = cv2.imread(os.path.join(img_root, name), -1)[:, :, :3] | ||
48 | mask = cv2.imread(os.path.join(img_root, name), -1)[:, :, -1] | ||
49 | mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) | ||
50 | img = np.asarray(img, dtype=np.uint8) | ||
51 | mask = np.asarray(mask, dtype=np.uint8) | ||
52 | img = img[np.newaxis, :] | ||
53 | mask = mask[np.newaxis, :] | ||
54 | img_aug, mask_aug = seq(images=img, segmentation_maps=mask) | ||
55 | img_aug = img_aug.squeeze() | ||
56 | mask_aug = mask_aug.squeeze(0) | ||
57 | |||
58 | blend = cv2.cvtColor(img_aug, cv2.COLOR_BGR2GRAY) | ||
59 | blend_aug = gen_blend(images=blend) | ||
60 | blend_aug = blend_aug.squeeze() | ||
61 | |||
62 | cv2.imwrite(os.path.join(gen_img_root, (name_no_ex + '_' + str(i) + '.' + ext)), img_aug) | ||
63 | cv2.imwrite(os.path.join(gen_mask_root, (name_no_ex + '_' + str(i) + '.' + ext)), mask_aug) | ||
64 | cv2.imwrite(os.path.join(gen_blend_root, (name_no_ex + '_' + str(i) + '.' + ext)), blend_aug) |
core/loss/__init__.py
0 → 100644
core/loss/builder.py
0 → 100644
1 | import torch | ||
2 | import inspect | ||
3 | from utils.registery import LOSS_REGISTRY | ||
4 | import copy | ||
5 | |||
6 | |||
7 | def register_torch_loss(): | ||
8 | for module_name in dir(torch.nn): | ||
9 | if module_name.startswith('__') or 'Loss' not in module_name: | ||
10 | continue | ||
11 | _loss = getattr(torch.nn, module_name) | ||
12 | if inspect.isclass(_loss) and issubclass(_loss, torch.nn.Module): | ||
13 | LOSS_REGISTRY.register()(_loss) | ||
14 | |||
15 | def build_loss(cfg): | ||
16 | register_torch_loss() | ||
17 | loss_cfg = copy.deepcopy(cfg) | ||
18 | |||
19 | try: | ||
20 | loss_cfg = cfg['solver']['loss'] | ||
21 | except Exception: | ||
22 | raise 'should contain {solver.loss}!' | ||
23 | |||
24 | return LOSS_REGISTRY.get(loss_cfg['name'])(**loss_cfg['args']) |
core/metric/__init__.py
0 → 100644
core/metric/builder.py
0 → 100644
1 | from utils.registery import METRIC_REGISTRY | ||
2 | import copy | ||
3 | |||
4 | from .recon_metric import Recon | ||
5 | |||
6 | def build_metric(cfg): | ||
7 | cfg = copy.deepcopy(cfg) | ||
8 | try: | ||
9 | metric_cfg = cfg['solver']['metric'] | ||
10 | except Exception: | ||
11 | raise 'should contain {solver.metric}!' | ||
12 | |||
13 | return METRIC_REGISTRY.get(metric_cfg['name'])() |
core/metric/recon_metric.py
0 → 100644
1 | import numpy as np | ||
2 | import torchmetrics | ||
3 | |||
4 | |||
5 | from utils.registery import METRIC_REGISTRY | ||
6 | |||
7 | @METRIC_REGISTRY.register() | ||
8 | class Recon(object): | ||
9 | def __init__(self): | ||
10 | self.psnr = torchmetrics.PeakSignalNoiseRatio() | ||
11 | self.ssim = torchmetrics.StructuralSimilarityIndexMeasure() | ||
12 | |||
13 | def __call__(self, pred, label): | ||
14 | |||
15 | assert pred.shape[0] == label.shape[0] | ||
16 | |||
17 | psnr_list = list() | ||
18 | ssim_list = list() | ||
19 | |||
20 | for i in range(len(pred)): | ||
21 | psnr_list.append(self.psnr(pred[i].unsqueeze(0), label[i].unsqueeze(0))) | ||
22 | ssim_list.append(self.ssim(pred[i].unsqueeze(0), label[i].unsqueeze(0))) | ||
23 | |||
24 | psnr_result = sum(psnr_list) / len(psnr_list) | ||
25 | ssim_result = sum(ssim_list) / len(ssim_list) | ||
26 | |||
27 | return {'psnr': psnr_result, 'ssim': ssim_result} | ||
28 | |||
29 |
core/model/__init__.py
0 → 100644
core/model/base_model.py
0 → 100644
1 | import torch | ||
2 | import torch.nn as nn | ||
3 | from abc import ABCMeta | ||
4 | import math | ||
5 | |||
6 | |||
7 | class BaseModel(nn.Module, metaclass=ABCMeta): | ||
8 | """ | ||
9 | Base model class | ||
10 | """ | ||
11 | def __init__(self): | ||
12 | super().__init__() | ||
13 | |||
14 | def _initialize_weights(self): | ||
15 | for m in self.modules(): | ||
16 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | ||
17 | nn.init.xavier_uniform_(m.weight) | ||
18 | if m.bias is not None: | ||
19 | nn.init.constant_(m.bias, 0) | ||
20 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): | ||
21 | nn.init.constant_(m.weight, 1) | ||
22 | nn.init.constant_(m.bias, 0) | ||
23 | elif isinstance(m, nn.Linear): | ||
24 | nn.init.xavier_uniform_(m.weight) | ||
25 | if m.bias is not None: | ||
26 | nn.init.constant_(m.bias, 0) | ||
27 |
core/model/builder.py
0 → 100644
1 | import copy | ||
2 | from utils import MODEL_REGISTRY | ||
3 | |||
4 | from .unet import Unet | ||
5 | from .unet_skip import UnetSkip | ||
6 | |||
7 | |||
8 | def build_model(cfg): | ||
9 | model_cfg = copy.deepcopy(cfg) | ||
10 | try: | ||
11 | model_cfg = model_cfg['model'] | ||
12 | except Exception: | ||
13 | raise 'should contain {model}' | ||
14 | |||
15 | model = MODEL_REGISTRY.get(model_cfg['name'])(**model_cfg['args']) | ||
16 | |||
17 | return model | ||
18 |
core/model/unet.py
0 → 100644
1 | import segmentation_models_pytorch as smp | ||
2 | |||
3 | from utils.registery import MODEL_REGISTRY | ||
4 | from core.model.base_model import BaseModel | ||
5 | |||
6 | |||
7 | @MODEL_REGISTRY.register() | ||
8 | class Unet(BaseModel): | ||
9 | def __init__(self, | ||
10 | encoder_name: str = 'resnet50', | ||
11 | encoder_weights: str = 'imagenet', | ||
12 | in_channels: int = 3, | ||
13 | classes: int = 3, | ||
14 | activation: str = 'tanh'): | ||
15 | super().__init__() | ||
16 | self.model = smp.Unet( | ||
17 | encoder_name=encoder_name, | ||
18 | encoder_weights=encoder_weights, | ||
19 | in_channels=in_channels, | ||
20 | classes=classes, | ||
21 | activation=activation | ||
22 | ) | ||
23 | |||
24 | self._initialize_weights() | ||
25 | |||
26 | def forward(self, x): | ||
27 | out = x + self.model(x) | ||
28 | # out = self.model(x) | ||
29 | |||
30 | return out | ||
31 | |||
32 |
core/model/unet_skip.py
0 → 100644
1 | import segmentation_models_pytorch as smp | ||
2 | |||
3 | from utils.registery import MODEL_REGISTRY | ||
4 | from core.model.base_model import BaseModel | ||
5 | |||
6 | |||
7 | @MODEL_REGISTRY.register() | ||
8 | class UnetSkip(BaseModel): | ||
9 | def __init__(self, | ||
10 | encoder_name: str = 'resnet50', | ||
11 | encoder_weights: str = 'imagenet', | ||
12 | in_channels: int = 3, | ||
13 | classes: int = 3, | ||
14 | activation: str = 'tanh'): | ||
15 | super().__init__() | ||
16 | self.model = smp.Unet( | ||
17 | encoder_name=encoder_name, | ||
18 | encoder_weights=encoder_weights, | ||
19 | in_channels=in_channels, | ||
20 | classes=classes, | ||
21 | activation=activation | ||
22 | ) | ||
23 | |||
24 | self._initialize_weights() | ||
25 | |||
26 | def forward(self, x): | ||
27 | residual = self.model(x) | ||
28 | reconstruction = x + residual | ||
29 | |||
30 | returned_dict = {'reconstruction': reconstruction, 'residual': residual} | ||
31 | |||
32 | return returned_dict | ||
33 | |||
34 |
core/optimizer/__init__.py
0 → 100644
core/optimizer/builder.py
0 → 100644
1 | import torch | ||
2 | import inspect | ||
3 | from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY | ||
4 | import copy | ||
5 | |||
6 | def register_torch_optimizers(): | ||
7 | """ | ||
8 | Register all optimizers implemented by torch | ||
9 | """ | ||
10 | for module_name in dir(torch.optim): | ||
11 | if module_name.startswith('__'): | ||
12 | continue | ||
13 | _optim = getattr(torch.optim, module_name) | ||
14 | if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): | ||
15 | OPTIMIZER_REGISTRY.register()(_optim) | ||
16 | |||
17 | def build_optimizer(cfg): | ||
18 | register_torch_optimizers() | ||
19 | optimizer_cfg = copy.deepcopy(cfg) | ||
20 | |||
21 | try: | ||
22 | optimizer_cfg = cfg['solver']['optimizer'] | ||
23 | except Exception: | ||
24 | raise 'should contain {solver.optimizer}!' | ||
25 | |||
26 | return OPTIMIZER_REGISTRY.get(optimizer_cfg['name']) | ||
27 | |||
28 | def register_torch_lr_scheduler(): | ||
29 | """ | ||
30 | Register all lr_schedulers implemented by torch | ||
31 | """ | ||
32 | for module_name in dir(torch.optim.lr_scheduler): | ||
33 | if module_name.startswith('__'): | ||
34 | continue | ||
35 | |||
36 | _scheduler = getattr(torch.optim.lr_scheduler, module_name) | ||
37 | if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler): | ||
38 | LR_SCHEDULER_REGISTRY.register()(_scheduler) | ||
39 | |||
40 | def build_lr_scheduler(cfg): | ||
41 | register_torch_lr_scheduler() | ||
42 | scheduler_cfg = copy.deepcopy(cfg) | ||
43 | |||
44 | try: | ||
45 | scheduler_cfg = cfg['solver']['lr_scheduler'] | ||
46 | except Exception: | ||
47 | raise 'should contain {solver.lr_scheduler}!' | ||
48 | |||
49 | return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name']) | ||
50 |
core/solver/__init__.py
0 → 100644
core/solver/base_solver.py
0 → 100644
1 | import torch | ||
2 | from core.data import build_dataloader | ||
3 | from core.model import build_model | ||
4 | from core.optimizer import build_optimizer, build_lr_scheduler | ||
5 | from core.loss import build_loss | ||
6 | from core.metric import build_metric | ||
7 | from utils.registery import SOLVER_REGISTRY | ||
8 | from utils.logger import get_logger_and_log_path | ||
9 | import os | ||
10 | import copy | ||
11 | import datetime | ||
12 | from torch.nn.parallel import DistributedDataParallel | ||
13 | import numpy as np | ||
14 | import pandas as pd | ||
15 | import yaml | ||
16 | |||
17 | |||
18 | @SOLVER_REGISTRY.register() | ||
19 | class BaseSolver(object): | ||
20 | def __init__(self, cfg): | ||
21 | self.cfg = copy.deepcopy(cfg) | ||
22 | self.local_rank = torch.distributed.get_rank() | ||
23 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
24 | self.len_train_loader, self.len_val_loader = len(self.train_loader), len(self.val_loader) | ||
25 | self.criterion = build_loss(cfg).cuda(self.local_rank) | ||
26 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(build_model(cfg)) | ||
27 | self.model = DistributedDataParallel(model.cuda(self.local_rank), device_ids=[self.local_rank], find_unused_parameters=True) | ||
28 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
29 | self.hyper_params = cfg['solver']['args'] | ||
30 | crt_date = datetime.date.today().strftime('%Y-%m-%d') | ||
31 | self.logger, self.log_path = get_logger_and_log_path(crt_date=crt_date, **cfg['solver']['logger']) | ||
32 | self.metric_fn = build_metric(cfg) | ||
33 | try: | ||
34 | self.epoch = self.hyper_params['epoch'] | ||
35 | except Exception: | ||
36 | raise 'should contain epoch in {solver.args}' | ||
37 | if self.local_rank == 0: | ||
38 | self.save_dict_to_yaml(self.cfg, os.path.join(self.log_path, 'config.yaml')) | ||
39 | self.logger.info(self.cfg) | ||
40 | |||
41 | def train(self): | ||
42 | if torch.distributed.get_rank() == 0: | ||
43 | self.logger.info('==> Start Training') | ||
44 | lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
45 | |||
46 | for t in range(self.epoch): | ||
47 | self.train_loader.sampler.set_epoch(t) | ||
48 | if torch.distributed.get_rank() == 0: | ||
49 | self.logger.info(f'==> epoch {t + 1}') | ||
50 | self.model.train() | ||
51 | |||
52 | pred_list = list() | ||
53 | label_list = list() | ||
54 | |||
55 | mean_loss = 0.0 | ||
56 | |||
57 | for i, data in enumerate(self.train_loader): | ||
58 | self.optimizer.zero_grad() | ||
59 | image = data['image'].cuda(self.local_rank) | ||
60 | label = data['label'].cuda(self.local_rank) | ||
61 | |||
62 | pred = self.model(image) | ||
63 | |||
64 | loss = self.criterion(pred, label) | ||
65 | mean_loss += loss.item() | ||
66 | |||
67 | if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0): | ||
68 | loss_value = loss.item() | ||
69 | self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, loss: {loss_value :.4f}') | ||
70 | |||
71 | loss.backward() | ||
72 | self.optimizer.step() | ||
73 | |||
74 | # batch_pred = [torch.zeros_like(pred) for _ in range(torch.distributed.get_world_size())] # 1 | ||
75 | # torch.distributed.all_gather(batch_pred, pred) | ||
76 | # pred_list.append(torch.cat(batch_pred, dim=0).detach().cpu()) | ||
77 | |||
78 | # batch_label = [torch.zeros_like(label) for _ in range(torch.distributed.get_world_size())] | ||
79 | # torch.distributed.all_gather(batch_label, label) | ||
80 | # label_list.append(torch.cat(batch_label, dim=0).detach().cpu()) | ||
81 | |||
82 | # pred_list = torch.cat(pred_list, dim=0) | ||
83 | # label_list = torch.cat(label_list, dim=0) | ||
84 | # metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list}) | ||
85 | mean_loss = mean_loss / self.len_train_loader | ||
86 | |||
87 | if torch.distributed.get_rank() == 0: | ||
88 | # self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}") | ||
89 | self.logger.info(f'==> train mean loss: {mean_loss :.4f}') | ||
90 | self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1) | ||
91 | self.val(t + 1) | ||
92 | lr_scheduler.step() | ||
93 | |||
94 | if self.local_rank == 0: | ||
95 | self.logger.info('==> End Training') | ||
96 | |||
97 | @torch.no_grad() | ||
98 | def val(self, t): | ||
99 | self.model.eval() | ||
100 | |||
101 | pred_list = list() | ||
102 | label_list = list() | ||
103 | |||
104 | for i, data in enumerate(self.val_loader): | ||
105 | feat = data['image'].cuda(self.local_rank) | ||
106 | label = data['label'].cuda(self.local_rank) | ||
107 | |||
108 | pred = self.model(feat) | ||
109 | |||
110 | pred_list.append(pred.detach().cpu()) | ||
111 | label_list.append(label.detach().cpu()) | ||
112 | |||
113 | pred_list = torch.cat(pred_list, dim=0) | ||
114 | label_list = torch.cat(label_list, dim=0) | ||
115 | |||
116 | metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list}) | ||
117 | if torch.distributed.get_rank() == 0: | ||
118 | self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}") | ||
119 | |||
120 | def run(self): | ||
121 | self.train() | ||
122 | |||
123 | @staticmethod | ||
124 | def save_dict_to_yaml(dict_value, save_path): | ||
125 | with open(save_path, 'w', encoding='utf-8') as file: | ||
126 | yaml.dump(dict_value, file, sort_keys=False) | ||
127 | |||
128 | |||
129 | def save_checkpoint(self, model, cfg, log_path, epoch_id): | ||
130 | model.eval() | ||
131 | torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt')) |
core/solver/builder.py
0 → 100644
1 | from utils.registery import SOLVER_REGISTRY | ||
2 | import copy | ||
3 | |||
4 | from .base_solver import BaseSolver | ||
5 | from .skip_solver import SkipSolver | ||
6 | |||
7 | |||
8 | def build_solver(cfg): | ||
9 | cfg = copy.deepcopy(cfg) | ||
10 | |||
11 | try: | ||
12 | solver_cfg = cfg['solver'] | ||
13 | except Exception: | ||
14 | raise 'should contain {solver}!' | ||
15 | |||
16 | return SOLVER_REGISTRY.get(solver_cfg['name'])(cfg) |
core/solver/skip_solver.py
0 → 100644
1 | import torch | ||
2 | from core.data import build_dataloader | ||
3 | from core.model import build_model | ||
4 | from core.optimizer import build_optimizer, build_lr_scheduler | ||
5 | from core.loss import build_loss | ||
6 | from core.metric import build_metric | ||
7 | from utils.registery import SOLVER_REGISTRY | ||
8 | from utils.logger import get_logger_and_log_path | ||
9 | import os | ||
10 | import copy | ||
11 | import datetime | ||
12 | from torch.nn.parallel import DistributedDataParallel | ||
13 | import numpy as np | ||
14 | import pandas as pd | ||
15 | import yaml | ||
16 | |||
17 | from .base_solver import BaseSolver | ||
18 | |||
19 | |||
20 | @SOLVER_REGISTRY.register() | ||
21 | class SkipSolver(BaseSolver): | ||
22 | def __init__(self, cfg): | ||
23 | super().__init__(cfg) | ||
24 | |||
25 | def train(self): | ||
26 | if torch.distributed.get_rank() == 0: | ||
27 | self.logger.info('==> Start Training') | ||
28 | lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
29 | |||
30 | for t in range(self.epoch): | ||
31 | self.train_loader.sampler.set_epoch(t) | ||
32 | if torch.distributed.get_rank() == 0: | ||
33 | self.logger.info(f'==> epoch {t + 1}') | ||
34 | self.model.train() | ||
35 | |||
36 | pred_list = list() | ||
37 | label_list = list() | ||
38 | |||
39 | mean_loss = 0.0 | ||
40 | |||
41 | for i, data in enumerate(self.train_loader): | ||
42 | self.optimizer.zero_grad() | ||
43 | image = data['image'].cuda(self.local_rank) | ||
44 | label = data['label'].cuda(self.local_rank) | ||
45 | residual = image - label | ||
46 | |||
47 | pred = self.model(image) | ||
48 | |||
49 | reconstruction_loss = self.criterion(pred['reconstruction'], label) | ||
50 | residual_loss = self.criterion(pred['residual'], residual) | ||
51 | loss = reconstruction_loss + residual_loss | ||
52 | mean_loss += loss.item() | ||
53 | |||
54 | if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0): | ||
55 | reconstruction_loss_value = reconstruction_loss.item() | ||
56 | residual_loss_value = residual_loss.item() | ||
57 | loss_value = loss.item() | ||
58 | self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, reconstruction loss: {reconstruction_loss_value :.4f}, residual loss: {residual_loss_value :.4f}, loss: {loss_value :.4f}') | ||
59 | |||
60 | loss.backward() | ||
61 | self.optimizer.step() | ||
62 | |||
63 | mean_loss = mean_loss / self.len_train_loader | ||
64 | |||
65 | if torch.distributed.get_rank() == 0: | ||
66 | # self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}") | ||
67 | self.logger.info(f'==> train mean loss: {mean_loss :.4f}') | ||
68 | self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1) | ||
69 | self.val(t + 1) | ||
70 | lr_scheduler.step() | ||
71 | |||
72 | if self.local_rank == 0: | ||
73 | self.logger.info('==> End Training') | ||
74 | |||
75 | @torch.no_grad() | ||
76 | def val(self, t): | ||
77 | self.model.eval() | ||
78 | |||
79 | pred_list = list() | ||
80 | label_list = list() | ||
81 | |||
82 | for i, data in enumerate(self.val_loader): | ||
83 | image = data['image'].cuda(self.local_rank) | ||
84 | label = data['label'].cuda(self.local_rank) | ||
85 | residual = image - label | ||
86 | |||
87 | pred = self.model(image) | ||
88 | |||
89 | pred_list.append(pred['reconstruction'].detach().cpu()) | ||
90 | label_list.append(label.detach().cpu()) | ||
91 | |||
92 | pred_list = torch.cat(pred_list, dim=0) | ||
93 | label_list = torch.cat(label_list, dim=0) | ||
94 | |||
95 | metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list}) | ||
96 | if torch.distributed.get_rank() == 0: | ||
97 | self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}") | ||
98 | |||
99 | def run(self): | ||
100 | self.train() | ||
101 | |||
102 | @staticmethod | ||
103 | def save_dict_to_yaml(dict_value, save_path): | ||
104 | with open(save_path, 'w', encoding='utf-8') as file: | ||
105 | yaml.dump(dict_value, file, sort_keys=False) | ||
106 | |||
107 | |||
108 | def save_checkpoint(self, model, cfg, log_path, epoch_id): | ||
109 | model.eval() | ||
110 | torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt')) |
main.py
0 → 100644
1 | import yaml | ||
2 | from core.solver import build_solver | ||
3 | import torch | ||
4 | import numpy as np | ||
5 | import random | ||
6 | import argparse | ||
7 | |||
8 | |||
9 | def init_seed(seed=778): | ||
10 | random.seed(seed) | ||
11 | np.random.seed(seed) | ||
12 | torch.manual_seed(seed) | ||
13 | torch.cuda.manual_seed(seed) | ||
14 | torch.cuda.manual_seed_all(seed) | ||
15 | torch.backends.cudnn.benchmark = False | ||
16 | torch.backends.cudnn.deterministic = True | ||
17 | |||
18 | def main(): | ||
19 | parser = argparse.ArgumentParser() | ||
20 | parser.add_argument('--config', default='./config/baseline.yaml', type=str, help='config file') | ||
21 | parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') | ||
22 | args = parser.parse_args() | ||
23 | |||
24 | cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader) | ||
25 | init_seed(cfg['seed']) | ||
26 | |||
27 | torch.distributed.init_process_group(backend='nccl') | ||
28 | torch.cuda.set_device(args.local_rank) | ||
29 | |||
30 | solver = build_solver(cfg) | ||
31 | |||
32 | solver.run() | ||
33 | |||
34 | if __name__ == '__main__': | ||
35 | main() |
pred.py
0 → 100644
1 | import os | ||
2 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
3 | import cv2 | ||
4 | import numpy as np | ||
5 | import torch | ||
6 | from PIL import Image | ||
7 | import torchvision.transforms as transforms | ||
8 | from tqdm import tqdm | ||
9 | |||
10 | |||
11 | from core.model.unet import Unet | ||
12 | from core.model.unet_skip import UnetSkip | ||
13 | |||
14 | # def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/baseline/ckpt_epoch_18.pt'): | ||
15 | # def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/residual/ckpt_epoch_28.pt'): | ||
16 | def load_model(ckpt='./log/2022-11-11/skip/ckpt_epoch_30.pt'): | ||
17 | # model = Unet( | ||
18 | # encoder_name='resnet50', | ||
19 | # encoder_weights='imagenet', | ||
20 | # in_channels=3, | ||
21 | # classes=3, | ||
22 | # activation='tanh' | ||
23 | # ) | ||
24 | |||
25 | model = UnetSkip( | ||
26 | encoder_name='resnet50', | ||
27 | encoder_weights='imagenet', | ||
28 | in_channels=3, | ||
29 | classes=3, | ||
30 | activation='tanh' | ||
31 | ) | ||
32 | |||
33 | model.load_state_dict(torch.load(ckpt, map_location='cpu')) | ||
34 | model.eval() | ||
35 | |||
36 | return model | ||
37 | |||
38 | |||
39 | def infer(model, img_path, gen_path): | ||
40 | img = cv2.imread(img_path) | ||
41 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
42 | img = cv2.resize(img, (512, 512)) | ||
43 | img = img / 255. | ||
44 | img = img.transpose(2, 0, 1).astype(np.float32) | ||
45 | img = torch.from_numpy(img).unsqueeze(0) | ||
46 | out = model(img) | ||
47 | out = out['reconstruction'] | ||
48 | out[out < 0] = 0 | ||
49 | out[out > 1] = 1 | ||
50 | out = out * 255 | ||
51 | out = out.detach().cpu().numpy().squeeze(0).transpose(1, 2, 0) | ||
52 | out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR) | ||
53 | cv2.imwrite(gen_path, out) | ||
54 | |||
55 | def infer_list(): | ||
56 | img_root = '/data1/lxl/data/ocr/real/src/' | ||
57 | gen_root = '/data1/lxl/data/ocr/real/removed/' | ||
58 | img_list = sorted(os.listdir(img_root)) | ||
59 | model = load_model() | ||
60 | for img in tqdm(img_list): | ||
61 | img_path = os.path.join(img_root, img) | ||
62 | gen_path = os.path.join(gen_root, img) | ||
63 | infer(model, img_path, gen_path) | ||
64 | |||
65 | def infer_img(): | ||
66 | model = load_model() | ||
67 | img_path = '../DocEnTR/real/hetong_006_00.png' | ||
68 | gen_path = './out.jpg' | ||
69 | infer(model, img_path, gen_path) | ||
70 | |||
71 | infer_img() |
run.sh
0 → 100644
1 | CUDA_VISIBLE_DEVICES=0 nohup python -m torch.distributed.launch --master_port 8999 --nproc_per_node=1 main.py --config ./config/skip.yaml & |
utils/__init__.py
0 → 100644
utils/helper.py
0 → 100644
1 | import torch | ||
2 | import yaml | ||
3 | import os | ||
4 | |||
5 | |||
6 | def save_dict_to_yaml(dict_value, save_path): | ||
7 | with open(save_path, 'w', encoding='utf-8') as file: | ||
8 | yaml.dump(dict_value, file, sort_keys=False) | ||
9 | |||
10 | |||
11 | def save_checkpoint(model, cfg, log_path, epoch_id): | ||
12 | save_dict_to_yaml(cfg, os.path.join(log_path, 'config.yaml')) | ||
13 | torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt')) |
utils/logger.py
0 → 100644
1 | import loguru | ||
2 | import copy | ||
3 | import os | ||
4 | import datetime | ||
5 | |||
6 | def get_logger_and_log_path(log_root, | ||
7 | crt_date, | ||
8 | suffix): | ||
9 | """ | ||
10 | get logger and log path | ||
11 | |||
12 | Args: | ||
13 | log_root (str): root path of log | ||
14 | crt_date (str): formated date name (Y-M-D) | ||
15 | suffix (str): log save name | ||
16 | |||
17 | Returns: | ||
18 | logger (loguru.logger): logger object | ||
19 | log_path (str): current root log path (with suffix) | ||
20 | """ | ||
21 | log_path = os.path.join(log_root, crt_date, suffix) | ||
22 | if not os.path.exists(log_path): | ||
23 | os.makedirs(log_path) | ||
24 | |||
25 | logger_path = os.path.join(log_path, 'logfile.log') | ||
26 | logger = loguru.logger | ||
27 | fmt = '{time:YYYY-MM-DD at HH:mm:ss} | {message}' | ||
28 | logger.add(logger_path, format=fmt) | ||
29 | |||
30 | return logger, log_path |
utils/registery.py
0 → 100644
1 | class Registry(): | ||
2 | """ | ||
3 | The registry that provides name -> object mapping, to support third-party | ||
4 | users' custom modules. | ||
5 | |||
6 | """ | ||
7 | |||
8 | def __init__(self, name): | ||
9 | """ | ||
10 | Args: | ||
11 | name (str): the name of this registry | ||
12 | """ | ||
13 | self._name = name | ||
14 | self._obj_map = {} | ||
15 | |||
16 | def _do_register(self, name, obj, suffix=None): | ||
17 | if isinstance(suffix, str): | ||
18 | name = name + '_' + suffix | ||
19 | |||
20 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " | ||
21 | f"in '{self._name}' registry!") | ||
22 | self._obj_map[name] = obj | ||
23 | |||
24 | def register(self, obj=None, suffix=None): | ||
25 | """ | ||
26 | Register the given object under the the name `obj.__name__`. | ||
27 | Can be used as either a decorator or not. | ||
28 | See docstring of this class for usage. | ||
29 | """ | ||
30 | if obj is None: | ||
31 | # used as a decorator | ||
32 | def deco(func_or_class): | ||
33 | name = func_or_class.__name__ | ||
34 | self._do_register(name, func_or_class, suffix) | ||
35 | return func_or_class | ||
36 | |||
37 | return deco | ||
38 | |||
39 | # used as a function call | ||
40 | name = obj.__name__ | ||
41 | self._do_register(name, obj, suffix) | ||
42 | |||
43 | def get(self, name, suffix='soulwalker'): | ||
44 | ret = self._obj_map.get(name) | ||
45 | if ret is None: | ||
46 | ret = self._obj_map.get(name + '_' + suffix) | ||
47 | print(f'Name {name} is not found, use name: {name}_{suffix}!') | ||
48 | if ret is None: | ||
49 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") | ||
50 | return ret | ||
51 | |||
52 | def __contains__(self, name): | ||
53 | return name in self._obj_map | ||
54 | |||
55 | def __iter__(self): | ||
56 | return iter(self._obj_map.items()) | ||
57 | |||
58 | def keys(self): | ||
59 | return self._obj_map.keys() | ||
60 | |||
61 | |||
62 | DATASET_REGISTRY = Registry('dataset') | ||
63 | MODEL_REGISTRY = Registry('model') | ||
64 | LOSS_REGISTRY = Registry('loss') | ||
65 | METRIC_REGISTRY = Registry('metric') | ||
66 | OPTIMIZER_REGISTRY = Registry('optimizer') | ||
67 | SOLVER_REGISTRY = Registry('solver') | ||
68 | LR_SCHEDULER_REGISTRY = Registry('lr_scheduler') | ||
69 | COLLATE_FN_REGISTRY = Registry('collate_fn') |
-
Please register or sign in to post a comment