builder.py
678 Bytes
import torch
import inspect
from utils.registery import LOSS_REGISTRY
import copy
def register_torch_loss():
for module_name in dir(torch.nn):
if module_name.startswith('__') or 'Loss' not in module_name:
continue
_loss = getattr(torch.nn, module_name)
if inspect.isclass(_loss) and issubclass(_loss, torch.nn.Module):
LOSS_REGISTRY.register()(_loss)
def build_loss(cfg):
register_torch_loss()
loss_cfg = copy.deepcopy(cfg)
try:
loss_cfg = cfg['solver']['loss']
except Exception:
raise 'should contain {solver.loss}!'
return LOSS_REGISTRY.get(loss_cfg['name'])(**loss_cfg['args'])