first commit

0 parents main
.DS_Store
logs/
## Intro
测试坐标分类的信息结构化方案
## Useage
```
pip install -r requirements.txt
python3 main.py --config=path/to/config.yaml
```
seed: 3407
dataset:
name: 'CoordinatesData'
args:
data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset'
train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv'
val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv'
dataloader:
batch_size: 32
num_workers: 4
pin_memory: true
shuffle: true
model:
name: 'MLPModel'
args:
activation: 'relu'
solver:
name: 'MLPSolver'
args:
epoch: 100
optimizer:
name: 'Adam'
args:
lr: !!float 1e-4
weight_decay: !!float 5e-5
lr_scheduler:
name: 'StepLR'
args:
step_size: 15
gamma: 0.1
loss:
name: 'SigmoidFocalLoss'
# name: 'CrossEntropyLoss'
args:
reduction: "mean"
logger:
log_root: '/Users/zhouweiqi/Downloads/test/logs'
suffix: 'mlp'
\ No newline at end of file
import os
import json
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from utils.registery import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class CoordinatesData(Dataset):
def __init__(self,
data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset',
anno_file: str = 'train.csv',
phase: str = 'train'):
self.data_root = data_root
self.df = pd.read_csv(anno_file)
self.phase = phase
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
series = self.df.iloc[idx]
name = series['name']
with open(os.path.join(self.data_root, self.phase, name), 'r') as fp:
input_coordinates_list, label_list = json.load(fp)
input_coordinates = torch.tensor(input_coordinates_list)
label = torch.tensor(label_list).float()
return input_coordinates, label
\ No newline at end of file
from .builder import build_dataloader
__all__ = ['build_dataloader']
import copy
from torch.utils.data import DataLoader
from utils.registery import DATASET_REGISTRY
from .CoordinatesData import CoordinatesData
def build_dataset(cfg):
dataset_cfg = copy.deepcopy(cfg)
try:
dataset_cfg = dataset_cfg['dataset']
except Exception:
raise 'should contain {dataset}!'
train_cfg = copy.deepcopy(dataset_cfg)
val_cfg = copy.deepcopy(dataset_cfg)
train_cfg['args']['anno_file'] = train_cfg['args'].pop('train_anno_file')
train_cfg['args'].pop('val_anno_file', None)
train_cfg['args']['phase'] = 'train'
val_cfg['args']['anno_file'] = val_cfg['args'].pop('val_anno_file')
val_cfg['args'].pop('train_anno_file', None)
val_cfg['args']['phase'] = 'valid'
train_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**train_cfg['args'])
val_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**val_cfg['args'])
return train_data, val_data
def build_dataloader(cfg):
dataloader_cfg = copy.deepcopy(cfg)
try:
dataloader_cfg = cfg['dataloader']
except Exception:
raise 'should contain {dataloader}!'
train_ds, val_ds = build_dataset(cfg)
train_loader = DataLoader(train_ds,
**dataloader_cfg)
val_loader = DataLoader(val_ds,
**dataloader_cfg)
return train_loader, val_loader
import os
import cv2
import uuid
import json
import random
import copy
import pandas as pd
from tools import get_file_paths, load_json
def text_statistics(go_res_dir):
"""
Args:
go_res_dir: str 通用OCR的JSON文件夹
Returns: list 出现次数最多的文本及其次数
"""
json_count = 0
text_dict = {}
go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
for go_res_json_path in go_res_json_paths:
print('Info: start {0}'.format(go_res_json_path))
json_count += 1
go_res = load_json(go_res_json_path)
for _, text in go_res.values():
if text in text_dict:
text_dict[text] += 1
else:
text_dict[text] = 1
top_text_list = []
# 按照次数排序
for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True):
if text == '':
continue
# 丢弃:次数少于总数的2/3
if count <= json_count // 3:
break
top_text_list.append((text, count))
return top_text_list
def build_anno_file(dataset_dir, anno_file_path):
img_list = os.listdir(dataset_dir)
random.shuffle(img_list)
df = pd.DataFrame(columns=['name'])
df['name'] = img_list
df.to_csv(anno_file_path)
def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir):
"""
Args:
img_dir: str 图片目录
go_res_dir: str 通用OCR的JSON保存目录
label_dir: str 标注的JSON保存目录
top_text_list: list 出现次数最多的文本及其次数
skip_list: list 跳过的图片列表
save_dir: str 数据集保存目录
"""
# if os.path.exists(save_dir):
# return
# else:
# os.makedirs(save_dir, exist_ok=True)
count = 0
un_count = 0
top_text_count = len(top_text_list)
for img_name in sorted(os.listdir(img_dir)):
if img_name in skip_list:
print('Info: skip {0}'.format(img_name))
continue
print('Info: start {0}'.format(img_name))
image_path = os.path.join(img_dir, img_name)
img = cv2.imread(image_path)
h, w, _ = img.shape
base_image_name, _ = os.path.splitext(img_name)
go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name))
go_res = load_json(go_res_json_path)
input_key_list = []
not_found_count = 0
go_key_set = set()
for top_text, _ in top_text_list:
for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), text) in go_res.items():
if text == top_text:
input_key_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
go_key_set.add(go_key)
break
else:
not_found_count += 1
input_key_list.append([0, 0, 0, 0, 0, 0, 0, 0])
if not_found_count >= top_text_count // 3:
print('Info: skip {0} : {1}/{2}'.format(img_name, not_found_count, top_text_count))
continue
label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name))
label_res = load_json(label_json_path)
# 开票日期 发票代码 机打号码 车辆类型 电话
test_group_id = [1, 2, 5, 9, 20]
group_list = []
for group_id in test_group_id:
for item in label_res.get("shapes", []):
if item.get("group_id") == group_id:
x_list = []
y_list = []
for point in item['points']:
x_list.append(point[0])
y_list.append(point[1])
group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2])
break
else:
group_list.append(None)
go_center_list = []
for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), _) in go_res.items():
if go_key in go_key_set:
continue
xmin = min(x0, x1, x2, x3)
ymin = min(y0, y1, y2, y3)
xmax = max(x0, x1, x2, x3)
ymax = max(y0, y1, y2, y3)
xcenter = xmin + (xmax - xmin)/2
ycenter = ymin + (ymax - ymin)/2
go_center_list.append([xcenter, ycenter, go_key])
group_go_key_list = []
for label_center_list in group_list:
if isinstance(label_center_list, list):
min_go_key = None
min_length = None
for go_x_center, go_y_center, go_key in go_center_list:
if go_key in go_key_set:
continue
length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1])
if min_go_key is None or length < min_length:
min_go_key = go_key
min_length = length
if min_go_key is not None:
go_key_set.add(min_go_key)
group_go_key_list.append(min_go_key)
else:
group_go_key_list.append(None)
else:
group_go_key_list.append(None)
src_label_list = [0 for _ in test_group_id]
for idx, find_go_key in enumerate(group_go_key_list):
if find_go_key is None:
continue
(x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res[find_go_key]
input_list = copy.deepcopy(input_key_list)
input_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
input_label = copy.deepcopy(src_label_list)
input_label[idx] = 1
# with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, '{0}-{1}'.format(img_name, find_go_key)))), 'w') as fp:
# json.dump([input_list, input_label], fp)
count += 1
for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), _) in go_res.items():
if go_key in go_key_set:
continue
input_list = copy.deepcopy(input_key_list)
input_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
# with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, '{0}-{1}'.format(img_name, go_key)))), 'w') as fp:
# json.dump([input_list, src_label_list], fp)
un_count += 1
# break
print(count)
print(un_count)
if __name__ == '__main__':
base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
go_dir = os.path.join(base_dir, 'go_res')
dataset_save_dir = os.path.join(base_dir, 'dataset')
label_dir = os.path.join(base_dir, 'labeled')
train_go_path = os.path.join(go_dir, 'train')
train_image_path = os.path.join(label_dir, 'train', 'image')
train_label_path = os.path.join(label_dir, 'train', 'label')
train_dataset_dir = os.path.join(dataset_save_dir, 'train')
train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv')
valid_go_path = os.path.join(go_dir, 'valid')
valid_image_path = os.path.join(label_dir, 'valid', 'image')
valid_label_path = os.path.join(label_dir, 'valid', 'label')
valid_dataset_dir = os.path.join(dataset_save_dir, 'valid')
valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv')
# top_text_list = text_statistics(go_dir)
# for t in top_text_list:
# print(t)
filter_from_top_text_list = [
('机器编号', 496),
('购买方名称', 496),
('合格证号', 495),
('进口证明书号', 495),
('机打代码', 494),
('车辆类型', 492),
('完税凭证号码', 492),
('机打号码', 491),
('发动机号码', 491),
('主管税务', 491),
('价税合计', 489),
('机关及代码', 489),
('销货单位名称', 486),
('厂牌型号', 485),
('产地', 485),
('商检单号', 483),
('电话', 476),
('开户银行', 472),
('车辆识别代号/车架号码', 463),
('身份证号码', 454),
('吨位', 452),
('备注:一车一票', 439),
('地', 432),
('账号', 431),
('统一社会信用代码/', 424),
('限乘人数', 404),
('税额', 465),
('址', 392)
]
skip_list_train = [
'CH-B101910792-page-12.jpg',
'CH-B101655312-page-13.jpg',
'CH-B102278656.jpg',
'CH-B101846620_page_1_img_0.jpg',
'CH-B103062528-0.jpg',
'CH-B102613120-3.jpg',
'CH-B102997980-3.jpg',
'CH-B102680060-3.jpg',
# 'CH-B102995500-2.jpg', # 没value
]
skip_list_valid = [
'CH-B102897920-2.jpg',
'CH-B102551284-0.jpg',
'CH-B102879376-2.jpg',
'CH-B101509488-page-16.jpg',
'CH-B102708352-2.jpg',
]
# build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
# build_anno_file(train_dataset_dir, train_anno_file_path)
# build_anno_file(valid_dataset_dir, valid_anno_file_path)
import json
import os
def get_exclude_paths(input_path, exclude_list=[]):
"""
Args:
input_path: str 目标目录
exclude_list: list 排除文件或目录的相对位置
Returns: set 排除文件或目录的绝对路径集合
"""
exclude_paths_set = set()
if os.path.isdir(input_path):
for path in exclude_list:
abs_path = path if os.path.isabs(path) else os.path.join(input_path, path)
if not os.path.exists(abs_path):
print('Warning: exclude path not exists: {0}'.format(abs_path))
continue
exclude_paths_set.add(abs_path)
return exclude_paths_set
def get_file_paths(input_path, suffix_list, exclude_list=[]):
"""
Args:
input_path: str 目标目录
suffix_list: list 搜索的文件的后缀列表
exclude_list: list 排除文件或目录的相对位置
Returns: list 搜索到的相关文件绝对路径列表
"""
exclude_paths_set = get_exclude_paths(input_path, exclude_list)
for parent, _, filenames in os.walk(input_path):
if parent in exclude_paths_set:
print('Info: exclude path: {0}'.format(parent))
continue
for filename in filenames:
for suffix in suffix_list:
if filename.endswith(suffix):
file_path = os.path.join(parent, filename)
break
else:
continue
if file_path in exclude_paths_set:
print('Info: exclude path: {0}'.format(file_path))
continue
yield file_path
def load_json(json_path):
"""
Args:
json_path: str JSON文件路径
Returns: obj JSON对象
"""
with open(json_path, 'r') as fp:
output = json.load(fp)
return output
\ No newline at end of file
from .builder import build_loss
__all__ = ['build_loss']
import copy
import torch
import inspect
from utils.registery import LOSS_REGISTRY
from torchvision.ops import sigmoid_focal_loss
class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
def __init__(self,
weight= None,
size_average=None,
reduce=None,
reduction: str = 'mean',
alpha: float = 0.25,
gamma: float = 2):
super().__init__(weight, size_average, reduce, reduction)
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction)
def register_sigmoid_focal_loss():
LOSS_REGISTRY.register()(SigmoidFocalLoss)
def register_torch_loss():
for module_name in dir(torch.nn):
if module_name.startswith('__') or 'Loss' not in module_name:
continue
_loss = getattr(torch.nn, module_name)
if inspect.isclass(_loss) and issubclass(_loss, torch.nn.Module):
LOSS_REGISTRY.register()(_loss)
def build_loss(cfg):
register_sigmoid_focal_loss()
register_torch_loss()
loss_cfg = copy.deepcopy(cfg)
try:
loss_cfg = cfg['solver']['loss']
except Exception:
raise 'should contain {solver.loss}!'
# return sigmoid_focal_loss
return LOSS_REGISTRY.get(loss_cfg['name'])(**loss_cfg['args'])
\ No newline at end of file
import argparse
import torch
import yaml
from solver.builder import build_solver
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file')
args = parser.parse_args()
cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader)
# print(cfg)
# print(torch.cuda.is_available())
solver = build_solver(cfg)
solver.run()
if __name__ == '__main__':
main()
from .builder import build_model
__all__ = ['build_model']
import copy
from utils import MODEL_REGISTRY
from .mlp import MLPModel
def build_model(cfg):
model_cfg = copy.deepcopy(cfg)
try:
model_cfg = model_cfg['model']
except Exception:
raise 'should contain {model}'
model = MODEL_REGISTRY.get(model_cfg['name'])(**model_cfg['args'])
return model
from abc import ABCMeta
import torch.nn as nn
from utils.registery import MODEL_REGISTRY
@MODEL_REGISTRY.register()
class MLPModel(nn.Module):
def __init__(self, activation):
super().__init__()
self.activation_fn = activation
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(29*8, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 5),
nn.Sigmoid(),
)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
from .builder import build_optimizer, build_lr_scheduler
__all__ = ['build_optimizer', 'build_lr_scheduler']
import torch
import inspect
from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
import copy
def register_torch_optimizers():
"""
Register all optimizers implemented by torch
"""
for module_name in dir(torch.optim):
if module_name.startswith('__'):
continue
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer):
OPTIMIZER_REGISTRY.register()(_optim)
def build_optimizer(cfg):
register_torch_optimizers()
optimizer_cfg = copy.deepcopy(cfg)
try:
optimizer_cfg = cfg['solver']['optimizer']
except Exception:
raise 'should contain {solver.optimizer}!'
return OPTIMIZER_REGISTRY.get(optimizer_cfg['name'])
def register_torch_lr_scheduler():
"""
Register all lr_schedulers implemented by torch
"""
for module_name in dir(torch.optim.lr_scheduler):
if module_name.startswith('__'):
continue
_scheduler = getattr(torch.optim.lr_scheduler, module_name)
if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler):
LR_SCHEDULER_REGISTRY.register()(_scheduler)
def build_lr_scheduler(cfg):
register_torch_lr_scheduler()
scheduler_cfg = copy.deepcopy(cfg)
try:
scheduler_cfg = cfg['solver']['lr_scheduler']
except Exception:
raise 'should contain {solver.lr_scheduler}!'
return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name'])
\ No newline at end of file
torch==1.13.0
torchvision==0.14.0
PyYaml==6.0
loguru==0.6.0
pandas==1.5.2
opencv-python==4.6.0.66
\ No newline at end of file
from .builder import build_solver
__all__ = ['build_solver']
import copy
from utils.registery import SOLVER_REGISTRY
from .mlp_solver import MLPSolver
def build_solver(cfg):
cfg = copy.deepcopy(cfg)
try:
solver_cfg = cfg['solver']
except Exception:
raise 'should contain {solver}!'
return SOLVER_REGISTRY.get(solver_cfg['name'])(cfg)
import os
import copy
import torch
from model import build_model
from data import build_dataloader
from optimizer import build_optimizer, build_lr_scheduler
from loss import build_loss
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
@SOLVER_REGISTRY.register()
class MLPSolver(object):
def __init__(self, cfg):
self.cfg = copy.deepcopy(cfg)
self.train_loader, self.val_loader = build_dataloader(cfg)
self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
# BatchNorm ?
self.model = build_model(cfg)
self.loss_fn = build_loss(cfg)
self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
self.hyper_params = cfg['solver']['args']
try:
self.epoch = self.hyper_params['epoch']
except Exception:
raise 'should contain epoch in {solver.args}'
self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
@staticmethod
def evaluate(y_pred, y_true, thresholds=0.5):
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=1) + 1
y_true_is_other = torch.sum(y_true, dim=1)
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item()
def train_loop(self):
self.model.train()
train_loss = 0
for batch, (X, y) in enumerate(self.train_loader):
pred = self.model(X)
# loss = self.loss_fn(pred, y, reduction="mean")
loss = self.loss_fn(pred, y)
train_loss += loss.item()
if batch % 100 == 0:
loss_value, current = loss.item(), batch
self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
train_loss /= self.train_loader_size
self.logger.info(f'train mean loss: {train_loss :.4f}')
@torch.no_grad()
def val_loop(self, t):
self.model.eval()
val_loss, correct = 0, 0
for X, y in self.val_loader:
pred = self.model(X)
correct += self.evaluate(pred, y)
loss = self.loss_fn(pred, y)
val_loss += loss.item()
correct /= self.val_dataset_size
val_loss /= self.val_loader_size
self.logger.info(f"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}")
def save_checkpoint(self, epoch_id):
self.model.eval()
torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
def run(self):
self.logger.info('==> Start Training')
print(self.model)
# lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
for t in range(self.epoch):
self.logger.info(f'==> epoch {t + 1}')
self.train_loop()
self.val_loop(t + 1)
self.save_checkpoint(t + 1)
# lr_scheduler.step()
self.logger.info('==> End Training')
# for X, y in self.train_loader:
# print(X.size())
# print(y.size())
# pred = self.model(X)
# print(pred)
# print(y)
# loss = self.loss_fn(pred, y, reduction="mean")
# print(loss)
# break
# y_true = [
# [0, 1, 0],
# [0, 1, 0],
# [0, 0, 1],
# [0, 0, 0],
# ]
# y_pred = [
# [0.1, 0.8, 0.9],
# [0.2, 0.8, 0.1],
# [0.2, 0.1, 0.85],
# [0.2, 0.6, 0.1],
# ]
# acc_num = self.evaluate(torch.tensor(y_pred), torch.tensor(y_true))
from .registery import *
from .logger import get_logger_and_log_dir
__all__ = [
'Registry',
'get_logger_and_log_dir',
]
import loguru
import os
import datetime
def get_logger_and_log_dir(log_root, suffix):
"""
get logger and log path
Args:
log_root (str): root path of log
suffix (str): log save name
Returns:
logger (loguru.logger): logger object
log_path (str): current root log path (with suffix)
"""
crt_date = datetime.date.today().strftime('%Y-%m-%d')
log_dir = os.path.join(log_root, crt_date, suffix)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
logger_path = os.path.join(log_dir, 'logfile.log')
fmt = '{time:YYYY-MM-DD at HH:mm:ss} | {message}'
logger = loguru.logger
logger.add(logger_path, format=fmt)
return logger, log_dir
\ No newline at end of file
class Registry():
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
"""
def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
def _do_register(self, name, obj, suffix=None):
if isinstance(suffix, str):
name = name + '_' + suffix
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
f"in '{self._name}' registry!")
self._obj_map[name] = obj
def register(self, obj=None, suffix=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class, suffix)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj, suffix)
def get(self, name, suffix='soulwalker'):
ret = self._obj_map.get(name)
if ret is None:
ret = self._obj_map.get(name + '_' + suffix)
print(f'Name {name} is not found, use name: {name}_{suffix}!')
if ret is None:
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
return ret
def __contains__(self, name):
return name in self._obj_map
def __iter__(self):
return iter(self._obj_map.items())
def keys(self):
return self._obj_map.keys()
DATASET_REGISTRY = Registry('dataset')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')
OPTIMIZER_REGISTRY = Registry('optimizer')
SOLVER_REGISTRY = Registry('solver')
LR_SCHEDULER_REGISTRY = Registry('lr_scheduler')
COLLATE_FN_REGISTRY = Registry('collate_fn')
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!