add no other
Showing
2 changed files
with
10 additions
and
3 deletions
... | @@ -29,6 +29,7 @@ class VITSolver(object): | ... | @@ -29,6 +29,7 @@ class VITSolver(object): |
29 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | 29 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) |
30 | 30 | ||
31 | self.hyper_params = cfg['solver']['args'] | 31 | self.hyper_params = cfg['solver']['args'] |
32 | self.no_other = self.hyper_params['no_other'] | ||
32 | try: | 33 | try: |
33 | self.epoch = self.hyper_params['epoch'] | 34 | self.epoch = self.hyper_params['epoch'] |
34 | except Exception: | 35 | except Exception: |
... | @@ -36,9 +37,8 @@ class VITSolver(object): | ... | @@ -36,9 +37,8 @@ class VITSolver(object): |
36 | 37 | ||
37 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | 38 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) |
38 | 39 | ||
39 | @staticmethod | 40 | def evaluate(self, y_pred, y_true, thresholds=0.5): |
40 | def evaluate(y_pred, y_true, thresholds=0.5, no_other=False): | 41 | if self.no_other: |
41 | if no_other: | ||
42 | return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item() | 42 | return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item() |
43 | else: | 43 | else: |
44 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | 44 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 |
... | @@ -59,7 +59,10 @@ class VITSolver(object): | ... | @@ -59,7 +59,10 @@ class VITSolver(object): |
59 | for batch, (X, y) in enumerate(self.train_loader): | 59 | for batch, (X, y) in enumerate(self.train_loader): |
60 | X, y = X.to(self.device), y.to(self.device) | 60 | X, y = X.to(self.device), y.to(self.device) |
61 | 61 | ||
62 | if self.no_other: | ||
62 | pred = torch.nn.Softmax(dim=1)(self.model(X)) | 63 | pred = torch.nn.Softmax(dim=1)(self.model(X)) |
64 | else: | ||
65 | pred = torch.nn.Sigmoid(self.model(X)) | ||
63 | 66 | ||
64 | correct += self.evaluate(pred, y) | 67 | correct += self.evaluate(pred, y) |
65 | 68 | ||
... | @@ -88,7 +91,10 @@ class VITSolver(object): | ... | @@ -88,7 +91,10 @@ class VITSolver(object): |
88 | for X, y in self.val_loader: | 91 | for X, y in self.val_loader: |
89 | X, y = X.to(self.device), y.to(self.device) | 92 | X, y = X.to(self.device), y.to(self.device) |
90 | 93 | ||
94 | if self.no_other: | ||
91 | pred = torch.nn.Softmax(dim=1)(self.model(X)) | 95 | pred = torch.nn.Softmax(dim=1)(self.model(X)) |
96 | else: | ||
97 | pred = torch.nn.Sigmoid(self.model(X)) | ||
92 | 98 | ||
93 | correct += self.evaluate(pred, y) | 99 | correct += self.evaluate(pred, y) |
94 | 100 | ... | ... |
-
Please register or sign in to post a comment