builder.py 2 KB
import copy
import math
import torch
import inspect
from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
from torch.optim.lr_scheduler import LambdaLR

def register_torch_optimizers():
    """
    Register all optimizers implemented by torch
    """
    for module_name in dir(torch.optim):
        if module_name.startswith('__'):
            continue
        _optim = getattr(torch.optim, module_name)
        if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer):
            OPTIMIZER_REGISTRY.register()(_optim)

def build_optimizer(cfg):
    register_torch_optimizers()
    optimizer_cfg = copy.deepcopy(cfg)
    try:
        optimizer_cfg = cfg['solver']['optimizer']
    except Exception:
        raise 'should contain {solver.optimizer}!'
    
    return OPTIMIZER_REGISTRY.get(optimizer_cfg['name'])

class CosineLR(LambdaLR):

    def __init__(self, optimizer, epochs, lrf, last_epoch=-1, verbose=False):
        lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf  # cosine
        super(CosineLR, self).__init__(optimizer=optimizer, lr_lambda=lf, last_epoch=last_epoch, verbose=verbose)
        
def register_cosine_lr_scheduler():
    LR_SCHEDULER_REGISTRY.register()(CosineLR)

def register_torch_lr_scheduler():
    """
    Register all lr_schedulers implemented by torch
    """
    register_cosine_lr_scheduler()
    for module_name in dir(torch.optim.lr_scheduler):
        if module_name.startswith('__'):
            continue
        
        _scheduler = getattr(torch.optim.lr_scheduler, module_name)
        if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler):
            LR_SCHEDULER_REGISTRY.register()(_scheduler)

def build_lr_scheduler(cfg):
    register_torch_lr_scheduler()
    scheduler_cfg = copy.deepcopy(cfg)
    try:
        scheduler_cfg = cfg['solver']['lr_scheduler']
    except Exception:
        raise 'should contain {solver.lr_scheduler}!'
    return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name'])