add lr_schedulder
Showing
3 changed files
with
28 additions
and
10 deletions
... | @@ -45,14 +45,14 @@ solver: | ... | @@ -45,14 +45,14 @@ solver: |
45 | optimizer: | 45 | optimizer: |
46 | name: 'Adam' | 46 | name: 'Adam' |
47 | args: | 47 | args: |
48 | lr: !!float 1e-4 | 48 | lr: !!float 1e-3 |
49 | weight_decay: !!float 5e-5 | 49 | # weight_decay: !!float 5e-5 |
50 | 50 | ||
51 | lr_scheduler: | 51 | lr_scheduler: |
52 | name: 'StepLR' | 52 | name: 'CosineLR' |
53 | args: | 53 | args: |
54 | step_size: 15 | 54 | epochs: 100 |
55 | gamma: 0.1 | 55 | lrf: 0.1 |
56 | 56 | ||
57 | loss: | 57 | loss: |
58 | name: 'SigmoidFocalLoss' | 58 | name: 'SigmoidFocalLoss' | ... | ... |
1 | import copy | ||
2 | import math | ||
1 | import torch | 3 | import torch |
2 | import inspect | 4 | import inspect |
3 | from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY | 5 | from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY |
4 | import copy | 6 | from torch.optim.lr_scheduler import LambdaLR |
5 | 7 | ||
6 | def register_torch_optimizers(): | 8 | def register_torch_optimizers(): |
7 | """ | 9 | """ |
... | @@ -24,10 +26,20 @@ def build_optimizer(cfg): | ... | @@ -24,10 +26,20 @@ def build_optimizer(cfg): |
24 | 26 | ||
25 | return OPTIMIZER_REGISTRY.get(optimizer_cfg['name']) | 27 | return OPTIMIZER_REGISTRY.get(optimizer_cfg['name']) |
26 | 28 | ||
29 | class CosineLR(LambdaLR): | ||
30 | |||
31 | def __init__(self, optimizer, epochs, lrf, last_epoch=-1, verbose=False): | ||
32 | lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine | ||
33 | super(CosineLR, self).__init__(optimizer=optimizer, lr_lambda=lf, last_epoch=last_epoch, verbose=verbose) | ||
34 | |||
35 | def register_cosine_lr_scheduler(): | ||
36 | LR_SCHEDULER_REGISTRY.register()(CosineLR) | ||
37 | |||
27 | def register_torch_lr_scheduler(): | 38 | def register_torch_lr_scheduler(): |
28 | """ | 39 | """ |
29 | Register all lr_schedulers implemented by torch | 40 | Register all lr_schedulers implemented by torch |
30 | """ | 41 | """ |
42 | register_cosine_lr_scheduler() | ||
31 | for module_name in dir(torch.optim.lr_scheduler): | 43 | for module_name in dir(torch.optim.lr_scheduler): |
32 | if module_name.startswith('__'): | 44 | if module_name.startswith('__'): |
33 | continue | 45 | continue | ... | ... |
... | @@ -75,12 +75,15 @@ class VITSolver(object): | ... | @@ -75,12 +75,15 @@ class VITSolver(object): |
75 | loss_value, current = loss.item(), batch | 75 | loss_value, current = loss.item(), batch |
76 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | 76 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') |
77 | 77 | ||
78 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | ||
79 | |||
80 | self.optimizer.zero_grad() | 78 | self.optimizer.zero_grad() |
81 | loss.backward() | 79 | loss.backward() |
82 | self.optimizer.step() | 80 | self.optimizer.step() |
83 | 81 | ||
82 | if self.no_other: | ||
83 | correct += self.evaluate(pred, y) | ||
84 | else: | ||
85 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | ||
86 | |||
84 | correct /= self.train_dataset_size | 87 | correct /= self.train_dataset_size |
85 | train_loss /= self.train_loader_size | 88 | train_loss /= self.train_loader_size |
86 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') | 89 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') |
... | @@ -103,6 +106,9 @@ class VITSolver(object): | ... | @@ -103,6 +106,9 @@ class VITSolver(object): |
103 | loss = self.loss_fn(pred, y) | 106 | loss = self.loss_fn(pred, y) |
104 | val_loss += loss.item() | 107 | val_loss += loss.item() |
105 | 108 | ||
109 | if self.no_other: | ||
110 | correct += self.evaluate(pred, y) | ||
111 | else: | ||
106 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | 112 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) |
107 | 113 | ||
108 | correct /= self.val_dataset_size | 114 | correct /= self.val_dataset_size |
... | @@ -122,7 +128,7 @@ class VITSolver(object): | ... | @@ -122,7 +128,7 @@ class VITSolver(object): |
122 | self.logger.info('==> Start Training') | 128 | self.logger.info('==> Start Training') |
123 | print(self.model) | 129 | print(self.model) |
124 | 130 | ||
125 | # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | 131 | lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) |
126 | 132 | ||
127 | for t in range(self.epoch): | 133 | for t in range(self.epoch): |
128 | self.logger.info(f'==> epoch {t + 1}') | 134 | self.logger.info(f'==> epoch {t + 1}') |
... | @@ -131,6 +137,6 @@ class VITSolver(object): | ... | @@ -131,6 +137,6 @@ class VITSolver(object): |
131 | self.val_loop(t + 1) | 137 | self.val_loop(t + 1) |
132 | self.save_checkpoint(t + 1) | 138 | self.save_checkpoint(t + 1) |
133 | 139 | ||
134 | # lr_scheduler.step() | 140 | lr_scheduler.step() |
135 | 141 | ||
136 | self.logger.info('==> End Training') | 142 | self.logger.info('==> End Training') | ... | ... |
-
Please register or sign in to post a comment