69e75f77 by 周伟奇

add load model

1 parent fb5f4ba1
...@@ -40,6 +40,7 @@ solver: ...@@ -40,6 +40,7 @@ solver:
40 args: 40 args:
41 epoch: 100 41 epoch: 100
42 no_other: false 42 no_other: false
43 base_on: null
43 44
44 optimizer: 45 optimizer:
45 name: 'Adam' 46 name: 'Adam'
......
1 import os
2 import copy 1 import copy
2 import os
3
3 import torch 4 import torch
4 5
5 from model import build_model
6 from data import build_dataloader 6 from data import build_dataloader
7 from optimizer import build_optimizer, build_lr_scheduler
8 from loss import build_loss 7 from loss import build_loss
8 from model import build_model
9 from optimizer import build_lr_scheduler, build_optimizer
9 from utils import SOLVER_REGISTRY, get_logger_and_log_dir 10 from utils import SOLVER_REGISTRY, get_logger_and_log_dir
10 11
11 12
...@@ -30,6 +31,7 @@ class VITSolver(object): ...@@ -30,6 +31,7 @@ class VITSolver(object):
30 31
31 self.hyper_params = cfg['solver']['args'] 32 self.hyper_params = cfg['solver']['args']
32 self.no_other = self.hyper_params['no_other'] 33 self.no_other = self.hyper_params['no_other']
34 self.base_on = self.hyper_params['base_on']
33 try: 35 try:
34 self.epoch = self.hyper_params['epoch'] 36 self.epoch = self.hyper_params['epoch']
35 except Exception: 37 except Exception:
...@@ -62,9 +64,8 @@ class VITSolver(object): ...@@ -62,9 +64,8 @@ class VITSolver(object):
62 if self.no_other: 64 if self.no_other:
63 pred = torch.nn.Softmax(dim=1)(self.model(X)) 65 pred = torch.nn.Softmax(dim=1)(self.model(X))
64 else: 66 else:
65 pred = torch.nn.Sigmoid(self.model(X)) 67 # pred = torch.nn.Sigmoid()(self.model(X))
66 68 pred = self.model(X)
67 correct += self.evaluate(pred, y)
68 69
69 # loss = self.loss_fn(pred, y, reduction="mean") 70 # loss = self.loss_fn(pred, y, reduction="mean")
70 loss = self.loss_fn(pred, y) 71 loss = self.loss_fn(pred, y)
...@@ -73,6 +74,8 @@ class VITSolver(object): ...@@ -73,6 +74,8 @@ class VITSolver(object):
73 if batch % 100 == 0: 74 if batch % 100 == 0:
74 loss_value, current = loss.item(), batch 75 loss_value, current = loss.item(), batch
75 self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') 76 self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
77
78 correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
76 79
77 self.optimizer.zero_grad() 80 self.optimizer.zero_grad()
78 loss.backward() 81 loss.backward()
...@@ -94,13 +97,14 @@ class VITSolver(object): ...@@ -94,13 +97,14 @@ class VITSolver(object):
94 if self.no_other: 97 if self.no_other:
95 pred = torch.nn.Softmax(dim=1)(self.model(X)) 98 pred = torch.nn.Softmax(dim=1)(self.model(X))
96 else: 99 else:
97 pred = torch.nn.Sigmoid(self.model(X)) 100 # pred = torch.nn.Sigmoid()(self.model(X))
98 101 pred = self.model(X)
99 correct += self.evaluate(pred, y)
100 102
101 loss = self.loss_fn(pred, y) 103 loss = self.loss_fn(pred, y)
102 val_loss += loss.item() 104 val_loss += loss.item()
103 105
106 correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
107
104 correct /= self.val_dataset_size 108 correct /= self.val_dataset_size
105 val_loss /= self.val_loader_size 109 val_loss /= self.val_loader_size
106 110
...@@ -111,6 +115,10 @@ class VITSolver(object): ...@@ -111,6 +115,10 @@ class VITSolver(object):
111 torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) 115 torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
112 116
113 def run(self): 117 def run(self):
118 if isinstance(self.base_on, str) and os.path.exists(self.base_on):
119 self.model.load_state_dict(torch.load(self.base_on))
120 self.logger.info(f'==> Load Model from {self.base_on}')
121
114 self.logger.info('==> Start Training') 122 self.logger.info('==> Start Training')
115 print(self.model) 123 print(self.model)
116 124
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!