builder.py
1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import inspect
from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
import copy
def register_torch_optimizers():
"""
Register all optimizers implemented by torch
"""
for module_name in dir(torch.optim):
if module_name.startswith('__'):
continue
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer):
OPTIMIZER_REGISTRY.register()(_optim)
def build_optimizer(cfg):
register_torch_optimizers()
optimizer_cfg = copy.deepcopy(cfg)
try:
optimizer_cfg = cfg['solver']['optimizer']
except Exception:
raise 'should contain {solver.optimizer}!'
return OPTIMIZER_REGISTRY.get(optimizer_cfg['name'])
def register_torch_lr_scheduler():
"""
Register all lr_schedulers implemented by torch
"""
for module_name in dir(torch.optim.lr_scheduler):
if module_name.startswith('__'):
continue
_scheduler = getattr(torch.optim.lr_scheduler, module_name)
if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler):
LR_SCHEDULER_REGISTRY.register()(_scheduler)
def build_lr_scheduler(cfg):
register_torch_lr_scheduler()
scheduler_cfg = copy.deepcopy(cfg)
try:
scheduler_cfg = cfg['solver']['lr_scheduler']
except Exception:
raise 'should contain {solver.lr_scheduler}!'
return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name'])