b3694ec8 by 周伟奇

add eval

1 parent 82a85c6d
...@@ -41,6 +41,7 @@ solver: ...@@ -41,6 +41,7 @@ solver:
41 epoch: 100 41 epoch: 100
42 no_other: false 42 no_other: false
43 base_on: null 43 base_on: null
44 model_path: null
44 45
45 optimizer: 46 optimizer:
46 name: 'Adam' 47 name: 'Adam'
......
...@@ -7,6 +7,7 @@ from solver.builder import build_solver ...@@ -7,6 +7,7 @@ from solver.builder import build_solver
7 def main(): 7 def main():
8 parser = argparse.ArgumentParser() 8 parser = argparse.ArgumentParser()
9 parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file') 9 parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file')
10 parser.add_argument('-e', '--eval', action="store_true")
10 args = parser.parse_args() 11 args = parser.parse_args()
11 12
12 cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader) 13 cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader)
...@@ -14,6 +15,10 @@ def main(): ...@@ -14,6 +15,10 @@ def main():
14 # print(torch.cuda.is_available()) 15 # print(torch.cuda.is_available())
15 16
16 solver = build_solver(cfg) 17 solver = build_solver(cfg)
18
19 if args.eval:
20 solver.evaluate()
21 else:
17 solver.run() 22 solver.run()
18 23
19 24
......
...@@ -8,6 +8,7 @@ from loss import build_loss ...@@ -8,6 +8,7 @@ from loss import build_loss
8 from model import build_model 8 from model import build_model
9 from optimizer import build_lr_scheduler, build_optimizer 9 from optimizer import build_lr_scheduler, build_optimizer
10 from utils import SOLVER_REGISTRY, get_logger_and_log_dir 10 from utils import SOLVER_REGISTRY, get_logger_and_log_dir
11 from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
11 12
12 13
13 @SOLVER_REGISTRY.register() 14 @SOLVER_REGISTRY.register()
...@@ -32,6 +33,7 @@ class VITSolver(object): ...@@ -32,6 +33,7 @@ class VITSolver(object):
32 self.hyper_params = cfg['solver']['args'] 33 self.hyper_params = cfg['solver']['args']
33 self.no_other = self.hyper_params['no_other'] 34 self.no_other = self.hyper_params['no_other']
34 self.base_on = self.hyper_params['base_on'] 35 self.base_on = self.hyper_params['base_on']
36 self.model_path = self.hyper_params['model_path']
35 try: 37 try:
36 self.epoch = self.hyper_params['epoch'] 38 self.epoch = self.hyper_params['epoch']
37 except Exception: 39 except Exception:
...@@ -39,7 +41,7 @@ class VITSolver(object): ...@@ -39,7 +41,7 @@ class VITSolver(object):
39 41
40 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) 42 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
41 43
42 def evaluate(self, y_pred, y_true, thresholds=0.5): 44 def accuracy(self, y_pred, y_true, thresholds=0.5):
43 if self.no_other: 45 if self.no_other:
44 return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item() 46 return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item()
45 else: 47 else:
...@@ -80,9 +82,9 @@ class VITSolver(object): ...@@ -80,9 +82,9 @@ class VITSolver(object):
80 self.optimizer.step() 82 self.optimizer.step()
81 83
82 if self.no_other: 84 if self.no_other:
83 correct += self.evaluate(pred, y) 85 correct += self.accuracy(pred, y)
84 else: 86 else:
85 correct += self.evaluate(torch.nn.Sigmoid()(pred), y) 87 correct += self.accuracy(torch.nn.Sigmoid()(pred), y)
86 88
87 correct /= self.train_dataset_size 89 correct /= self.train_dataset_size
88 train_loss /= self.train_loader_size 90 train_loss /= self.train_loader_size
...@@ -107,9 +109,9 @@ class VITSolver(object): ...@@ -107,9 +109,9 @@ class VITSolver(object):
107 val_loss += loss.item() 109 val_loss += loss.item()
108 110
109 if self.no_other: 111 if self.no_other:
110 correct += self.evaluate(pred, y) 112 correct += self.accuracy(pred, y)
111 else: 113 else:
112 correct += self.evaluate(torch.nn.Sigmoid()(pred), y) 114 correct += self.accuracy(torch.nn.Sigmoid()(pred), y)
113 115
114 correct /= self.val_dataset_size 116 correct /= self.val_dataset_size
115 val_loss /= self.val_loader_size 117 val_loss /= self.val_loader_size
...@@ -140,3 +142,44 @@ class VITSolver(object): ...@@ -140,3 +142,44 @@ class VITSolver(object):
140 lr_scheduler.step() 142 lr_scheduler.step()
141 143
142 self.logger.info('==> End Training') 144 self.logger.info('==> End Training')
145
146 def evaluate(self):
147 if isinstance(self.model_path, str) and os.path.exists(self.model_path):
148 self.model.load_state_dict(torch.load(self.model_path))
149 self.logger.info(f'==> Load Model from {self.model_path}')
150 else:
151 return
152
153 self.model.eval()
154
155 label_true_list = []
156 label_pred_list = []
157 for X, y in self.val_loader:
158 X, y_true = X.to(self.device), y.to(self.device)
159
160 if self.no_other:
161 pred = torch.nn.Softmax(dim=1)(self.model(X))
162 else:
163 # pred = torch.nn.Sigmoid()(self.model(X))
164 pred = self.model(X)
165
166 y_pred = torch.nn.Sigmoid()(pred)
167
168 y_pred_idx = torch.argmax(y_pred, dim=1) + 1
169 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
170 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
171
172 y_true_idx = torch.argmax(y_true, dim=1) + 1
173 y_true_is_other = torch.sum(y_true, dim=1)
174 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
175
176 label_true_list.extend(y_true_rebuild.cpu().numpy().tolist())
177 label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist())
178
179
180 acc = accuracy_score(label_true_list, label_pred_list)
181 cm = confusion_matrix(label_true_list, label_pred_list)
182 report = classification_report(label_true_list, label_pred_list)
183 print(acc)
184 print(cm)
185 print(report)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!