add eval
Showing
3 changed files
with
54 additions
and
5 deletions
... | @@ -7,6 +7,7 @@ from solver.builder import build_solver | ... | @@ -7,6 +7,7 @@ from solver.builder import build_solver |
7 | def main(): | 7 | def main(): |
8 | parser = argparse.ArgumentParser() | 8 | parser = argparse.ArgumentParser() |
9 | parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file') | 9 | parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file') |
10 | parser.add_argument('-e', '--eval', action="store_true") | ||
10 | args = parser.parse_args() | 11 | args = parser.parse_args() |
11 | 12 | ||
12 | cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader) | 13 | cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader) |
... | @@ -14,6 +15,10 @@ def main(): | ... | @@ -14,6 +15,10 @@ def main(): |
14 | # print(torch.cuda.is_available()) | 15 | # print(torch.cuda.is_available()) |
15 | 16 | ||
16 | solver = build_solver(cfg) | 17 | solver = build_solver(cfg) |
18 | |||
19 | if args.eval: | ||
20 | solver.evaluate() | ||
21 | else: | ||
17 | solver.run() | 22 | solver.run() |
18 | 23 | ||
19 | 24 | ... | ... |
... | @@ -8,6 +8,7 @@ from loss import build_loss | ... | @@ -8,6 +8,7 @@ from loss import build_loss |
8 | from model import build_model | 8 | from model import build_model |
9 | from optimizer import build_lr_scheduler, build_optimizer | 9 | from optimizer import build_lr_scheduler, build_optimizer |
10 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | 10 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir |
11 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report | ||
11 | 12 | ||
12 | 13 | ||
13 | @SOLVER_REGISTRY.register() | 14 | @SOLVER_REGISTRY.register() |
... | @@ -32,6 +33,7 @@ class VITSolver(object): | ... | @@ -32,6 +33,7 @@ class VITSolver(object): |
32 | self.hyper_params = cfg['solver']['args'] | 33 | self.hyper_params = cfg['solver']['args'] |
33 | self.no_other = self.hyper_params['no_other'] | 34 | self.no_other = self.hyper_params['no_other'] |
34 | self.base_on = self.hyper_params['base_on'] | 35 | self.base_on = self.hyper_params['base_on'] |
36 | self.model_path = self.hyper_params['model_path'] | ||
35 | try: | 37 | try: |
36 | self.epoch = self.hyper_params['epoch'] | 38 | self.epoch = self.hyper_params['epoch'] |
37 | except Exception: | 39 | except Exception: |
... | @@ -39,7 +41,7 @@ class VITSolver(object): | ... | @@ -39,7 +41,7 @@ class VITSolver(object): |
39 | 41 | ||
40 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | 42 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) |
41 | 43 | ||
42 | def evaluate(self, y_pred, y_true, thresholds=0.5): | 44 | def accuracy(self, y_pred, y_true, thresholds=0.5): |
43 | if self.no_other: | 45 | if self.no_other: |
44 | return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item() | 46 | return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item() |
45 | else: | 47 | else: |
... | @@ -80,9 +82,9 @@ class VITSolver(object): | ... | @@ -80,9 +82,9 @@ class VITSolver(object): |
80 | self.optimizer.step() | 82 | self.optimizer.step() |
81 | 83 | ||
82 | if self.no_other: | 84 | if self.no_other: |
83 | correct += self.evaluate(pred, y) | 85 | correct += self.accuracy(pred, y) |
84 | else: | 86 | else: |
85 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | 87 | correct += self.accuracy(torch.nn.Sigmoid()(pred), y) |
86 | 88 | ||
87 | correct /= self.train_dataset_size | 89 | correct /= self.train_dataset_size |
88 | train_loss /= self.train_loader_size | 90 | train_loss /= self.train_loader_size |
... | @@ -107,9 +109,9 @@ class VITSolver(object): | ... | @@ -107,9 +109,9 @@ class VITSolver(object): |
107 | val_loss += loss.item() | 109 | val_loss += loss.item() |
108 | 110 | ||
109 | if self.no_other: | 111 | if self.no_other: |
110 | correct += self.evaluate(pred, y) | 112 | correct += self.accuracy(pred, y) |
111 | else: | 113 | else: |
112 | correct += self.evaluate(torch.nn.Sigmoid()(pred), y) | 114 | correct += self.accuracy(torch.nn.Sigmoid()(pred), y) |
113 | 115 | ||
114 | correct /= self.val_dataset_size | 116 | correct /= self.val_dataset_size |
115 | val_loss /= self.val_loader_size | 117 | val_loss /= self.val_loader_size |
... | @@ -140,3 +142,44 @@ class VITSolver(object): | ... | @@ -140,3 +142,44 @@ class VITSolver(object): |
140 | lr_scheduler.step() | 142 | lr_scheduler.step() |
141 | 143 | ||
142 | self.logger.info('==> End Training') | 144 | self.logger.info('==> End Training') |
145 | |||
146 | def evaluate(self): | ||
147 | if isinstance(self.model_path, str) and os.path.exists(self.model_path): | ||
148 | self.model.load_state_dict(torch.load(self.model_path)) | ||
149 | self.logger.info(f'==> Load Model from {self.model_path}') | ||
150 | else: | ||
151 | return | ||
152 | |||
153 | self.model.eval() | ||
154 | |||
155 | label_true_list = [] | ||
156 | label_pred_list = [] | ||
157 | for X, y in self.val_loader: | ||
158 | X, y_true = X.to(self.device), y.to(self.device) | ||
159 | |||
160 | if self.no_other: | ||
161 | pred = torch.nn.Softmax(dim=1)(self.model(X)) | ||
162 | else: | ||
163 | # pred = torch.nn.Sigmoid()(self.model(X)) | ||
164 | pred = self.model(X) | ||
165 | |||
166 | y_pred = torch.nn.Sigmoid()(pred) | ||
167 | |||
168 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
169 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
170 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
171 | |||
172 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
173 | y_true_is_other = torch.sum(y_true, dim=1) | ||
174 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
175 | |||
176 | label_true_list.extend(y_true_rebuild.cpu().numpy().tolist()) | ||
177 | label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist()) | ||
178 | |||
179 | |||
180 | acc = accuracy_score(label_true_list, label_pred_list) | ||
181 | cm = confusion_matrix(label_true_list, label_pred_list) | ||
182 | report = classification_report(label_true_list, label_pred_list) | ||
183 | print(acc) | ||
184 | print(cm) | ||
185 | print(report) | ... | ... |
-
Please register or sign in to post a comment