import copy import math import torch import inspect from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY from torch.optim.lr_scheduler import LambdaLR def register_torch_optimizers(): """ Register all optimizers implemented by torch """ for module_name in dir(torch.optim): if module_name.startswith('__'): continue _optim = getattr(torch.optim, module_name) if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): OPTIMIZER_REGISTRY.register()(_optim) def build_optimizer(cfg): register_torch_optimizers() optimizer_cfg = copy.deepcopy(cfg) try: optimizer_cfg = cfg['solver']['optimizer'] except Exception: raise 'should contain {solver.optimizer}!' return OPTIMIZER_REGISTRY.get(optimizer_cfg['name']) class CosineLR(LambdaLR): def __init__(self, optimizer, epochs, lrf, last_epoch=-1, verbose=False): lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine super(CosineLR, self).__init__(optimizer=optimizer, lr_lambda=lf, last_epoch=last_epoch, verbose=verbose) def register_cosine_lr_scheduler(): LR_SCHEDULER_REGISTRY.register()(CosineLR) def register_torch_lr_scheduler(): """ Register all lr_schedulers implemented by torch """ register_cosine_lr_scheduler() for module_name in dir(torch.optim.lr_scheduler): if module_name.startswith('__'): continue _scheduler = getattr(torch.optim.lr_scheduler, module_name) if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler): LR_SCHEDULER_REGISTRY.register()(_scheduler) def build_lr_scheduler(cfg): register_torch_lr_scheduler() scheduler_cfg = copy.deepcopy(cfg) try: scheduler_cfg = cfg['solver']['lr_scheduler'] except Exception: raise 'should contain {solver.lr_scheduler}!' return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name'])