82a85c6d by 周伟奇

add lr_schedulder

1 parent 69e75f77
......@@ -45,14 +45,14 @@ solver:
optimizer:
name: 'Adam'
args:
lr: !!float 1e-4
weight_decay: !!float 5e-5
lr: !!float 1e-3
# weight_decay: !!float 5e-5
lr_scheduler:
name: 'StepLR'
name: 'CosineLR'
args:
step_size: 15
gamma: 0.1
epochs: 100
lrf: 0.1
loss:
name: 'SigmoidFocalLoss'
......
import copy
import math
import torch
import inspect
from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
import copy
from torch.optim.lr_scheduler import LambdaLR
def register_torch_optimizers():
"""
......@@ -24,10 +26,20 @@ def build_optimizer(cfg):
return OPTIMIZER_REGISTRY.get(optimizer_cfg['name'])
class CosineLR(LambdaLR):
def __init__(self, optimizer, epochs, lrf, last_epoch=-1, verbose=False):
lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine
super(CosineLR, self).__init__(optimizer=optimizer, lr_lambda=lf, last_epoch=last_epoch, verbose=verbose)
def register_cosine_lr_scheduler():
LR_SCHEDULER_REGISTRY.register()(CosineLR)
def register_torch_lr_scheduler():
"""
Register all lr_schedulers implemented by torch
"""
register_cosine_lr_scheduler()
for module_name in dir(torch.optim.lr_scheduler):
if module_name.startswith('__'):
continue
......
......@@ -74,13 +74,16 @@ class VITSolver(object):
if batch % 100 == 0:
loss_value, current = loss.item(), batch
self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.no_other:
correct += self.evaluate(pred, y)
else:
correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
correct /= self.train_dataset_size
train_loss /= self.train_loader_size
self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
......@@ -103,7 +106,10 @@ class VITSolver(object):
loss = self.loss_fn(pred, y)
val_loss += loss.item()
correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
if self.no_other:
correct += self.evaluate(pred, y)
else:
correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
correct /= self.val_dataset_size
val_loss /= self.val_loader_size
......@@ -122,7 +128,7 @@ class VITSolver(object):
self.logger.info('==> Start Training')
print(self.model)
# lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
for t in range(self.epoch):
self.logger.info(f'==> epoch {t + 1}')
......@@ -131,6 +137,6 @@ class VITSolver(object):
self.val_loop(t + 1)
self.save_checkpoint(t + 1)
# lr_scheduler.step()
lr_scheduler.step()
self.logger.info('==> End Training')
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!