add load model
Showing
2 changed files
with
18 additions
and
9 deletions
| 1 | import os | ||
| 2 | import copy | 1 | import copy |
| 2 | import os | ||
| 3 | |||
| 3 | import torch | 4 | import torch |
| 4 | 5 | ||
| 5 | from model import build_model | ||
| 6 | from data import build_dataloader | 6 | from data import build_dataloader |
| 7 | from optimizer import build_optimizer, build_lr_scheduler | ||
| 8 | from loss import build_loss | 7 | from loss import build_loss |
| 8 | from model import build_model | ||
| 9 | from optimizer import build_lr_scheduler, build_optimizer | ||
| 9 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | 10 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir |
| 10 | 11 | ||
| 11 | 12 | ||
| ... | @@ -30,6 +31,7 @@ class VITSolver(object): | ... | @@ -30,6 +31,7 @@ class VITSolver(object): |
| 30 | 31 | ||
| 31 | self.hyper_params = cfg['solver']['args'] | 32 | self.hyper_params = cfg['solver']['args'] |
| 32 | self.no_other = self.hyper_params['no_other'] | 33 | self.no_other = self.hyper_params['no_other'] |
| 34 | self.base_on = self.hyper_params['base_on'] | ||
| 33 | try: | 35 | try: |
| 34 | self.epoch = self.hyper_params['epoch'] | 36 | self.epoch = self.hyper_params['epoch'] |
| 35 | except Exception: | 37 | except Exception: |
| ... | @@ -62,9 +64,8 @@ class VITSolver(object): | ... | @@ -62,9 +64,8 @@ class VITSolver(object): |
| 62 | if self.no_other: | 64 | if self.no_other: |
| 63 | pred = torch.nn.Softmax(dim=1)(self.model(X)) | 65 | pred = torch.nn.Softmax(dim=1)(self.model(X)) |
| 64 | else: | 66 | else: |
| 65 | pred = torch.nn.Sigmoid(self.model(X)) | 67 | # pred = torch.nn.Sigmoid()(self.model(X)) |
| 66 | 68 | pred = self.model(X) | |
| 67 | correct += self.evaluate(pred, y) | ||
| 68 | 69 | ||
| 69 | # loss = self.loss_fn(pred, y, reduction="mean") | 70 | # loss = self.loss_fn(pred, y, reduction="mean") |
| 70 | loss = self.loss_fn(pred, y) | 71 | loss = self.loss_fn(pred, y) |
| ... | @@ -74,6 +75,8 @@ class VITSolver(object): | ... | @@ -74,6 +75,8 @@ class VITSolver(object): |
| 74 | loss_value, current = loss.item(), batch | 75 | loss_value, current = loss.item(), batch |
| 75 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | 76 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') |
| 76 | 77 | ||
| 78 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | ||
| 79 | |||
| 77 | self.optimizer.zero_grad() | 80 | self.optimizer.zero_grad() |
| 78 | loss.backward() | 81 | loss.backward() |
| 79 | self.optimizer.step() | 82 | self.optimizer.step() |
| ... | @@ -94,13 +97,14 @@ class VITSolver(object): | ... | @@ -94,13 +97,14 @@ class VITSolver(object): |
| 94 | if self.no_other: | 97 | if self.no_other: |
| 95 | pred = torch.nn.Softmax(dim=1)(self.model(X)) | 98 | pred = torch.nn.Softmax(dim=1)(self.model(X)) |
| 96 | else: | 99 | else: |
| 97 | pred = torch.nn.Sigmoid(self.model(X)) | 100 | # pred = torch.nn.Sigmoid()(self.model(X)) |
| 98 | 101 | pred = self.model(X) | |
| 99 | correct += self.evaluate(pred, y) | ||
| 100 | 102 | ||
| 101 | loss = self.loss_fn(pred, y) | 103 | loss = self.loss_fn(pred, y) |
| 102 | val_loss += loss.item() | 104 | val_loss += loss.item() |
| 103 | 105 | ||
| 106 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | ||
| 107 | |||
| 104 | correct /= self.val_dataset_size | 108 | correct /= self.val_dataset_size |
| 105 | val_loss /= self.val_loader_size | 109 | val_loss /= self.val_loader_size |
| 106 | 110 | ||
| ... | @@ -111,6 +115,10 @@ class VITSolver(object): | ... | @@ -111,6 +115,10 @@ class VITSolver(object): |
| 111 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | 115 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) |
| 112 | 116 | ||
| 113 | def run(self): | 117 | def run(self): |
| 118 | if isinstance(self.base_on, str) and os.path.exists(self.base_on): | ||
| 119 | self.model.load_state_dict(torch.load(self.base_on)) | ||
| 120 | self.logger.info(f'==> Load Model from {self.base_on}') | ||
| 121 | |||
| 114 | self.logger.info('==> Start Training') | 122 | self.logger.info('==> Start Training') |
| 115 | print(self.model) | 123 | print(self.model) |
| 116 | 124 | ... | ... |
-
Please register or sign in to post a comment