cdc6a87d by 刘晓龙

init commit

0 parents
# template
\ No newline at end of file
seed: 3407
dataset:
name: 'ReconData'
args:
data_root: '/data1/lxl/data/ocr/generate1108'
train_anno_file: '/data1/lxl/data/ocr/generate1108/train.csv'
val_anno_file: '/data1/lxl/data/ocr/generate1108/val.csv'
fixed_size: 512
dataloader:
batch_size: 8
num_workers: 16
pin_memory: true
collate_fn: 'base_collate_fn'
model:
name: 'Unet'
args:
encoder_name: 'resnet50'
encoder_weights: 'imagenet'
in_channels: 3
classes: 3
activation: 'tanh'
solver:
name: 'BaseSolver'
args:
epoch: 30
optimizer:
name: 'Adam'
args:
lr: !!float 1e-4
weight_decay: !!float 5e-5
lr_scheduler:
name: 'StepLR'
args:
step_size: 15
gamma: 0.1
loss:
name: 'L1Loss'
args:
reduction: 'mean'
logger:
log_root: '/data1/lxl/code/ocr/removal/log'
suffix: 'residual'
metric:
name: 'Recon'
seed: 3407
dataset:
name: 'ReconData'
args:
data_root: '/data1/lxl/data/ocr/generate1108'
train_anno_file: '/data1/lxl/data/ocr/generate1108/train.csv'
val_anno_file: '/data1/lxl/data/ocr/generate1108/val.csv'
fixed_size: 512
dataloader:
batch_size: 8
num_workers: 16
pin_memory: true
collate_fn: 'base_collate_fn'
model:
name: 'UnetSkip'
args:
encoder_name: 'resnet50'
encoder_weights: 'imagenet'
in_channels: 3
classes: 3
activation: 'tanh'
solver:
name: 'SkipSolver'
args:
epoch: 30
optimizer:
name: 'Adam'
args:
lr: !!float 1e-4
weight_decay: !!float 5e-5
lr_scheduler:
name: 'StepLR'
args:
step_size: 15
gamma: 0.1
loss:
name: 'L1Loss'
args:
reduction: 'mean'
logger:
log_root: '/data1/lxl/code/ocr/removal/log'
suffix: 'skip'
metric:
name: 'Recon'
from .solver import build_solver
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torch
import pandas as pd
import os
import random
import cv2
import albumentations as A
import albumentations.pytorch
import numpy as np
from PIL import Image
from utils.registery import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class ReconData(Dataset):
def __init__(self,
data_root: str = '/data1/lxl/data/ocr/generate1108',
anno_file: str = 'train.csv',
fixed_size: int = 448,
phase: str = 'train'):
self.data_root = data_root
self.df = pd.read_csv(anno_file)
self.img_root = os.path.join(data_root, 'img')
self.gt_root = os.path.join(data_root, 'text_img')
self.fixed_size = fixed_size
self.phase = phase
transform_fn = self.__get_transform()
self.transform = transform_fn[phase]
def __get_transform(self):
train_transform = A.Compose([
A.Resize(height=self.fixed_size, width=self.fixed_size),
# A.RandomBrightness(limit=(-0.5, 0), p=0.5),
A.RandomBrightnessContrast(brightness_limit=(-0.5, 0), contrast_limit=0, p=0.5),
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
A.pytorch.transforms.ToTensorV2()
], additional_targets={'label': 'image'})
val_transform = A.Compose([
A.Resize(height=self.fixed_size, width=self.fixed_size),
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
A.pytorch.transforms.ToTensorV2()
], additional_targets={'label': 'image'})
transform_fn = {'train': train_transform, 'val': val_transform}
return transform_fn
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
series = self.df.iloc[idx]
name = series['name']
# img = Image.open(os.path.join(self.img_root, name))
# gt = Image.open(os.path.join(self.gt_root, name))
img = cv2.imread(os.path.join(self.img_root, name))
gt = cv2.imread(os.path.join(self.gt_root, name))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
transformed = self.transform(image=img, label=gt)
img = transformed['image']
label = transformed['label']
return img, label
from .builder import build_dataloader
import copy
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from utils.registery import DATASET_REGISTRY, COLLATE_FN_REGISTRY
from .collate_fn import base_collate_fn
from .ReconData import ReconData
def build_dataset(cfg):
dataset_cfg = copy.deepcopy(cfg)
try:
dataset_cfg = dataset_cfg['dataset']
except Exception:
raise 'should contain {dataset}!'
train_cfg = copy.deepcopy(dataset_cfg)
val_cfg = copy.deepcopy(dataset_cfg)
train_cfg['args']['anno_file'] = train_cfg['args'].pop('train_anno_file')
train_cfg['args'].pop('val_anno_file')
train_cfg['args']['phase'] = 'train'
val_cfg['args']['anno_file'] = val_cfg['args'].pop('val_anno_file')
val_cfg['args'].pop('train_anno_file')
val_cfg['args']['phase'] = 'val'
train_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**train_cfg['args'])
val_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**val_cfg['args'])
return train_data, val_data
def build_dataloader(cfg):
dataloader_cfg = copy.deepcopy(cfg)
try:
dataloader_cfg = cfg['dataloader']
except Exception:
raise 'should contain {dataloader}!'
train_ds, val_ds = build_dataset(cfg)
train_sampler = DistributedSampler(train_ds)
collate_fn = COLLATE_FN_REGISTRY.get(dataloader_cfg.pop('collate_fn'))
train_loader = DataLoader(train_ds,
sampler=train_sampler,
collate_fn=collate_fn,
**dataloader_cfg)
val_loader = DataLoader(val_ds,
collate_fn=collate_fn,
**dataloader_cfg)
return train_loader, val_loader
import torch
from utils.registery import COLLATE_FN_REGISTRY
@COLLATE_FN_REGISTRY.register()
def base_collate_fn(batch):
images, labels = list(), list()
for image, label in batch:
images.append(image.unsqueeze(0))
labels.append(label.unsqueeze(0))
images = torch.cat(images, dim=0)
labels = torch.cat(labels, dim=0)
return {'image': images, 'label': labels}
import os
import pandas as pd
import random
root = '/data1/lxl/data/ocr/generate1108/'
img_path = os.path.join(root, 'img')
img_list = os.listdir(img_path)
random.shuffle(img_list)
train_df = pd.DataFrame(columns=['name'])
val_df = pd.DataFrame(columns=['name'])
train_df['name'] = img_list[:16000]
val_df['name'] = img_list[16000:]
train_df.to_csv(os.path.join(root, 'train.csv'))
val_df.to_csv(os.path.join(root, 'val.csv'))
import albumentations as A
import cv2
import os
from tqdm import tqdm
transform = A.Compose([
A.RandomResizedCrop(height=400, width=400, scale=(0.1, 0.2)),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
])
root = '/data1/lxl/data/ocr/doc_remove_stamp/'
save_path = '/data1/lxl/data/ocr/crop_bg/'
date_dirs = sorted(os.listdir(root))
img_list = list()
img_names = sorted(os.listdir(root))
for img_name in img_names:
img_path = os.path.join(root, img_name)
img_list.append(img_path)
print(f'src img number: {len(img_list)}')
cnt = 0
for crt_img in tqdm(img_list):
try:
img = cv2.imread(crt_img)
for _ in range(5):
transformed = transform(image=img)
transformed_img = transformed['image']
cv2.imwrite(os.path.join(save_path, f'{cnt :06d}.jpg'), transformed_img)
cnt += 1
except Exception:
continue
import os
import cv2
stamp_root = '/data1/lxl/data/ocr/stamp/src'
mask_root = '/data1/lxl/data/ocr/stamp/mask'
def read_img_list(root):
img_full_path_list = list()
img_list = os.listdir(root)
for img_name in img_list:
img_full_path_list.append(os.path.join(root, img_name))
return img_full_path_list
def get_mask_list(stamp_list):
for stamp in stamp_list:
img_name = stamp.split('/')[-1]
img = cv2.imread(stamp, -1)
mask = img[:, :, -1]
cv2.imwrite(os.path.join(mask_root, img_name), mask)
if __name__ == '__main__':
full_path_list = read_img_list(stamp_root)
get_mask_list(full_path_list)
import os
import cv2
from PIL import Image, ImageEnhance
import numpy as np
import random
import shutil
from tqdm import tqdm
import json
import multiprocessing as mp
import imgaug.augmenters as iaa
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
stamp_img_root = '/data1/lxl/data/ocr/stamp/aug/img'
stamp_mask_root = '/data1/lxl/data/ocr/stamp/aug/mask'
stamp_blend_root = '/data1/lxl/data/ocr/stamp/aug/blend'
text_root = '/data1/lxl/data/ocr/crop_bg'
gen_root = '/data1/lxl/data/ocr/generate1108/'
gen_img_root = os.path.join(gen_root, 'img')
gen_stamp_img_root = os.path.join(gen_root, 'stamp_img')
gen_stamp_mask_root = os.path.join(gen_root, 'stamp_mask')
gen_text_img_root = os.path.join(gen_root, 'text_img')
mkdir(gen_img_root)
mkdir(gen_text_img_root)
mkdir(gen_stamp_img_root)
mkdir(gen_stamp_mask_root)
def random_idx(s, e):
idx = int(np.random.randint(s, e, size=(1)))
return idx
def get_full_path_list(root):
path_list = list()
name_list = sorted(os.listdir(root))
for name in name_list:
path_list.append(os.path.join(root, name))
return path_list
def gen(stamp_img, blend_mask, stamp_mask, text_img, gen_img_root, gen_text_img_root, gen_stamp_mask_root, savename):
stamp_img = Image.open(stamp_img).convert("RGB")
blend_mask = Image.open(blend_mask)
stamp_img_width, stamp_img_height = stamp_img.size
stamp_mask = Image.open(stamp_mask).convert('L')
stamp_mask_copy = stamp_mask.copy().convert('L')
text_img = Image.open(text_img).convert("RGB")
gen_img = text_img.copy().convert("RGB")
x = random_idx(0, text_img.size[0] - stamp_img.size[0])
y = random_idx(0, text_img.size[1] - stamp_img.size[1])
gen_img.paste(stamp_img, (x, y), mask=blend_mask)
gen_stamp_img = Image.new('RGB', size=text_img.size)
gen_stamp_img.paste(stamp_img, (x, y), mask=blend_mask)
gen_stamp_mask = Image.new('L', size=text_img.size)
gen_stamp_mask.paste(stamp_mask, (x, y), mask=stamp_mask)
stamp_coordinate = [x, y, x + stamp_img.size[0], y + stamp_img.size[1]]
stamp_dict = {'name': str(savename), 'coordinate': stamp_coordinate, 'label': ''}
gen_img.save(os.path.join(gen_img_root, "{:>06d}.jpg".format(savename)))
text_img.save(os.path.join(gen_text_img_root, "{:>06d}.jpg".format(savename)))
gen_stamp_img.save(os.path.join(gen_stamp_img_root, "{:>06d}.jpg".format(savename)))
gen_stamp_mask.save(os.path.join(gen_stamp_mask_root, "{:>06d}.jpg".format(savename)))
def process():
stamp_list = sorted(os.listdir(stamp_img_root))
stamp_list_lth = len(stamp_list)
text_list = sorted(os.listdir(text_root))
text_list_lth = len(text_list)
need = 20000
pool = mp.Pool(processes=6)
for i in range(0, need):
stamp_idx = random_idx(0, stamp_list_lth)
stamp_img_path = os.path.join(stamp_img_root, stamp_list[stamp_idx])
stamp_mask_path = os.path.join(stamp_mask_root, stamp_list[stamp_idx])
blend_mask_path = os.path.join(stamp_blend_root, stamp_list[stamp_idx])
text_idx = random_idx(0, text_list_lth)
text_img_path = os.path.join(text_root, text_list[text_idx])
pool.apply_async(gen, (stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,))
# gen(stamp_img_path, blend_mask_path, stamp_mask_path, text_img_path, gen_img_root, gen_text_img_root, gen_stamp_mask_root, i,)
pool.close()
pool.join()
def main():
process()
if __name__ == '__main__':
main()
signatureDet @ 2dde6823
Subproject commit 2dde682312388f5c5099d161019cd97df06315e4
import imgaug.augmenters as iaa
import numpy as np
import cv2
import os
from tqdm import tqdm
from PIL import Image
seq = iaa.Sequential(
[
iaa.Fliplr(0.5),
iaa.Crop(percent=(0, 0.1), keep_size=True),
iaa.Sometimes(0.5, iaa.GaussianBlur(sigma=(0, 0.5))),
iaa.AddElementwise((-40, -10), per_channel=0.5),
iaa.Sometimes(0.5, iaa.MultiplyElementwise((0.7, 1.0))),
iaa.OneOf([
iaa.Rotate((-45, 45)),
iaa.Rot90((1, 3))
]),
iaa.Sometimes(0.7, iaa.CoarseDropout(p=(0.1, 0.4), size_percent=(0.02, 0.2))),
iaa.Sometimes(0.02, iaa.imgcorruptlike.MotionBlur(severity=2)),
])
gen_blend = iaa.Sequential([
iaa.GaussianBlur(sigma=(0, 0.5)),
iaa.MultiplyElementwise((0.8, 3.5)),
])
img_root = '/data1/lxl/data/ocr/stamp/src'
gen_root = '/data1/lxl/data/ocr/stamp/aug'
gen_img_root = os.path.join(gen_root, 'img')
gen_mask_root = os.path.join(gen_root, 'mask')
gen_blend_root = os.path.join(gen_root, 'blend')
if not os.path.exists(gen_img_root):
os.makedirs(gen_img_root)
if not os.path.exists(gen_mask_root):
os.makedirs(gen_mask_root)
if not os.path.exists(gen_blend_root):
os.makedirs(gen_blend_root)
name_list = sorted(os.listdir(img_root))
for i in range(2):
for name in tqdm(name_list):
name_no_ex = name.split('.')[0]
ext = name.split('.')[1]
img = cv2.imread(os.path.join(img_root, name), -1)[:, :, :3]
mask = cv2.imread(os.path.join(img_root, name), -1)[:, :, -1]
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
img = np.asarray(img, dtype=np.uint8)
mask = np.asarray(mask, dtype=np.uint8)
img = img[np.newaxis, :]
mask = mask[np.newaxis, :]
img_aug, mask_aug = seq(images=img, segmentation_maps=mask)
img_aug = img_aug.squeeze()
mask_aug = mask_aug.squeeze(0)
blend = cv2.cvtColor(img_aug, cv2.COLOR_BGR2GRAY)
blend_aug = gen_blend(images=blend)
blend_aug = blend_aug.squeeze()
cv2.imwrite(os.path.join(gen_img_root, (name_no_ex + '_' + str(i) + '.' + ext)), img_aug)
cv2.imwrite(os.path.join(gen_mask_root, (name_no_ex + '_' + str(i) + '.' + ext)), mask_aug)
cv2.imwrite(os.path.join(gen_blend_root, (name_no_ex + '_' + str(i) + '.' + ext)), blend_aug)
from .builder import build_loss
__all__ = ['BCE', 'build_loss']
import torch
import inspect
from utils.registery import LOSS_REGISTRY
import copy
def register_torch_loss():
for module_name in dir(torch.nn):
if module_name.startswith('__') or 'Loss' not in module_name:
continue
_loss = getattr(torch.nn, module_name)
if inspect.isclass(_loss) and issubclass(_loss, torch.nn.Module):
LOSS_REGISTRY.register()(_loss)
def build_loss(cfg):
register_torch_loss()
loss_cfg = copy.deepcopy(cfg)
try:
loss_cfg = cfg['solver']['loss']
except Exception:
raise 'should contain {solver.loss}!'
return LOSS_REGISTRY.get(loss_cfg['name'])(**loss_cfg['args'])
from .builder import build_metric
__all__ = ['build_metric']
from utils.registery import METRIC_REGISTRY
import copy
from .recon_metric import Recon
def build_metric(cfg):
cfg = copy.deepcopy(cfg)
try:
metric_cfg = cfg['solver']['metric']
except Exception:
raise 'should contain {solver.metric}!'
return METRIC_REGISTRY.get(metric_cfg['name'])()
import numpy as np
import torchmetrics
from utils.registery import METRIC_REGISTRY
@METRIC_REGISTRY.register()
class Recon(object):
def __init__(self):
self.psnr = torchmetrics.PeakSignalNoiseRatio()
self.ssim = torchmetrics.StructuralSimilarityIndexMeasure()
def __call__(self, pred, label):
assert pred.shape[0] == label.shape[0]
psnr_list = list()
ssim_list = list()
for i in range(len(pred)):
psnr_list.append(self.psnr(pred[i].unsqueeze(0), label[i].unsqueeze(0)))
ssim_list.append(self.ssim(pred[i].unsqueeze(0), label[i].unsqueeze(0)))
psnr_result = sum(psnr_list) / len(psnr_list)
ssim_result = sum(ssim_list) / len(ssim_list)
return {'psnr': psnr_result, 'ssim': ssim_result}
from .builder import build_model
import torch
import torch.nn as nn
from abc import ABCMeta
import math
class BaseModel(nn.Module, metaclass=ABCMeta):
"""
Base model class
"""
def __init__(self):
super().__init__()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
import copy
from utils import MODEL_REGISTRY
from .unet import Unet
from .unet_skip import UnetSkip
def build_model(cfg):
model_cfg = copy.deepcopy(cfg)
try:
model_cfg = model_cfg['model']
except Exception:
raise 'should contain {model}'
model = MODEL_REGISTRY.get(model_cfg['name'])(**model_cfg['args'])
return model
import segmentation_models_pytorch as smp
from utils.registery import MODEL_REGISTRY
from core.model.base_model import BaseModel
@MODEL_REGISTRY.register()
class Unet(BaseModel):
def __init__(self,
encoder_name: str = 'resnet50',
encoder_weights: str = 'imagenet',
in_channels: int = 3,
classes: int = 3,
activation: str = 'tanh'):
super().__init__()
self.model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
activation=activation
)
self._initialize_weights()
def forward(self, x):
out = x + self.model(x)
# out = self.model(x)
return out
import segmentation_models_pytorch as smp
from utils.registery import MODEL_REGISTRY
from core.model.base_model import BaseModel
@MODEL_REGISTRY.register()
class UnetSkip(BaseModel):
def __init__(self,
encoder_name: str = 'resnet50',
encoder_weights: str = 'imagenet',
in_channels: int = 3,
classes: int = 3,
activation: str = 'tanh'):
super().__init__()
self.model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
activation=activation
)
self._initialize_weights()
def forward(self, x):
residual = self.model(x)
reconstruction = x + residual
returned_dict = {'reconstruction': reconstruction, 'residual': residual}
return returned_dict
from .builder import build_optimizer, build_lr_scheduler
__all__ = ['build_optimizer', 'build_lr_scheduler']
import torch
import inspect
from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
import copy
def register_torch_optimizers():
"""
Register all optimizers implemented by torch
"""
for module_name in dir(torch.optim):
if module_name.startswith('__'):
continue
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer):
OPTIMIZER_REGISTRY.register()(_optim)
def build_optimizer(cfg):
register_torch_optimizers()
optimizer_cfg = copy.deepcopy(cfg)
try:
optimizer_cfg = cfg['solver']['optimizer']
except Exception:
raise 'should contain {solver.optimizer}!'
return OPTIMIZER_REGISTRY.get(optimizer_cfg['name'])
def register_torch_lr_scheduler():
"""
Register all lr_schedulers implemented by torch
"""
for module_name in dir(torch.optim.lr_scheduler):
if module_name.startswith('__'):
continue
_scheduler = getattr(torch.optim.lr_scheduler, module_name)
if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler):
LR_SCHEDULER_REGISTRY.register()(_scheduler)
def build_lr_scheduler(cfg):
register_torch_lr_scheduler()
scheduler_cfg = copy.deepcopy(cfg)
try:
scheduler_cfg = cfg['solver']['lr_scheduler']
except Exception:
raise 'should contain {solver.lr_scheduler}!'
return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name'])
from .builder import build_solver
__all__ = ['build_solver']
import torch
from core.data import build_dataloader
from core.model import build_model
from core.optimizer import build_optimizer, build_lr_scheduler
from core.loss import build_loss
from core.metric import build_metric
from utils.registery import SOLVER_REGISTRY
from utils.logger import get_logger_and_log_path
import os
import copy
import datetime
from torch.nn.parallel import DistributedDataParallel
import numpy as np
import pandas as pd
import yaml
@SOLVER_REGISTRY.register()
class BaseSolver(object):
def __init__(self, cfg):
self.cfg = copy.deepcopy(cfg)
self.local_rank = torch.distributed.get_rank()
self.train_loader, self.val_loader = build_dataloader(cfg)
self.len_train_loader, self.len_val_loader = len(self.train_loader), len(self.val_loader)
self.criterion = build_loss(cfg).cuda(self.local_rank)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(build_model(cfg))
self.model = DistributedDataParallel(model.cuda(self.local_rank), device_ids=[self.local_rank], find_unused_parameters=True)
self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
self.hyper_params = cfg['solver']['args']
crt_date = datetime.date.today().strftime('%Y-%m-%d')
self.logger, self.log_path = get_logger_and_log_path(crt_date=crt_date, **cfg['solver']['logger'])
self.metric_fn = build_metric(cfg)
try:
self.epoch = self.hyper_params['epoch']
except Exception:
raise 'should contain epoch in {solver.args}'
if self.local_rank == 0:
self.save_dict_to_yaml(self.cfg, os.path.join(self.log_path, 'config.yaml'))
self.logger.info(self.cfg)
def train(self):
if torch.distributed.get_rank() == 0:
self.logger.info('==> Start Training')
lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
for t in range(self.epoch):
self.train_loader.sampler.set_epoch(t)
if torch.distributed.get_rank() == 0:
self.logger.info(f'==> epoch {t + 1}')
self.model.train()
pred_list = list()
label_list = list()
mean_loss = 0.0
for i, data in enumerate(self.train_loader):
self.optimizer.zero_grad()
image = data['image'].cuda(self.local_rank)
label = data['label'].cuda(self.local_rank)
pred = self.model(image)
loss = self.criterion(pred, label)
mean_loss += loss.item()
if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0):
loss_value = loss.item()
self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, loss: {loss_value :.4f}')
loss.backward()
self.optimizer.step()
# batch_pred = [torch.zeros_like(pred) for _ in range(torch.distributed.get_world_size())] # 1
# torch.distributed.all_gather(batch_pred, pred)
# pred_list.append(torch.cat(batch_pred, dim=0).detach().cpu())
# batch_label = [torch.zeros_like(label) for _ in range(torch.distributed.get_world_size())]
# torch.distributed.all_gather(batch_label, label)
# label_list.append(torch.cat(batch_label, dim=0).detach().cpu())
# pred_list = torch.cat(pred_list, dim=0)
# label_list = torch.cat(label_list, dim=0)
# metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
mean_loss = mean_loss / self.len_train_loader
if torch.distributed.get_rank() == 0:
# self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
self.logger.info(f'==> train mean loss: {mean_loss :.4f}')
self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1)
self.val(t + 1)
lr_scheduler.step()
if self.local_rank == 0:
self.logger.info('==> End Training')
@torch.no_grad()
def val(self, t):
self.model.eval()
pred_list = list()
label_list = list()
for i, data in enumerate(self.val_loader):
feat = data['image'].cuda(self.local_rank)
label = data['label'].cuda(self.local_rank)
pred = self.model(feat)
pred_list.append(pred.detach().cpu())
label_list.append(label.detach().cpu())
pred_list = torch.cat(pred_list, dim=0)
label_list = torch.cat(label_list, dim=0)
metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
if torch.distributed.get_rank() == 0:
self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
def run(self):
self.train()
@staticmethod
def save_dict_to_yaml(dict_value, save_path):
with open(save_path, 'w', encoding='utf-8') as file:
yaml.dump(dict_value, file, sort_keys=False)
def save_checkpoint(self, model, cfg, log_path, epoch_id):
model.eval()
torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))
from utils.registery import SOLVER_REGISTRY
import copy
from .base_solver import BaseSolver
from .skip_solver import SkipSolver
def build_solver(cfg):
cfg = copy.deepcopy(cfg)
try:
solver_cfg = cfg['solver']
except Exception:
raise 'should contain {solver}!'
return SOLVER_REGISTRY.get(solver_cfg['name'])(cfg)
import torch
from core.data import build_dataloader
from core.model import build_model
from core.optimizer import build_optimizer, build_lr_scheduler
from core.loss import build_loss
from core.metric import build_metric
from utils.registery import SOLVER_REGISTRY
from utils.logger import get_logger_and_log_path
import os
import copy
import datetime
from torch.nn.parallel import DistributedDataParallel
import numpy as np
import pandas as pd
import yaml
from .base_solver import BaseSolver
@SOLVER_REGISTRY.register()
class SkipSolver(BaseSolver):
def __init__(self, cfg):
super().__init__(cfg)
def train(self):
if torch.distributed.get_rank() == 0:
self.logger.info('==> Start Training')
lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
for t in range(self.epoch):
self.train_loader.sampler.set_epoch(t)
if torch.distributed.get_rank() == 0:
self.logger.info(f'==> epoch {t + 1}')
self.model.train()
pred_list = list()
label_list = list()
mean_loss = 0.0
for i, data in enumerate(self.train_loader):
self.optimizer.zero_grad()
image = data['image'].cuda(self.local_rank)
label = data['label'].cuda(self.local_rank)
residual = image - label
pred = self.model(image)
reconstruction_loss = self.criterion(pred['reconstruction'], label)
residual_loss = self.criterion(pred['residual'], residual)
loss = reconstruction_loss + residual_loss
mean_loss += loss.item()
if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0):
reconstruction_loss_value = reconstruction_loss.item()
residual_loss_value = residual_loss.item()
loss_value = loss.item()
self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, reconstruction loss: {reconstruction_loss_value :.4f}, residual loss: {residual_loss_value :.4f}, loss: {loss_value :.4f}')
loss.backward()
self.optimizer.step()
mean_loss = mean_loss / self.len_train_loader
if torch.distributed.get_rank() == 0:
# self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
self.logger.info(f'==> train mean loss: {mean_loss :.4f}')
self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1)
self.val(t + 1)
lr_scheduler.step()
if self.local_rank == 0:
self.logger.info('==> End Training')
@torch.no_grad()
def val(self, t):
self.model.eval()
pred_list = list()
label_list = list()
for i, data in enumerate(self.val_loader):
image = data['image'].cuda(self.local_rank)
label = data['label'].cuda(self.local_rank)
residual = image - label
pred = self.model(image)
pred_list.append(pred['reconstruction'].detach().cpu())
label_list.append(label.detach().cpu())
pred_list = torch.cat(pred_list, dim=0)
label_list = torch.cat(label_list, dim=0)
metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
if torch.distributed.get_rank() == 0:
self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
def run(self):
self.train()
@staticmethod
def save_dict_to_yaml(dict_value, save_path):
with open(save_path, 'w', encoding='utf-8') as file:
yaml.dump(dict_value, file, sort_keys=False)
def save_checkpoint(self, model, cfg, log_path, epoch_id):
model.eval()
torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))
import yaml
from core.solver import build_solver
import torch
import numpy as np
import random
import argparse
def init_seed(seed=778):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./config/baseline.yaml', type=str, help='config file')
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
args = parser.parse_args()
cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader)
init_seed(cfg['seed'])
torch.distributed.init_process_group(backend='nccl')
torch.cuda.set_device(args.local_rank)
solver = build_solver(cfg)
solver.run()
if __name__ == '__main__':
main()
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
from core.model.unet import Unet
from core.model.unet_skip import UnetSkip
# def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/baseline/ckpt_epoch_18.pt'):
# def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/residual/ckpt_epoch_28.pt'):
def load_model(ckpt='./log/2022-11-11/skip/ckpt_epoch_30.pt'):
# model = Unet(
# encoder_name='resnet50',
# encoder_weights='imagenet',
# in_channels=3,
# classes=3,
# activation='tanh'
# )
model = UnetSkip(
encoder_name='resnet50',
encoder_weights='imagenet',
in_channels=3,
classes=3,
activation='tanh'
)
model.load_state_dict(torch.load(ckpt, map_location='cpu'))
model.eval()
return model
def infer(model, img_path, gen_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (512, 512))
img = img / 255.
img = img.transpose(2, 0, 1).astype(np.float32)
img = torch.from_numpy(img).unsqueeze(0)
out = model(img)
out = out['reconstruction']
out[out < 0] = 0
out[out > 1] = 1
out = out * 255
out = out.detach().cpu().numpy().squeeze(0).transpose(1, 2, 0)
out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
cv2.imwrite(gen_path, out)
def infer_list():
img_root = '/data1/lxl/data/ocr/real/src/'
gen_root = '/data1/lxl/data/ocr/real/removed/'
img_list = sorted(os.listdir(img_root))
model = load_model()
for img in tqdm(img_list):
img_path = os.path.join(img_root, img)
gen_path = os.path.join(gen_root, img)
infer(model, img_path, gen_path)
def infer_img():
model = load_model()
img_path = '../DocEnTR/real/hetong_006_00.png'
gen_path = './out.jpg'
infer(model, img_path, gen_path)
infer_img()
CUDA_VISIBLE_DEVICES=0 nohup python -m torch.distributed.launch --master_port 8999 --nproc_per_node=1 main.py --config ./config/skip.yaml &
from .registery import *
from .logger import get_logger_and_log_path
from .helper import save_checkpoint
__all__ = [
'Registry',
'get_logger_and_log_path',
'save_checkpoint'
]
import torch
import yaml
import os
def save_dict_to_yaml(dict_value, save_path):
with open(save_path, 'w', encoding='utf-8') as file:
yaml.dump(dict_value, file, sort_keys=False)
def save_checkpoint(model, cfg, log_path, epoch_id):
save_dict_to_yaml(cfg, os.path.join(log_path, 'config.yaml'))
torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))
import loguru
import copy
import os
import datetime
def get_logger_and_log_path(log_root,
crt_date,
suffix):
"""
get logger and log path
Args:
log_root (str): root path of log
crt_date (str): formated date name (Y-M-D)
suffix (str): log save name
Returns:
logger (loguru.logger): logger object
log_path (str): current root log path (with suffix)
"""
log_path = os.path.join(log_root, crt_date, suffix)
if not os.path.exists(log_path):
os.makedirs(log_path)
logger_path = os.path.join(log_path, 'logfile.log')
logger = loguru.logger
fmt = '{time:YYYY-MM-DD at HH:mm:ss} | {message}'
logger.add(logger_path, format=fmt)
return logger, log_path
class Registry():
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
"""
def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
def _do_register(self, name, obj, suffix=None):
if isinstance(suffix, str):
name = name + '_' + suffix
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
f"in '{self._name}' registry!")
self._obj_map[name] = obj
def register(self, obj=None, suffix=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class, suffix)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj, suffix)
def get(self, name, suffix='soulwalker'):
ret = self._obj_map.get(name)
if ret is None:
ret = self._obj_map.get(name + '_' + suffix)
print(f'Name {name} is not found, use name: {name}_{suffix}!')
if ret is None:
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
return ret
def __contains__(self, name):
return name in self._obj_map
def __iter__(self):
return iter(self._obj_map.items())
def keys(self):
return self._obj_map.keys()
DATASET_REGISTRY = Registry('dataset')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')
OPTIMIZER_REGISTRY = Registry('optimizer')
SOLVER_REGISTRY = Registry('solver')
LR_SCHEDULER_REGISTRY = Registry('lr_scheduler')
COLLATE_FN_REGISTRY = Registry('collate_fn')
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!