add lr_schedulder
Showing
3 changed files
with
29 additions
and
11 deletions
| ... | @@ -45,14 +45,14 @@ solver: | ... | @@ -45,14 +45,14 @@ solver: |
| 45 | optimizer: | 45 | optimizer: |
| 46 | name: 'Adam' | 46 | name: 'Adam' |
| 47 | args: | 47 | args: |
| 48 | lr: !!float 1e-4 | 48 | lr: !!float 1e-3 |
| 49 | weight_decay: !!float 5e-5 | 49 | # weight_decay: !!float 5e-5 |
| 50 | 50 | ||
| 51 | lr_scheduler: | 51 | lr_scheduler: |
| 52 | name: 'StepLR' | 52 | name: 'CosineLR' |
| 53 | args: | 53 | args: |
| 54 | step_size: 15 | 54 | epochs: 100 |
| 55 | gamma: 0.1 | 55 | lrf: 0.1 |
| 56 | 56 | ||
| 57 | loss: | 57 | loss: |
| 58 | name: 'SigmoidFocalLoss' | 58 | name: 'SigmoidFocalLoss' | ... | ... |
| 1 | import copy | ||
| 2 | import math | ||
| 1 | import torch | 3 | import torch |
| 2 | import inspect | 4 | import inspect |
| 3 | from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY | 5 | from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY |
| 4 | import copy | 6 | from torch.optim.lr_scheduler import LambdaLR |
| 5 | 7 | ||
| 6 | def register_torch_optimizers(): | 8 | def register_torch_optimizers(): |
| 7 | """ | 9 | """ |
| ... | @@ -24,10 +26,20 @@ def build_optimizer(cfg): | ... | @@ -24,10 +26,20 @@ def build_optimizer(cfg): |
| 24 | 26 | ||
| 25 | return OPTIMIZER_REGISTRY.get(optimizer_cfg['name']) | 27 | return OPTIMIZER_REGISTRY.get(optimizer_cfg['name']) |
| 26 | 28 | ||
| 29 | class CosineLR(LambdaLR): | ||
| 30 | |||
| 31 | def __init__(self, optimizer, epochs, lrf, last_epoch=-1, verbose=False): | ||
| 32 | lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine | ||
| 33 | super(CosineLR, self).__init__(optimizer=optimizer, lr_lambda=lf, last_epoch=last_epoch, verbose=verbose) | ||
| 34 | |||
| 35 | def register_cosine_lr_scheduler(): | ||
| 36 | LR_SCHEDULER_REGISTRY.register()(CosineLR) | ||
| 37 | |||
| 27 | def register_torch_lr_scheduler(): | 38 | def register_torch_lr_scheduler(): |
| 28 | """ | 39 | """ |
| 29 | Register all lr_schedulers implemented by torch | 40 | Register all lr_schedulers implemented by torch |
| 30 | """ | 41 | """ |
| 42 | register_cosine_lr_scheduler() | ||
| 31 | for module_name in dir(torch.optim.lr_scheduler): | 43 | for module_name in dir(torch.optim.lr_scheduler): |
| 32 | if module_name.startswith('__'): | 44 | if module_name.startswith('__'): |
| 33 | continue | 45 | continue | ... | ... |
| ... | @@ -74,13 +74,16 @@ class VITSolver(object): | ... | @@ -74,13 +74,16 @@ class VITSolver(object): |
| 74 | if batch % 100 == 0: | 74 | if batch % 100 == 0: |
| 75 | loss_value, current = loss.item(), batch | 75 | loss_value, current = loss.item(), batch |
| 76 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | 76 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') |
| 77 | |||
| 78 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | ||
| 79 | 77 | ||
| 80 | self.optimizer.zero_grad() | 78 | self.optimizer.zero_grad() |
| 81 | loss.backward() | 79 | loss.backward() |
| 82 | self.optimizer.step() | 80 | self.optimizer.step() |
| 83 | 81 | ||
| 82 | if self.no_other: | ||
| 83 | correct += self.evaluate(pred, y) | ||
| 84 | else: | ||
| 85 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | ||
| 86 | |||
| 84 | correct /= self.train_dataset_size | 87 | correct /= self.train_dataset_size |
| 85 | train_loss /= self.train_loader_size | 88 | train_loss /= self.train_loader_size |
| 86 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') | 89 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') |
| ... | @@ -103,7 +106,10 @@ class VITSolver(object): | ... | @@ -103,7 +106,10 @@ class VITSolver(object): |
| 103 | loss = self.loss_fn(pred, y) | 106 | loss = self.loss_fn(pred, y) |
| 104 | val_loss += loss.item() | 107 | val_loss += loss.item() |
| 105 | 108 | ||
| 106 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | 109 | if self.no_other: |
| 110 | correct += self.evaluate(pred, y) | ||
| 111 | else: | ||
| 112 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | ||
| 107 | 113 | ||
| 108 | correct /= self.val_dataset_size | 114 | correct /= self.val_dataset_size |
| 109 | val_loss /= self.val_loader_size | 115 | val_loss /= self.val_loader_size |
| ... | @@ -122,7 +128,7 @@ class VITSolver(object): | ... | @@ -122,7 +128,7 @@ class VITSolver(object): |
| 122 | self.logger.info('==> Start Training') | 128 | self.logger.info('==> Start Training') |
| 123 | print(self.model) | 129 | print(self.model) |
| 124 | 130 | ||
| 125 | # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | 131 | lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) |
| 126 | 132 | ||
| 127 | for t in range(self.epoch): | 133 | for t in range(self.epoch): |
| 128 | self.logger.info(f'==> epoch {t + 1}') | 134 | self.logger.info(f'==> epoch {t + 1}') |
| ... | @@ -131,6 +137,6 @@ class VITSolver(object): | ... | @@ -131,6 +137,6 @@ class VITSolver(object): |
| 131 | self.val_loop(t + 1) | 137 | self.val_loop(t + 1) |
| 132 | self.save_checkpoint(t + 1) | 138 | self.save_checkpoint(t + 1) |
| 133 | 139 | ||
| 134 | # lr_scheduler.step() | 140 | lr_scheduler.step() |
| 135 | 141 | ||
| 136 | self.logger.info('==> End Training') | 142 | self.logger.info('==> End Training') | ... | ... |
-
Please register or sign in to post a comment