init commit
0 parents
Showing
37 changed files
with
1230 additions
and
0 deletions
config/baseline.yaml
0 → 100644
| 1 | seed: 3407 | ||
| 2 | |||
| 3 | dataset: | ||
| 4 | name: 'ReconData' | ||
| 5 | args: | ||
| 6 | data_root: '/data1/lxl/data/ocr/generate1108' | ||
| 7 | train_anno_file: '/data1/lxl/data/ocr/generate1108/train.csv' | ||
| 8 | val_anno_file: '/data1/lxl/data/ocr/generate1108/val.csv' | ||
| 9 | fixed_size: 512 | ||
| 10 | |||
| 11 | dataloader: | ||
| 12 | batch_size: 8 | ||
| 13 | num_workers: 16 | ||
| 14 | pin_memory: true | ||
| 15 | collate_fn: 'base_collate_fn' | ||
| 16 | |||
| 17 | model: | ||
| 18 | name: 'Unet' | ||
| 19 | args: | ||
| 20 | encoder_name: 'resnet50' | ||
| 21 | encoder_weights: 'imagenet' | ||
| 22 | in_channels: 3 | ||
| 23 | classes: 3 | ||
| 24 | activation: 'tanh' | ||
| 25 | |||
| 26 | solver: | ||
| 27 | name: 'BaseSolver' | ||
| 28 | args: | ||
| 29 | epoch: 30 | ||
| 30 | |||
| 31 | optimizer: | ||
| 32 | name: 'Adam' | ||
| 33 | args: | ||
| 34 | lr: !!float 1e-4 | ||
| 35 | weight_decay: !!float 5e-5 | ||
| 36 | |||
| 37 | lr_scheduler: | ||
| 38 | name: 'StepLR' | ||
| 39 | args: | ||
| 40 | step_size: 15 | ||
| 41 | gamma: 0.1 | ||
| 42 | |||
| 43 | loss: | ||
| 44 | name: 'L1Loss' | ||
| 45 | args: | ||
| 46 | reduction: 'mean' | ||
| 47 | |||
| 48 | logger: | ||
| 49 | log_root: '/data1/lxl/code/ocr/removal/log' | ||
| 50 | suffix: 'residual' | ||
| 51 | |||
| 52 | metric: | ||
| 53 | name: 'Recon' |
config/skip.yaml
0 → 100644
| 1 | seed: 3407 | ||
| 2 | |||
| 3 | dataset: | ||
| 4 | name: 'ReconData' | ||
| 5 | args: | ||
| 6 | data_root: '/data1/lxl/data/ocr/generate1108' | ||
| 7 | train_anno_file: '/data1/lxl/data/ocr/generate1108/train.csv' | ||
| 8 | val_anno_file: '/data1/lxl/data/ocr/generate1108/val.csv' | ||
| 9 | fixed_size: 512 | ||
| 10 | |||
| 11 | dataloader: | ||
| 12 | batch_size: 8 | ||
| 13 | num_workers: 16 | ||
| 14 | pin_memory: true | ||
| 15 | collate_fn: 'base_collate_fn' | ||
| 16 | |||
| 17 | model: | ||
| 18 | name: 'UnetSkip' | ||
| 19 | args: | ||
| 20 | encoder_name: 'resnet50' | ||
| 21 | encoder_weights: 'imagenet' | ||
| 22 | in_channels: 3 | ||
| 23 | classes: 3 | ||
| 24 | activation: 'tanh' | ||
| 25 | |||
| 26 | solver: | ||
| 27 | name: 'SkipSolver' | ||
| 28 | args: | ||
| 29 | epoch: 30 | ||
| 30 | |||
| 31 | optimizer: | ||
| 32 | name: 'Adam' | ||
| 33 | args: | ||
| 34 | lr: !!float 1e-4 | ||
| 35 | weight_decay: !!float 5e-5 | ||
| 36 | |||
| 37 | lr_scheduler: | ||
| 38 | name: 'StepLR' | ||
| 39 | args: | ||
| 40 | step_size: 15 | ||
| 41 | gamma: 0.1 | ||
| 42 | |||
| 43 | loss: | ||
| 44 | name: 'L1Loss' | ||
| 45 | args: | ||
| 46 | reduction: 'mean' | ||
| 47 | |||
| 48 | logger: | ||
| 49 | log_root: '/data1/lxl/code/ocr/removal/log' | ||
| 50 | suffix: 'skip' | ||
| 51 | |||
| 52 | metric: | ||
| 53 | name: 'Recon' |
core/__init__.py
0 → 100644
| 1 | from .solver import build_solver |
core/data/ReconData.py
0 → 100644
| 1 | from torch.utils.data import DataLoader, Dataset | ||
| 2 | import torchvision.transforms as transforms | ||
| 3 | import torch | ||
| 4 | import pandas as pd | ||
| 5 | import os | ||
| 6 | import random | ||
| 7 | import cv2 | ||
| 8 | import albumentations as A | ||
| 9 | import albumentations.pytorch | ||
| 10 | import numpy as np | ||
| 11 | from PIL import Image | ||
| 12 | |||
| 13 | from utils.registery import DATASET_REGISTRY | ||
| 14 | |||
| 15 | |||
| 16 | @DATASET_REGISTRY.register() | ||
| 17 | class ReconData(Dataset): | ||
| 18 | def __init__(self, | ||
| 19 | data_root: str = '/data1/lxl/data/ocr/generate1108', | ||
| 20 | anno_file: str = 'train.csv', | ||
| 21 | fixed_size: int = 448, | ||
| 22 | phase: str = 'train'): | ||
| 23 | self.data_root = data_root | ||
| 24 | self.df = pd.read_csv(anno_file) | ||
| 25 | self.img_root = os.path.join(data_root, 'img') | ||
| 26 | self.gt_root = os.path.join(data_root, 'text_img') | ||
| 27 | self.fixed_size = fixed_size | ||
| 28 | |||
| 29 | self.phase = phase | ||
| 30 | transform_fn = self.__get_transform() | ||
| 31 | self.transform = transform_fn[phase] | ||
| 32 | |||
| 33 | def __get_transform(self): | ||
| 34 | |||
| 35 | train_transform = A.Compose([ | ||
| 36 | A.Resize(height=self.fixed_size, width=self.fixed_size), | ||
| 37 | # A.RandomBrightness(limit=(-0.5, 0), p=0.5), | ||
| 38 | A.RandomBrightnessContrast(brightness_limit=(-0.5, 0), contrast_limit=0, p=0.5), | ||
| 39 | A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0), | ||
| 40 | A.pytorch.transforms.ToTensorV2() | ||
| 41 | ], additional_targets={'label': 'image'}) | ||
| 42 | |||
| 43 | val_transform = A.Compose([ | ||
| 44 | A.Resize(height=self.fixed_size, width=self.fixed_size), | ||
| 45 | A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0), | ||
| 46 | A.pytorch.transforms.ToTensorV2() | ||
| 47 | ], additional_targets={'label': 'image'}) | ||
| 48 | |||
| 49 | transform_fn = {'train': train_transform, 'val': val_transform} | ||
| 50 | |||
| 51 | return transform_fn | ||
| 52 | |||
| 53 | def __len__(self): | ||
| 54 | return len(self.df) | ||
| 55 | |||
| 56 | |||
| 57 | def __getitem__(self, idx): | ||
| 58 | series = self.df.iloc[idx] | ||
| 59 | name = series['name'] | ||
| 60 | |||
| 61 | # img = Image.open(os.path.join(self.img_root, name)) | ||
| 62 | # gt = Image.open(os.path.join(self.gt_root, name)) | ||
| 63 | img = cv2.imread(os.path.join(self.img_root, name)) | ||
| 64 | gt = cv2.imread(os.path.join(self.gt_root, name)) | ||
| 65 | |||
| 66 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
| 67 | gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB) | ||
| 68 | |||
| 69 | transformed = self.transform(image=img, label=gt) | ||
| 70 | img = transformed['image'] | ||
| 71 | label = transformed['label'] | ||
| 72 | |||
| 73 | return img, label | ||
| 74 | |||
| 75 |
core/data/__init__.py
0 → 100644
| 1 | from .builder import build_dataloader |
core/data/builder.py
0 → 100644
| 1 | import copy | ||
| 2 | import random | ||
| 3 | import numpy as np | ||
| 4 | import torch | ||
| 5 | from torch.utils.data import DataLoader, Dataset | ||
| 6 | from torch.utils.data.distributed import DistributedSampler | ||
| 7 | from utils.registery import DATASET_REGISTRY, COLLATE_FN_REGISTRY | ||
| 8 | from .collate_fn import base_collate_fn | ||
| 9 | |||
| 10 | from .ReconData import ReconData | ||
| 11 | |||
| 12 | |||
| 13 | def build_dataset(cfg): | ||
| 14 | |||
| 15 | dataset_cfg = copy.deepcopy(cfg) | ||
| 16 | try: | ||
| 17 | dataset_cfg = dataset_cfg['dataset'] | ||
| 18 | except Exception: | ||
| 19 | raise 'should contain {dataset}!' | ||
| 20 | |||
| 21 | train_cfg = copy.deepcopy(dataset_cfg) | ||
| 22 | val_cfg = copy.deepcopy(dataset_cfg) | ||
| 23 | train_cfg['args']['anno_file'] = train_cfg['args'].pop('train_anno_file') | ||
| 24 | train_cfg['args'].pop('val_anno_file') | ||
| 25 | train_cfg['args']['phase'] = 'train' | ||
| 26 | val_cfg['args']['anno_file'] = val_cfg['args'].pop('val_anno_file') | ||
| 27 | val_cfg['args'].pop('train_anno_file') | ||
| 28 | val_cfg['args']['phase'] = 'val' | ||
| 29 | |||
| 30 | train_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**train_cfg['args']) | ||
| 31 | val_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**val_cfg['args']) | ||
| 32 | |||
| 33 | return train_data, val_data | ||
| 34 | |||
| 35 | |||
| 36 | |||
| 37 | def build_dataloader(cfg): | ||
| 38 | |||
| 39 | dataloader_cfg = copy.deepcopy(cfg) | ||
| 40 | try: | ||
| 41 | dataloader_cfg = cfg['dataloader'] | ||
| 42 | except Exception: | ||
| 43 | raise 'should contain {dataloader}!' | ||
| 44 | |||
| 45 | train_ds, val_ds = build_dataset(cfg) | ||
| 46 | train_sampler = DistributedSampler(train_ds) | ||
| 47 | collate_fn = COLLATE_FN_REGISTRY.get(dataloader_cfg.pop('collate_fn')) | ||
| 48 | |||
| 49 | train_loader = DataLoader(train_ds, | ||
| 50 | sampler=train_sampler, | ||
| 51 | collate_fn=collate_fn, | ||
| 52 | **dataloader_cfg) | ||
| 53 | |||
| 54 | val_loader = DataLoader(val_ds, | ||
| 55 | collate_fn=collate_fn, | ||
| 56 | **dataloader_cfg) | ||
| 57 | |||
| 58 | return train_loader, val_loader | ||
| 59 |
core/data/collate_fn.py
0 → 100644
| 1 | import torch | ||
| 2 | from utils.registery import COLLATE_FN_REGISTRY | ||
| 3 | |||
| 4 | |||
| 5 | @COLLATE_FN_REGISTRY.register() | ||
| 6 | def base_collate_fn(batch): | ||
| 7 | images, labels = list(), list() | ||
| 8 | for image, label in batch: | ||
| 9 | images.append(image.unsqueeze(0)) | ||
| 10 | labels.append(label.unsqueeze(0)) | ||
| 11 | |||
| 12 | images = torch.cat(images, dim=0) | ||
| 13 | labels = torch.cat(labels, dim=0) | ||
| 14 | |||
| 15 | return {'image': images, 'label': labels} | ||
| 16 | |||
| 17 |
core/data/preprocess/create_anno.py
0 → 100644
| 1 | import os | ||
| 2 | import pandas as pd | ||
| 3 | import random | ||
| 4 | |||
| 5 | root = '/data1/lxl/data/ocr/generate1108/' | ||
| 6 | img_path = os.path.join(root, 'img') | ||
| 7 | img_list = os.listdir(img_path) | ||
| 8 | random.shuffle(img_list) | ||
| 9 | train_df = pd.DataFrame(columns=['name']) | ||
| 10 | val_df = pd.DataFrame(columns=['name']) | ||
| 11 | |||
| 12 | train_df['name'] = img_list[:16000] | ||
| 13 | val_df['name'] = img_list[16000:] | ||
| 14 | |||
| 15 | train_df.to_csv(os.path.join(root, 'train.csv')) | ||
| 16 | val_df.to_csv(os.path.join(root, 'val.csv')) |
core/data/preprocess/make_bg.py
0 → 100644
| 1 | import albumentations as A | ||
| 2 | import cv2 | ||
| 3 | import os | ||
| 4 | from tqdm import tqdm | ||
| 5 | |||
| 6 | transform = A.Compose([ | ||
| 7 | A.RandomResizedCrop(height=400, width=400, scale=(0.1, 0.2)), | ||
| 8 | A.HorizontalFlip(p=0.5), | ||
| 9 | A.VerticalFlip(p=0.5), | ||
| 10 | ]) | ||
| 11 | |||
| 12 | root = '/data1/lxl/data/ocr/doc_remove_stamp/' | ||
| 13 | save_path = '/data1/lxl/data/ocr/crop_bg/' | ||
| 14 | |||
| 15 | date_dirs = sorted(os.listdir(root)) | ||
| 16 | |||
| 17 | img_list = list() | ||
| 18 | |||
| 19 | img_names = sorted(os.listdir(root)) | ||
| 20 | for img_name in img_names: | ||
| 21 | img_path = os.path.join(root, img_name) | ||
| 22 | img_list.append(img_path) | ||
| 23 | |||
| 24 | print(f'src img number: {len(img_list)}') | ||
| 25 | |||
| 26 | cnt = 0 | ||
| 27 | for crt_img in tqdm(img_list): | ||
| 28 | try: | ||
| 29 | img = cv2.imread(crt_img) | ||
| 30 | for _ in range(5): | ||
| 31 | transformed = transform(image=img) | ||
| 32 | transformed_img = transformed['image'] | ||
| 33 | cv2.imwrite(os.path.join(save_path, f'{cnt :06d}.jpg'), transformed_img) | ||
| 34 | cnt += 1 | ||
| 35 | except Exception: | ||
| 36 | continue | ||
| 37 |
core/data/preprocess/make_stamp_mask.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | |||
| 4 | stamp_root = '/data1/lxl/data/ocr/stamp/src' | ||
| 5 | mask_root = '/data1/lxl/data/ocr/stamp/mask' | ||
| 6 | |||
| 7 | |||
| 8 | def read_img_list(root): | ||
| 9 | img_full_path_list = list() | ||
| 10 | img_list = os.listdir(root) | ||
| 11 | for img_name in img_list: | ||
| 12 | img_full_path_list.append(os.path.join(root, img_name)) | ||
| 13 | |||
| 14 | return img_full_path_list | ||
| 15 | |||
| 16 | |||
| 17 | def get_mask_list(stamp_list): | ||
| 18 | for stamp in stamp_list: | ||
| 19 | img_name = stamp.split('/')[-1] | ||
| 20 | img = cv2.imread(stamp, -1) | ||
| 21 | mask = img[:, :, -1] | ||
| 22 | cv2.imwrite(os.path.join(mask_root, img_name), mask) | ||
| 23 | |||
| 24 | |||
| 25 | if __name__ == '__main__': | ||
| 26 | full_path_list = read_img_list(stamp_root) | ||
| 27 | get_mask_list(full_path_list) |
core/data/preprocess/paste_stamp.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | from PIL import Image, ImageEnhance | ||
| 4 | import numpy as np | ||
| 5 | import random | ||
| 6 | import shutil | ||
| 7 | from tqdm import tqdm | ||
| 8 | import json | ||
| 9 | import multiprocessing as mp | ||
| 10 | import imgaug.augmenters as iaa | ||
| 11 | |||
| 12 | |||
| 13 | def mkdir(path): | ||
| 14 | if not os.path.exists(path): | ||
| 15 | os.makedirs(path) | ||
| 16 | |||
| 17 | |||
| 18 | stamp_img_root = '/data1/lxl/data/ocr/stamp/aug/img' | ||
| 19 | stamp_mask_root = '/data1/lxl/data/ocr/stamp/aug/mask' | ||
| 20 | stamp_blend_root = '/data1/lxl/data/ocr/stamp/aug/blend' | ||
| 21 | text_root = '/data1/lxl/data/ocr/crop_bg' | ||
| 22 | gen_root = '/data1/lxl/data/ocr/generate1108/' | ||
| 23 | gen_img_root = os.path.join(gen_root, 'img') | ||
| 24 | gen_stamp_img_root = os.path.join(gen_root, 'stamp_img') | ||
| 25 | gen_stamp_mask_root = os.path.join(gen_root, 'stamp_mask') | ||
| 26 | gen_text_img_root = os.path.join(gen_root, 'text_img') | ||
| 27 | mkdir(gen_img_root) | ||
| 28 | mkdir(gen_text_img_root) | ||
| 29 | mkdir(gen_stamp_img_root) | ||
| 30 | mkdir(gen_stamp_mask_root) | ||
| 31 | |||
| 32 | |||
| 33 | def random_idx(s, e): | ||
| 34 | idx = int(np.random.randint(s, e, size=(1))) | ||
| 35 | |||
| 36 | return idx | ||
| 37 | |||
| 38 | |||
| 39 | def get_full_path_list(root): | ||
| 40 | path_list = list() | ||
| 41 | name_list = sorted(os.listdir(root)) | ||
| 42 | for name in name_list: | ||
| 43 | path_list.append(os.path.join(root, name)) | ||
| 44 | |||
| 45 | return path_list | ||
| 46 | |||
| 47 | |||
| 48 | def gen(stamp_img, blend_mask, stamp_mask, text_img, gen_img_root, gen_text_img_root, gen_stamp_mask_root, savename): | ||
| 49 | stamp_img = Image.open(stamp_img).convert("RGB") | ||
| 50 | blend_mask = Image.open(blend_mask) | ||
| 51 | |||
| 52 | stamp_img_width, stamp_img_height = stamp_img.size | ||
| 53 | stamp_mask = Image.open(stamp_mask).convert('L') | ||
| 54 | stamp_mask_copy = stamp_mask.copy().convert('L') | ||
| 55 | text_img = Image.open(text_img).convert("RGB") | ||
| 56 | gen_img = text_img.copy().convert("RGB") | ||
| 57 | x = random_idx(0, text_img.size[0] - stamp_img.size[0]) | ||
| 58 | y = random_idx(0, text_img.size[1] - stamp_img.size[1]) | ||
| 59 | |||
| 60 | gen_img.paste(stamp_img, (x, y), mask=blend_mask) | ||
| 61 | |||
| 62 | gen_stamp_img = Image.new('RGB', size=text_img.size) | ||
| 63 | gen_stamp_img.paste(stamp_img, (x, y), mask=blend_mask) | ||
| 64 | gen_stamp_mask = Image.new('L', size=text_img.size) | ||
| 65 | gen_stamp_mask.paste(stamp_mask, (x, y), mask=stamp_mask) | ||
| 66 | stamp_coordinate = [x, y, x + stamp_img.size[0], y + stamp_img.size[1]] | ||
| 67 | stamp_dict = {'name': str(savename), 'coordinate': stamp_coordinate, 'label': ''} | ||
| 68 | |||
| 69 | gen_img.save(os.path.join(gen_img_root, "{:>06d}.jpg".format(savename))) | ||
| 70 | text_img.save(os.path.join(gen_text_img_root, "{:>06d}.jpg".format(savename))) | ||
| 71 | gen_stamp_img.save(os.path.join(gen_stamp_img_root, "{:>06d}.jpg".format(savename))) | ||
| 72 | gen_stamp_mask.save(os.path.join(gen_stamp_mask_root, "{:>06d}.jpg".format(savename))) | ||
| 73 | |||
| 74 | |||
| 75 | def process(): | ||
| 76 | stamp_list = sorted(os.listdir(stamp_img_root)) | ||
| 77 | stamp_list_lth = len(stamp_list) | ||
| 78 | text_list = sorted(os.listdir(text_root)) | ||
| 79 | text_list_lth = len(text_list) | ||
| 80 | need = 20000 | ||
| 81 | pool = mp.Pool(processes=6) | ||
| 82 | for i in range(0, need): | ||
| 83 | stamp_idx = random_idx(0, stamp_list_lth) | ||
| 84 | stamp_img_path = os.path.join(stamp_img_root, stamp_list[stamp_idx]) | ||
| 85 | stamp_mask_path = os.path.join(stamp_mask_root, stamp_list[stamp_idx]) | ||
| 86 | blend_mask_path = os.path.join(stamp_blend_root, stamp_list[stamp_idx]) | ||
| 87 | text_idx = random_idx(0, text_list_lth) | ||
| 88 | text_img_path = os.path.join(text_root, text_list[text_idx]) | ||
| 89 | pool.apply_async(gen, (stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,)) | ||
| 90 | # gen(stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,) | ||
| 91 | pool.close() | ||
| 92 | pool.join() | ||
| 93 | |||
| 94 | |||
| 95 | def main(): | ||
| 96 | process() | ||
| 97 | |||
| 98 | if __name__ == '__main__': | ||
| 99 | main() |
signatureDet @ 2dde6823
| 1 | Subproject commit 2dde682312388f5c5099d161019cd97df06315e4 |
core/data/preprocess/stamp_aug.py
0 → 100644
| 1 | import imgaug.augmenters as iaa | ||
| 2 | import numpy as np | ||
| 3 | import cv2 | ||
| 4 | import os | ||
| 5 | from tqdm import tqdm | ||
| 6 | from PIL import Image | ||
| 7 | |||
| 8 | seq = iaa.Sequential( | ||
| 9 | [ | ||
| 10 | iaa.Fliplr(0.5), | ||
| 11 | iaa.Crop(percent=(0, 0.1), keep_size=True), | ||
| 12 | iaa.Sometimes(0.5, iaa.GaussianBlur(sigma=(0, 0.5))), | ||
| 13 | iaa.AddElementwise((-40, -10), per_channel=0.5), | ||
| 14 | iaa.Sometimes(0.5, iaa.MultiplyElementwise((0.7, 1.0))), | ||
| 15 | iaa.OneOf([ | ||
| 16 | iaa.Rotate((-45, 45)), | ||
| 17 | iaa.Rot90((1, 3)) | ||
| 18 | ]), | ||
| 19 | iaa.Sometimes(0.7, iaa.CoarseDropout(p=(0.1, 0.4), size_percent=(0.02, 0.2))), | ||
| 20 | iaa.Sometimes(0.02, iaa.imgcorruptlike.MotionBlur(severity=2)), | ||
| 21 | ]) | ||
| 22 | |||
| 23 | gen_blend = iaa.Sequential([ | ||
| 24 | iaa.GaussianBlur(sigma=(0, 0.5)), | ||
| 25 | iaa.MultiplyElementwise((0.8, 3.5)), | ||
| 26 | ]) | ||
| 27 | |||
| 28 | |||
| 29 | img_root = '/data1/lxl/data/ocr/stamp/src' | ||
| 30 | gen_root = '/data1/lxl/data/ocr/stamp/aug' | ||
| 31 | gen_img_root = os.path.join(gen_root, 'img') | ||
| 32 | gen_mask_root = os.path.join(gen_root, 'mask') | ||
| 33 | gen_blend_root = os.path.join(gen_root, 'blend') | ||
| 34 | if not os.path.exists(gen_img_root): | ||
| 35 | os.makedirs(gen_img_root) | ||
| 36 | if not os.path.exists(gen_mask_root): | ||
| 37 | os.makedirs(gen_mask_root) | ||
| 38 | if not os.path.exists(gen_blend_root): | ||
| 39 | os.makedirs(gen_blend_root) | ||
| 40 | |||
| 41 | name_list = sorted(os.listdir(img_root)) | ||
| 42 | |||
| 43 | for i in range(2): | ||
| 44 | for name in tqdm(name_list): | ||
| 45 | name_no_ex = name.split('.')[0] | ||
| 46 | ext = name.split('.')[1] | ||
| 47 | img = cv2.imread(os.path.join(img_root, name), -1)[:, :, :3] | ||
| 48 | mask = cv2.imread(os.path.join(img_root, name), -1)[:, :, -1] | ||
| 49 | mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) | ||
| 50 | img = np.asarray(img, dtype=np.uint8) | ||
| 51 | mask = np.asarray(mask, dtype=np.uint8) | ||
| 52 | img = img[np.newaxis, :] | ||
| 53 | mask = mask[np.newaxis, :] | ||
| 54 | img_aug, mask_aug = seq(images=img, segmentation_maps=mask) | ||
| 55 | img_aug = img_aug.squeeze() | ||
| 56 | mask_aug = mask_aug.squeeze(0) | ||
| 57 | |||
| 58 | blend = cv2.cvtColor(img_aug, cv2.COLOR_BGR2GRAY) | ||
| 59 | blend_aug = gen_blend(images=blend) | ||
| 60 | blend_aug = blend_aug.squeeze() | ||
| 61 | |||
| 62 | cv2.imwrite(os.path.join(gen_img_root, (name_no_ex + '_' + str(i) + '.' + ext)), img_aug) | ||
| 63 | cv2.imwrite(os.path.join(gen_mask_root, (name_no_ex + '_' + str(i) + '.' + ext)), mask_aug) | ||
| 64 | cv2.imwrite(os.path.join(gen_blend_root, (name_no_ex + '_' + str(i) + '.' + ext)), blend_aug) |
core/loss/__init__.py
0 → 100644
core/loss/builder.py
0 → 100644
| 1 | import torch | ||
| 2 | import inspect | ||
| 3 | from utils.registery import LOSS_REGISTRY | ||
| 4 | import copy | ||
| 5 | |||
| 6 | |||
| 7 | def register_torch_loss(): | ||
| 8 | for module_name in dir(torch.nn): | ||
| 9 | if module_name.startswith('__') or 'Loss' not in module_name: | ||
| 10 | continue | ||
| 11 | _loss = getattr(torch.nn, module_name) | ||
| 12 | if inspect.isclass(_loss) and issubclass(_loss, torch.nn.Module): | ||
| 13 | LOSS_REGISTRY.register()(_loss) | ||
| 14 | |||
| 15 | def build_loss(cfg): | ||
| 16 | register_torch_loss() | ||
| 17 | loss_cfg = copy.deepcopy(cfg) | ||
| 18 | |||
| 19 | try: | ||
| 20 | loss_cfg = cfg['solver']['loss'] | ||
| 21 | except Exception: | ||
| 22 | raise 'should contain {solver.loss}!' | ||
| 23 | |||
| 24 | return LOSS_REGISTRY.get(loss_cfg['name'])(**loss_cfg['args']) |
core/metric/__init__.py
0 → 100644
core/metric/builder.py
0 → 100644
| 1 | from utils.registery import METRIC_REGISTRY | ||
| 2 | import copy | ||
| 3 | |||
| 4 | from .recon_metric import Recon | ||
| 5 | |||
| 6 | def build_metric(cfg): | ||
| 7 | cfg = copy.deepcopy(cfg) | ||
| 8 | try: | ||
| 9 | metric_cfg = cfg['solver']['metric'] | ||
| 10 | except Exception: | ||
| 11 | raise 'should contain {solver.metric}!' | ||
| 12 | |||
| 13 | return METRIC_REGISTRY.get(metric_cfg['name'])() |
core/metric/recon_metric.py
0 → 100644
| 1 | import numpy as np | ||
| 2 | import torchmetrics | ||
| 3 | |||
| 4 | |||
| 5 | from utils.registery import METRIC_REGISTRY | ||
| 6 | |||
| 7 | @METRIC_REGISTRY.register() | ||
| 8 | class Recon(object): | ||
| 9 | def __init__(self): | ||
| 10 | self.psnr = torchmetrics.PeakSignalNoiseRatio() | ||
| 11 | self.ssim = torchmetrics.StructuralSimilarityIndexMeasure() | ||
| 12 | |||
| 13 | def __call__(self, pred, label): | ||
| 14 | |||
| 15 | assert pred.shape[0] == label.shape[0] | ||
| 16 | |||
| 17 | psnr_list = list() | ||
| 18 | ssim_list = list() | ||
| 19 | |||
| 20 | for i in range(len(pred)): | ||
| 21 | psnr_list.append(self.psnr(pred[i].unsqueeze(0), label[i].unsqueeze(0))) | ||
| 22 | ssim_list.append(self.ssim(pred[i].unsqueeze(0), label[i].unsqueeze(0))) | ||
| 23 | |||
| 24 | psnr_result = sum(psnr_list) / len(psnr_list) | ||
| 25 | ssim_result = sum(ssim_list) / len(ssim_list) | ||
| 26 | |||
| 27 | return {'psnr': psnr_result, 'ssim': ssim_result} | ||
| 28 | |||
| 29 |
core/model/__init__.py
0 → 100644
core/model/base_model.py
0 → 100644
| 1 | import torch | ||
| 2 | import torch.nn as nn | ||
| 3 | from abc import ABCMeta | ||
| 4 | import math | ||
| 5 | |||
| 6 | |||
| 7 | class BaseModel(nn.Module, metaclass=ABCMeta): | ||
| 8 | """ | ||
| 9 | Base model class | ||
| 10 | """ | ||
| 11 | def __init__(self): | ||
| 12 | super().__init__() | ||
| 13 | |||
| 14 | def _initialize_weights(self): | ||
| 15 | for m in self.modules(): | ||
| 16 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | ||
| 17 | nn.init.xavier_uniform_(m.weight) | ||
| 18 | if m.bias is not None: | ||
| 19 | nn.init.constant_(m.bias, 0) | ||
| 20 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): | ||
| 21 | nn.init.constant_(m.weight, 1) | ||
| 22 | nn.init.constant_(m.bias, 0) | ||
| 23 | elif isinstance(m, nn.Linear): | ||
| 24 | nn.init.xavier_uniform_(m.weight) | ||
| 25 | if m.bias is not None: | ||
| 26 | nn.init.constant_(m.bias, 0) | ||
| 27 |
core/model/builder.py
0 → 100644
| 1 | import copy | ||
| 2 | from utils import MODEL_REGISTRY | ||
| 3 | |||
| 4 | from .unet import Unet | ||
| 5 | from .unet_skip import UnetSkip | ||
| 6 | |||
| 7 | |||
| 8 | def build_model(cfg): | ||
| 9 | model_cfg = copy.deepcopy(cfg) | ||
| 10 | try: | ||
| 11 | model_cfg = model_cfg['model'] | ||
| 12 | except Exception: | ||
| 13 | raise 'should contain {model}' | ||
| 14 | |||
| 15 | model = MODEL_REGISTRY.get(model_cfg['name'])(**model_cfg['args']) | ||
| 16 | |||
| 17 | return model | ||
| 18 |
core/model/unet.py
0 → 100644
| 1 | import segmentation_models_pytorch as smp | ||
| 2 | |||
| 3 | from utils.registery import MODEL_REGISTRY | ||
| 4 | from core.model.base_model import BaseModel | ||
| 5 | |||
| 6 | |||
| 7 | @MODEL_REGISTRY.register() | ||
| 8 | class Unet(BaseModel): | ||
| 9 | def __init__(self, | ||
| 10 | encoder_name: str = 'resnet50', | ||
| 11 | encoder_weights: str = 'imagenet', | ||
| 12 | in_channels: int = 3, | ||
| 13 | classes: int = 3, | ||
| 14 | activation: str = 'tanh'): | ||
| 15 | super().__init__() | ||
| 16 | self.model = smp.Unet( | ||
| 17 | encoder_name=encoder_name, | ||
| 18 | encoder_weights=encoder_weights, | ||
| 19 | in_channels=in_channels, | ||
| 20 | classes=classes, | ||
| 21 | activation=activation | ||
| 22 | ) | ||
| 23 | |||
| 24 | self._initialize_weights() | ||
| 25 | |||
| 26 | def forward(self, x): | ||
| 27 | out = x + self.model(x) | ||
| 28 | # out = self.model(x) | ||
| 29 | |||
| 30 | return out | ||
| 31 | |||
| 32 |
core/model/unet_skip.py
0 → 100644
| 1 | import segmentation_models_pytorch as smp | ||
| 2 | |||
| 3 | from utils.registery import MODEL_REGISTRY | ||
| 4 | from core.model.base_model import BaseModel | ||
| 5 | |||
| 6 | |||
| 7 | @MODEL_REGISTRY.register() | ||
| 8 | class UnetSkip(BaseModel): | ||
| 9 | def __init__(self, | ||
| 10 | encoder_name: str = 'resnet50', | ||
| 11 | encoder_weights: str = 'imagenet', | ||
| 12 | in_channels: int = 3, | ||
| 13 | classes: int = 3, | ||
| 14 | activation: str = 'tanh'): | ||
| 15 | super().__init__() | ||
| 16 | self.model = smp.Unet( | ||
| 17 | encoder_name=encoder_name, | ||
| 18 | encoder_weights=encoder_weights, | ||
| 19 | in_channels=in_channels, | ||
| 20 | classes=classes, | ||
| 21 | activation=activation | ||
| 22 | ) | ||
| 23 | |||
| 24 | self._initialize_weights() | ||
| 25 | |||
| 26 | def forward(self, x): | ||
| 27 | residual = self.model(x) | ||
| 28 | reconstruction = x + residual | ||
| 29 | |||
| 30 | returned_dict = {'reconstruction': reconstruction, 'residual': residual} | ||
| 31 | |||
| 32 | return returned_dict | ||
| 33 | |||
| 34 |
core/optimizer/__init__.py
0 → 100644
core/optimizer/builder.py
0 → 100644
| 1 | import torch | ||
| 2 | import inspect | ||
| 3 | from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY | ||
| 4 | import copy | ||
| 5 | |||
| 6 | def register_torch_optimizers(): | ||
| 7 | """ | ||
| 8 | Register all optimizers implemented by torch | ||
| 9 | """ | ||
| 10 | for module_name in dir(torch.optim): | ||
| 11 | if module_name.startswith('__'): | ||
| 12 | continue | ||
| 13 | _optim = getattr(torch.optim, module_name) | ||
| 14 | if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): | ||
| 15 | OPTIMIZER_REGISTRY.register()(_optim) | ||
| 16 | |||
| 17 | def build_optimizer(cfg): | ||
| 18 | register_torch_optimizers() | ||
| 19 | optimizer_cfg = copy.deepcopy(cfg) | ||
| 20 | |||
| 21 | try: | ||
| 22 | optimizer_cfg = cfg['solver']['optimizer'] | ||
| 23 | except Exception: | ||
| 24 | raise 'should contain {solver.optimizer}!' | ||
| 25 | |||
| 26 | return OPTIMIZER_REGISTRY.get(optimizer_cfg['name']) | ||
| 27 | |||
| 28 | def register_torch_lr_scheduler(): | ||
| 29 | """ | ||
| 30 | Register all lr_schedulers implemented by torch | ||
| 31 | """ | ||
| 32 | for module_name in dir(torch.optim.lr_scheduler): | ||
| 33 | if module_name.startswith('__'): | ||
| 34 | continue | ||
| 35 | |||
| 36 | _scheduler = getattr(torch.optim.lr_scheduler, module_name) | ||
| 37 | if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler): | ||
| 38 | LR_SCHEDULER_REGISTRY.register()(_scheduler) | ||
| 39 | |||
| 40 | def build_lr_scheduler(cfg): | ||
| 41 | register_torch_lr_scheduler() | ||
| 42 | scheduler_cfg = copy.deepcopy(cfg) | ||
| 43 | |||
| 44 | try: | ||
| 45 | scheduler_cfg = cfg['solver']['lr_scheduler'] | ||
| 46 | except Exception: | ||
| 47 | raise 'should contain {solver.lr_scheduler}!' | ||
| 48 | |||
| 49 | return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name']) | ||
| 50 |
core/solver/__init__.py
0 → 100644
core/solver/base_solver.py
0 → 100644
| 1 | import torch | ||
| 2 | from core.data import build_dataloader | ||
| 3 | from core.model import build_model | ||
| 4 | from core.optimizer import build_optimizer, build_lr_scheduler | ||
| 5 | from core.loss import build_loss | ||
| 6 | from core.metric import build_metric | ||
| 7 | from utils.registery import SOLVER_REGISTRY | ||
| 8 | from utils.logger import get_logger_and_log_path | ||
| 9 | import os | ||
| 10 | import copy | ||
| 11 | import datetime | ||
| 12 | from torch.nn.parallel import DistributedDataParallel | ||
| 13 | import numpy as np | ||
| 14 | import pandas as pd | ||
| 15 | import yaml | ||
| 16 | |||
| 17 | |||
| 18 | @SOLVER_REGISTRY.register() | ||
| 19 | class BaseSolver(object): | ||
| 20 | def __init__(self, cfg): | ||
| 21 | self.cfg = copy.deepcopy(cfg) | ||
| 22 | self.local_rank = torch.distributed.get_rank() | ||
| 23 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
| 24 | self.len_train_loader, self.len_val_loader = len(self.train_loader), len(self.val_loader) | ||
| 25 | self.criterion = build_loss(cfg).cuda(self.local_rank) | ||
| 26 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(build_model(cfg)) | ||
| 27 | self.model = DistributedDataParallel(model.cuda(self.local_rank), device_ids=[self.local_rank], find_unused_parameters=True) | ||
| 28 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
| 29 | self.hyper_params = cfg['solver']['args'] | ||
| 30 | crt_date = datetime.date.today().strftime('%Y-%m-%d') | ||
| 31 | self.logger, self.log_path = get_logger_and_log_path(crt_date=crt_date, **cfg['solver']['logger']) | ||
| 32 | self.metric_fn = build_metric(cfg) | ||
| 33 | try: | ||
| 34 | self.epoch = self.hyper_params['epoch'] | ||
| 35 | except Exception: | ||
| 36 | raise 'should contain epoch in {solver.args}' | ||
| 37 | if self.local_rank == 0: | ||
| 38 | self.save_dict_to_yaml(self.cfg, os.path.join(self.log_path, 'config.yaml')) | ||
| 39 | self.logger.info(self.cfg) | ||
| 40 | |||
| 41 | def train(self): | ||
| 42 | if torch.distributed.get_rank() == 0: | ||
| 43 | self.logger.info('==> Start Training') | ||
| 44 | lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
| 45 | |||
| 46 | for t in range(self.epoch): | ||
| 47 | self.train_loader.sampler.set_epoch(t) | ||
| 48 | if torch.distributed.get_rank() == 0: | ||
| 49 | self.logger.info(f'==> epoch {t + 1}') | ||
| 50 | self.model.train() | ||
| 51 | |||
| 52 | pred_list = list() | ||
| 53 | label_list = list() | ||
| 54 | |||
| 55 | mean_loss = 0.0 | ||
| 56 | |||
| 57 | for i, data in enumerate(self.train_loader): | ||
| 58 | self.optimizer.zero_grad() | ||
| 59 | image = data['image'].cuda(self.local_rank) | ||
| 60 | label = data['label'].cuda(self.local_rank) | ||
| 61 | |||
| 62 | pred = self.model(image) | ||
| 63 | |||
| 64 | loss = self.criterion(pred, label) | ||
| 65 | mean_loss += loss.item() | ||
| 66 | |||
| 67 | if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0): | ||
| 68 | loss_value = loss.item() | ||
| 69 | self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, loss: {loss_value :.4f}') | ||
| 70 | |||
| 71 | loss.backward() | ||
| 72 | self.optimizer.step() | ||
| 73 | |||
| 74 | # batch_pred = [torch.zeros_like(pred) for _ in range(torch.distributed.get_world_size())] # 1 | ||
| 75 | # torch.distributed.all_gather(batch_pred, pred) | ||
| 76 | # pred_list.append(torch.cat(batch_pred, dim=0).detach().cpu()) | ||
| 77 | |||
| 78 | # batch_label = [torch.zeros_like(label) for _ in range(torch.distributed.get_world_size())] | ||
| 79 | # torch.distributed.all_gather(batch_label, label) | ||
| 80 | # label_list.append(torch.cat(batch_label, dim=0).detach().cpu()) | ||
| 81 | |||
| 82 | # pred_list = torch.cat(pred_list, dim=0) | ||
| 83 | # label_list = torch.cat(label_list, dim=0) | ||
| 84 | # metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list}) | ||
| 85 | mean_loss = mean_loss / self.len_train_loader | ||
| 86 | |||
| 87 | if torch.distributed.get_rank() == 0: | ||
| 88 | # self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}") | ||
| 89 | self.logger.info(f'==> train mean loss: {mean_loss :.4f}') | ||
| 90 | self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1) | ||
| 91 | self.val(t + 1) | ||
| 92 | lr_scheduler.step() | ||
| 93 | |||
| 94 | if self.local_rank == 0: | ||
| 95 | self.logger.info('==> End Training') | ||
| 96 | |||
| 97 | @torch.no_grad() | ||
| 98 | def val(self, t): | ||
| 99 | self.model.eval() | ||
| 100 | |||
| 101 | pred_list = list() | ||
| 102 | label_list = list() | ||
| 103 | |||
| 104 | for i, data in enumerate(self.val_loader): | ||
| 105 | feat = data['image'].cuda(self.local_rank) | ||
| 106 | label = data['label'].cuda(self.local_rank) | ||
| 107 | |||
| 108 | pred = self.model(feat) | ||
| 109 | |||
| 110 | pred_list.append(pred.detach().cpu()) | ||
| 111 | label_list.append(label.detach().cpu()) | ||
| 112 | |||
| 113 | pred_list = torch.cat(pred_list, dim=0) | ||
| 114 | label_list = torch.cat(label_list, dim=0) | ||
| 115 | |||
| 116 | metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list}) | ||
| 117 | if torch.distributed.get_rank() == 0: | ||
| 118 | self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}") | ||
| 119 | |||
| 120 | def run(self): | ||
| 121 | self.train() | ||
| 122 | |||
| 123 | @staticmethod | ||
| 124 | def save_dict_to_yaml(dict_value, save_path): | ||
| 125 | with open(save_path, 'w', encoding='utf-8') as file: | ||
| 126 | yaml.dump(dict_value, file, sort_keys=False) | ||
| 127 | |||
| 128 | |||
| 129 | def save_checkpoint(self, model, cfg, log_path, epoch_id): | ||
| 130 | model.eval() | ||
| 131 | torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt')) |
core/solver/builder.py
0 → 100644
| 1 | from utils.registery import SOLVER_REGISTRY | ||
| 2 | import copy | ||
| 3 | |||
| 4 | from .base_solver import BaseSolver | ||
| 5 | from .skip_solver import SkipSolver | ||
| 6 | |||
| 7 | |||
| 8 | def build_solver(cfg): | ||
| 9 | cfg = copy.deepcopy(cfg) | ||
| 10 | |||
| 11 | try: | ||
| 12 | solver_cfg = cfg['solver'] | ||
| 13 | except Exception: | ||
| 14 | raise 'should contain {solver}!' | ||
| 15 | |||
| 16 | return SOLVER_REGISTRY.get(solver_cfg['name'])(cfg) |
core/solver/skip_solver.py
0 → 100644
| 1 | import torch | ||
| 2 | from core.data import build_dataloader | ||
| 3 | from core.model import build_model | ||
| 4 | from core.optimizer import build_optimizer, build_lr_scheduler | ||
| 5 | from core.loss import build_loss | ||
| 6 | from core.metric import build_metric | ||
| 7 | from utils.registery import SOLVER_REGISTRY | ||
| 8 | from utils.logger import get_logger_and_log_path | ||
| 9 | import os | ||
| 10 | import copy | ||
| 11 | import datetime | ||
| 12 | from torch.nn.parallel import DistributedDataParallel | ||
| 13 | import numpy as np | ||
| 14 | import pandas as pd | ||
| 15 | import yaml | ||
| 16 | |||
| 17 | from .base_solver import BaseSolver | ||
| 18 | |||
| 19 | |||
| 20 | @SOLVER_REGISTRY.register() | ||
| 21 | class SkipSolver(BaseSolver): | ||
| 22 | def __init__(self, cfg): | ||
| 23 | super().__init__(cfg) | ||
| 24 | |||
| 25 | def train(self): | ||
| 26 | if torch.distributed.get_rank() == 0: | ||
| 27 | self.logger.info('==> Start Training') | ||
| 28 | lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
| 29 | |||
| 30 | for t in range(self.epoch): | ||
| 31 | self.train_loader.sampler.set_epoch(t) | ||
| 32 | if torch.distributed.get_rank() == 0: | ||
| 33 | self.logger.info(f'==> epoch {t + 1}') | ||
| 34 | self.model.train() | ||
| 35 | |||
| 36 | pred_list = list() | ||
| 37 | label_list = list() | ||
| 38 | |||
| 39 | mean_loss = 0.0 | ||
| 40 | |||
| 41 | for i, data in enumerate(self.train_loader): | ||
| 42 | self.optimizer.zero_grad() | ||
| 43 | image = data['image'].cuda(self.local_rank) | ||
| 44 | label = data['label'].cuda(self.local_rank) | ||
| 45 | residual = image - label | ||
| 46 | |||
| 47 | pred = self.model(image) | ||
| 48 | |||
| 49 | reconstruction_loss = self.criterion(pred['reconstruction'], label) | ||
| 50 | residual_loss = self.criterion(pred['residual'], residual) | ||
| 51 | loss = reconstruction_loss + residual_loss | ||
| 52 | mean_loss += loss.item() | ||
| 53 | |||
| 54 | if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0): | ||
| 55 | reconstruction_loss_value = reconstruction_loss.item() | ||
| 56 | residual_loss_value = residual_loss.item() | ||
| 57 | loss_value = loss.item() | ||
| 58 | self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, reconstruction loss: {reconstruction_loss_value :.4f}, residual loss: {residual_loss_value :.4f}, loss: {loss_value :.4f}') | ||
| 59 | |||
| 60 | loss.backward() | ||
| 61 | self.optimizer.step() | ||
| 62 | |||
| 63 | mean_loss = mean_loss / self.len_train_loader | ||
| 64 | |||
| 65 | if torch.distributed.get_rank() == 0: | ||
| 66 | # self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}") | ||
| 67 | self.logger.info(f'==> train mean loss: {mean_loss :.4f}') | ||
| 68 | self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1) | ||
| 69 | self.val(t + 1) | ||
| 70 | lr_scheduler.step() | ||
| 71 | |||
| 72 | if self.local_rank == 0: | ||
| 73 | self.logger.info('==> End Training') | ||
| 74 | |||
| 75 | @torch.no_grad() | ||
| 76 | def val(self, t): | ||
| 77 | self.model.eval() | ||
| 78 | |||
| 79 | pred_list = list() | ||
| 80 | label_list = list() | ||
| 81 | |||
| 82 | for i, data in enumerate(self.val_loader): | ||
| 83 | image = data['image'].cuda(self.local_rank) | ||
| 84 | label = data['label'].cuda(self.local_rank) | ||
| 85 | residual = image - label | ||
| 86 | |||
| 87 | pred = self.model(image) | ||
| 88 | |||
| 89 | pred_list.append(pred['reconstruction'].detach().cpu()) | ||
| 90 | label_list.append(label.detach().cpu()) | ||
| 91 | |||
| 92 | pred_list = torch.cat(pred_list, dim=0) | ||
| 93 | label_list = torch.cat(label_list, dim=0) | ||
| 94 | |||
| 95 | metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list}) | ||
| 96 | if torch.distributed.get_rank() == 0: | ||
| 97 | self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}") | ||
| 98 | |||
| 99 | def run(self): | ||
| 100 | self.train() | ||
| 101 | |||
| 102 | @staticmethod | ||
| 103 | def save_dict_to_yaml(dict_value, save_path): | ||
| 104 | with open(save_path, 'w', encoding='utf-8') as file: | ||
| 105 | yaml.dump(dict_value, file, sort_keys=False) | ||
| 106 | |||
| 107 | |||
| 108 | def save_checkpoint(self, model, cfg, log_path, epoch_id): | ||
| 109 | model.eval() | ||
| 110 | torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt')) |
main.py
0 → 100644
| 1 | import yaml | ||
| 2 | from core.solver import build_solver | ||
| 3 | import torch | ||
| 4 | import numpy as np | ||
| 5 | import random | ||
| 6 | import argparse | ||
| 7 | |||
| 8 | |||
| 9 | def init_seed(seed=778): | ||
| 10 | random.seed(seed) | ||
| 11 | np.random.seed(seed) | ||
| 12 | torch.manual_seed(seed) | ||
| 13 | torch.cuda.manual_seed(seed) | ||
| 14 | torch.cuda.manual_seed_all(seed) | ||
| 15 | torch.backends.cudnn.benchmark = False | ||
| 16 | torch.backends.cudnn.deterministic = True | ||
| 17 | |||
| 18 | def main(): | ||
| 19 | parser = argparse.ArgumentParser() | ||
| 20 | parser.add_argument('--config', default='./config/baseline.yaml', type=str, help='config file') | ||
| 21 | parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') | ||
| 22 | args = parser.parse_args() | ||
| 23 | |||
| 24 | cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader) | ||
| 25 | init_seed(cfg['seed']) | ||
| 26 | |||
| 27 | torch.distributed.init_process_group(backend='nccl') | ||
| 28 | torch.cuda.set_device(args.local_rank) | ||
| 29 | |||
| 30 | solver = build_solver(cfg) | ||
| 31 | |||
| 32 | solver.run() | ||
| 33 | |||
| 34 | if __name__ == '__main__': | ||
| 35 | main() |
pred.py
0 → 100644
| 1 | import os | ||
| 2 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
| 3 | import cv2 | ||
| 4 | import numpy as np | ||
| 5 | import torch | ||
| 6 | from PIL import Image | ||
| 7 | import torchvision.transforms as transforms | ||
| 8 | from tqdm import tqdm | ||
| 9 | |||
| 10 | |||
| 11 | from core.model.unet import Unet | ||
| 12 | from core.model.unet_skip import UnetSkip | ||
| 13 | |||
| 14 | # def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/baseline/ckpt_epoch_18.pt'): | ||
| 15 | # def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/residual/ckpt_epoch_28.pt'): | ||
| 16 | def load_model(ckpt='./log/2022-11-11/skip/ckpt_epoch_30.pt'): | ||
| 17 | # model = Unet( | ||
| 18 | # encoder_name='resnet50', | ||
| 19 | # encoder_weights='imagenet', | ||
| 20 | # in_channels=3, | ||
| 21 | # classes=3, | ||
| 22 | # activation='tanh' | ||
| 23 | # ) | ||
| 24 | |||
| 25 | model = UnetSkip( | ||
| 26 | encoder_name='resnet50', | ||
| 27 | encoder_weights='imagenet', | ||
| 28 | in_channels=3, | ||
| 29 | classes=3, | ||
| 30 | activation='tanh' | ||
| 31 | ) | ||
| 32 | |||
| 33 | model.load_state_dict(torch.load(ckpt, map_location='cpu')) | ||
| 34 | model.eval() | ||
| 35 | |||
| 36 | return model | ||
| 37 | |||
| 38 | |||
| 39 | def infer(model, img_path, gen_path): | ||
| 40 | img = cv2.imread(img_path) | ||
| 41 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
| 42 | img = cv2.resize(img, (512, 512)) | ||
| 43 | img = img / 255. | ||
| 44 | img = img.transpose(2, 0, 1).astype(np.float32) | ||
| 45 | img = torch.from_numpy(img).unsqueeze(0) | ||
| 46 | out = model(img) | ||
| 47 | out = out['reconstruction'] | ||
| 48 | out[out < 0] = 0 | ||
| 49 | out[out > 1] = 1 | ||
| 50 | out = out * 255 | ||
| 51 | out = out.detach().cpu().numpy().squeeze(0).transpose(1, 2, 0) | ||
| 52 | out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR) | ||
| 53 | cv2.imwrite(gen_path, out) | ||
| 54 | |||
| 55 | def infer_list(): | ||
| 56 | img_root = '/data1/lxl/data/ocr/real/src/' | ||
| 57 | gen_root = '/data1/lxl/data/ocr/real/removed/' | ||
| 58 | img_list = sorted(os.listdir(img_root)) | ||
| 59 | model = load_model() | ||
| 60 | for img in tqdm(img_list): | ||
| 61 | img_path = os.path.join(img_root, img) | ||
| 62 | gen_path = os.path.join(gen_root, img) | ||
| 63 | infer(model, img_path, gen_path) | ||
| 64 | |||
| 65 | def infer_img(): | ||
| 66 | model = load_model() | ||
| 67 | img_path = '../DocEnTR/real/hetong_006_00.png' | ||
| 68 | gen_path = './out.jpg' | ||
| 69 | infer(model, img_path, gen_path) | ||
| 70 | |||
| 71 | infer_img() |
run.sh
0 → 100644
| 1 | CUDA_VISIBLE_DEVICES=0 nohup python -m torch.distributed.launch --master_port 8999 --nproc_per_node=1 main.py --config ./config/skip.yaml & |
utils/__init__.py
0 → 100644
utils/helper.py
0 → 100644
| 1 | import torch | ||
| 2 | import yaml | ||
| 3 | import os | ||
| 4 | |||
| 5 | |||
| 6 | def save_dict_to_yaml(dict_value, save_path): | ||
| 7 | with open(save_path, 'w', encoding='utf-8') as file: | ||
| 8 | yaml.dump(dict_value, file, sort_keys=False) | ||
| 9 | |||
| 10 | |||
| 11 | def save_checkpoint(model, cfg, log_path, epoch_id): | ||
| 12 | save_dict_to_yaml(cfg, os.path.join(log_path, 'config.yaml')) | ||
| 13 | torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt')) |
utils/logger.py
0 → 100644
| 1 | import loguru | ||
| 2 | import copy | ||
| 3 | import os | ||
| 4 | import datetime | ||
| 5 | |||
| 6 | def get_logger_and_log_path(log_root, | ||
| 7 | crt_date, | ||
| 8 | suffix): | ||
| 9 | """ | ||
| 10 | get logger and log path | ||
| 11 | |||
| 12 | Args: | ||
| 13 | log_root (str): root path of log | ||
| 14 | crt_date (str): formated date name (Y-M-D) | ||
| 15 | suffix (str): log save name | ||
| 16 | |||
| 17 | Returns: | ||
| 18 | logger (loguru.logger): logger object | ||
| 19 | log_path (str): current root log path (with suffix) | ||
| 20 | """ | ||
| 21 | log_path = os.path.join(log_root, crt_date, suffix) | ||
| 22 | if not os.path.exists(log_path): | ||
| 23 | os.makedirs(log_path) | ||
| 24 | |||
| 25 | logger_path = os.path.join(log_path, 'logfile.log') | ||
| 26 | logger = loguru.logger | ||
| 27 | fmt = '{time:YYYY-MM-DD at HH:mm:ss} | {message}' | ||
| 28 | logger.add(logger_path, format=fmt) | ||
| 29 | |||
| 30 | return logger, log_path |
utils/registery.py
0 → 100644
| 1 | class Registry(): | ||
| 2 | """ | ||
| 3 | The registry that provides name -> object mapping, to support third-party | ||
| 4 | users' custom modules. | ||
| 5 | |||
| 6 | """ | ||
| 7 | |||
| 8 | def __init__(self, name): | ||
| 9 | """ | ||
| 10 | Args: | ||
| 11 | name (str): the name of this registry | ||
| 12 | """ | ||
| 13 | self._name = name | ||
| 14 | self._obj_map = {} | ||
| 15 | |||
| 16 | def _do_register(self, name, obj, suffix=None): | ||
| 17 | if isinstance(suffix, str): | ||
| 18 | name = name + '_' + suffix | ||
| 19 | |||
| 20 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " | ||
| 21 | f"in '{self._name}' registry!") | ||
| 22 | self._obj_map[name] = obj | ||
| 23 | |||
| 24 | def register(self, obj=None, suffix=None): | ||
| 25 | """ | ||
| 26 | Register the given object under the the name `obj.__name__`. | ||
| 27 | Can be used as either a decorator or not. | ||
| 28 | See docstring of this class for usage. | ||
| 29 | """ | ||
| 30 | if obj is None: | ||
| 31 | # used as a decorator | ||
| 32 | def deco(func_or_class): | ||
| 33 | name = func_or_class.__name__ | ||
| 34 | self._do_register(name, func_or_class, suffix) | ||
| 35 | return func_or_class | ||
| 36 | |||
| 37 | return deco | ||
| 38 | |||
| 39 | # used as a function call | ||
| 40 | name = obj.__name__ | ||
| 41 | self._do_register(name, obj, suffix) | ||
| 42 | |||
| 43 | def get(self, name, suffix='soulwalker'): | ||
| 44 | ret = self._obj_map.get(name) | ||
| 45 | if ret is None: | ||
| 46 | ret = self._obj_map.get(name + '_' + suffix) | ||
| 47 | print(f'Name {name} is not found, use name: {name}_{suffix}!') | ||
| 48 | if ret is None: | ||
| 49 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") | ||
| 50 | return ret | ||
| 51 | |||
| 52 | def __contains__(self, name): | ||
| 53 | return name in self._obj_map | ||
| 54 | |||
| 55 | def __iter__(self): | ||
| 56 | return iter(self._obj_map.items()) | ||
| 57 | |||
| 58 | def keys(self): | ||
| 59 | return self._obj_map.keys() | ||
| 60 | |||
| 61 | |||
| 62 | DATASET_REGISTRY = Registry('dataset') | ||
| 63 | MODEL_REGISTRY = Registry('model') | ||
| 64 | LOSS_REGISTRY = Registry('loss') | ||
| 65 | METRIC_REGISTRY = Registry('metric') | ||
| 66 | OPTIMIZER_REGISTRY = Registry('optimizer') | ||
| 67 | SOLVER_REGISTRY = Registry('solver') | ||
| 68 | LR_SCHEDULER_REGISTRY = Registry('lr_scheduler') | ||
| 69 | COLLATE_FN_REGISTRY = Registry('collate_fn') |
-
Please register or sign in to post a comment