add load model
Showing
2 changed files
with
18 additions
and
9 deletions
1 | import os | ||
2 | import copy | 1 | import copy |
2 | import os | ||
3 | |||
3 | import torch | 4 | import torch |
4 | 5 | ||
5 | from model import build_model | ||
6 | from data import build_dataloader | 6 | from data import build_dataloader |
7 | from optimizer import build_optimizer, build_lr_scheduler | ||
8 | from loss import build_loss | 7 | from loss import build_loss |
8 | from model import build_model | ||
9 | from optimizer import build_lr_scheduler, build_optimizer | ||
9 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | 10 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir |
10 | 11 | ||
11 | 12 | ||
... | @@ -30,6 +31,7 @@ class VITSolver(object): | ... | @@ -30,6 +31,7 @@ class VITSolver(object): |
30 | 31 | ||
31 | self.hyper_params = cfg['solver']['args'] | 32 | self.hyper_params = cfg['solver']['args'] |
32 | self.no_other = self.hyper_params['no_other'] | 33 | self.no_other = self.hyper_params['no_other'] |
34 | self.base_on = self.hyper_params['base_on'] | ||
33 | try: | 35 | try: |
34 | self.epoch = self.hyper_params['epoch'] | 36 | self.epoch = self.hyper_params['epoch'] |
35 | except Exception: | 37 | except Exception: |
... | @@ -62,9 +64,8 @@ class VITSolver(object): | ... | @@ -62,9 +64,8 @@ class VITSolver(object): |
62 | if self.no_other: | 64 | if self.no_other: |
63 | pred = torch.nn.Softmax(dim=1)(self.model(X)) | 65 | pred = torch.nn.Softmax(dim=1)(self.model(X)) |
64 | else: | 66 | else: |
65 | pred = torch.nn.Sigmoid(self.model(X)) | 67 | # pred = torch.nn.Sigmoid()(self.model(X)) |
66 | 68 | pred = self.model(X) | |
67 | correct += self.evaluate(pred, y) | ||
68 | 69 | ||
69 | # loss = self.loss_fn(pred, y, reduction="mean") | 70 | # loss = self.loss_fn(pred, y, reduction="mean") |
70 | loss = self.loss_fn(pred, y) | 71 | loss = self.loss_fn(pred, y) |
... | @@ -73,6 +74,8 @@ class VITSolver(object): | ... | @@ -73,6 +74,8 @@ class VITSolver(object): |
73 | if batch % 100 == 0: | 74 | if batch % 100 == 0: |
74 | loss_value, current = loss.item(), batch | 75 | loss_value, current = loss.item(), batch |
75 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | 76 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') |
77 | |||
78 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | ||
76 | 79 | ||
77 | self.optimizer.zero_grad() | 80 | self.optimizer.zero_grad() |
78 | loss.backward() | 81 | loss.backward() |
... | @@ -94,13 +97,14 @@ class VITSolver(object): | ... | @@ -94,13 +97,14 @@ class VITSolver(object): |
94 | if self.no_other: | 97 | if self.no_other: |
95 | pred = torch.nn.Softmax(dim=1)(self.model(X)) | 98 | pred = torch.nn.Softmax(dim=1)(self.model(X)) |
96 | else: | 99 | else: |
97 | pred = torch.nn.Sigmoid(self.model(X)) | 100 | # pred = torch.nn.Sigmoid()(self.model(X)) |
98 | 101 | pred = self.model(X) | |
99 | correct += self.evaluate(pred, y) | ||
100 | 102 | ||
101 | loss = self.loss_fn(pred, y) | 103 | loss = self.loss_fn(pred, y) |
102 | val_loss += loss.item() | 104 | val_loss += loss.item() |
103 | 105 | ||
106 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | ||
107 | |||
104 | correct /= self.val_dataset_size | 108 | correct /= self.val_dataset_size |
105 | val_loss /= self.val_loader_size | 109 | val_loss /= self.val_loader_size |
106 | 110 | ||
... | @@ -111,6 +115,10 @@ class VITSolver(object): | ... | @@ -111,6 +115,10 @@ class VITSolver(object): |
111 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | 115 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) |
112 | 116 | ||
113 | def run(self): | 117 | def run(self): |
118 | if isinstance(self.base_on, str) and os.path.exists(self.base_on): | ||
119 | self.model.load_state_dict(torch.load(self.base_on)) | ||
120 | self.logger.info(f'==> Load Model from {self.base_on}') | ||
121 | |||
114 | self.logger.info('==> Start Training') | 122 | self.logger.info('==> Start Training') |
115 | print(self.model) | 123 | print(self.model) |
116 | 124 | ... | ... |
-
Please register or sign in to post a comment