82a85c6d by 周伟奇

add lr_schedulder

1 parent 69e75f77
...@@ -45,14 +45,14 @@ solver: ...@@ -45,14 +45,14 @@ solver:
45 optimizer: 45 optimizer:
46 name: 'Adam' 46 name: 'Adam'
47 args: 47 args:
48 lr: !!float 1e-4 48 lr: !!float 1e-3
49 weight_decay: !!float 5e-5 49 # weight_decay: !!float 5e-5
50 50
51 lr_scheduler: 51 lr_scheduler:
52 name: 'StepLR' 52 name: 'CosineLR'
53 args: 53 args:
54 step_size: 15 54 epochs: 100
55 gamma: 0.1 55 lrf: 0.1
56 56
57 loss: 57 loss:
58 name: 'SigmoidFocalLoss' 58 name: 'SigmoidFocalLoss'
......
1 import copy
2 import math
1 import torch 3 import torch
2 import inspect 4 import inspect
3 from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY 5 from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
4 import copy 6 from torch.optim.lr_scheduler import LambdaLR
5 7
6 def register_torch_optimizers(): 8 def register_torch_optimizers():
7 """ 9 """
...@@ -24,10 +26,20 @@ def build_optimizer(cfg): ...@@ -24,10 +26,20 @@ def build_optimizer(cfg):
24 26
25 return OPTIMIZER_REGISTRY.get(optimizer_cfg['name']) 27 return OPTIMIZER_REGISTRY.get(optimizer_cfg['name'])
26 28
29 class CosineLR(LambdaLR):
30
31 def __init__(self, optimizer, epochs, lrf, last_epoch=-1, verbose=False):
32 lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine
33 super(CosineLR, self).__init__(optimizer=optimizer, lr_lambda=lf, last_epoch=last_epoch, verbose=verbose)
34
35 def register_cosine_lr_scheduler():
36 LR_SCHEDULER_REGISTRY.register()(CosineLR)
37
27 def register_torch_lr_scheduler(): 38 def register_torch_lr_scheduler():
28 """ 39 """
29 Register all lr_schedulers implemented by torch 40 Register all lr_schedulers implemented by torch
30 """ 41 """
42 register_cosine_lr_scheduler()
31 for module_name in dir(torch.optim.lr_scheduler): 43 for module_name in dir(torch.optim.lr_scheduler):
32 if module_name.startswith('__'): 44 if module_name.startswith('__'):
33 continue 45 continue
......
...@@ -75,12 +75,15 @@ class VITSolver(object): ...@@ -75,12 +75,15 @@ class VITSolver(object):
75 loss_value, current = loss.item(), batch 75 loss_value, current = loss.item(), batch
76 self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') 76 self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
77 77
78 correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
79
80 self.optimizer.zero_grad() 78 self.optimizer.zero_grad()
81 loss.backward() 79 loss.backward()
82 self.optimizer.step() 80 self.optimizer.step()
83 81
82 if self.no_other:
83 correct += self.evaluate(pred, y)
84 else:
85 correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
86
84 correct /= self.train_dataset_size 87 correct /= self.train_dataset_size
85 train_loss /= self.train_loader_size 88 train_loss /= self.train_loader_size
86 self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') 89 self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
...@@ -103,6 +106,9 @@ class VITSolver(object): ...@@ -103,6 +106,9 @@ class VITSolver(object):
103 loss = self.loss_fn(pred, y) 106 loss = self.loss_fn(pred, y)
104 val_loss += loss.item() 107 val_loss += loss.item()
105 108
109 if self.no_other:
110 correct += self.evaluate(pred, y)
111 else:
106 correct += self.evaluate(torch.nn.Sigmoid()(pred), y) 112 correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
107 113
108 correct /= self.val_dataset_size 114 correct /= self.val_dataset_size
...@@ -122,7 +128,7 @@ class VITSolver(object): ...@@ -122,7 +128,7 @@ class VITSolver(object):
122 self.logger.info('==> Start Training') 128 self.logger.info('==> Start Training')
123 print(self.model) 129 print(self.model)
124 130
125 # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) 131 lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
126 132
127 for t in range(self.epoch): 133 for t in range(self.epoch):
128 self.logger.info(f'==> epoch {t + 1}') 134 self.logger.info(f'==> epoch {t + 1}')
...@@ -131,6 +137,6 @@ class VITSolver(object): ...@@ -131,6 +137,6 @@ class VITSolver(object):
131 self.val_loop(t + 1) 137 self.val_loop(t + 1)
132 self.save_checkpoint(t + 1) 138 self.save_checkpoint(t + 1)
133 139
134 # lr_scheduler.step() 140 lr_scheduler.step()
135 141
136 self.logger.info('==> End Training') 142 self.logger.info('==> End Training')
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!