fb5f4ba1 by 周伟奇

add no other

1 parent 30cf7dbe
...@@ -39,6 +39,7 @@ solver: ...@@ -39,6 +39,7 @@ solver:
39 name: 'VITSolver' 39 name: 'VITSolver'
40 args: 40 args:
41 epoch: 100 41 epoch: 100
42 no_other: false
42 43
43 optimizer: 44 optimizer:
44 name: 'Adam' 45 name: 'Adam'
......
...@@ -29,6 +29,7 @@ class VITSolver(object): ...@@ -29,6 +29,7 @@ class VITSolver(object):
29 self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) 29 self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
30 30
31 self.hyper_params = cfg['solver']['args'] 31 self.hyper_params = cfg['solver']['args']
32 self.no_other = self.hyper_params['no_other']
32 try: 33 try:
33 self.epoch = self.hyper_params['epoch'] 34 self.epoch = self.hyper_params['epoch']
34 except Exception: 35 except Exception:
...@@ -36,9 +37,8 @@ class VITSolver(object): ...@@ -36,9 +37,8 @@ class VITSolver(object):
36 37
37 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) 38 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
38 39
39 @staticmethod 40 def evaluate(self, y_pred, y_true, thresholds=0.5):
40 def evaluate(y_pred, y_true, thresholds=0.5, no_other=False): 41 if self.no_other:
41 if no_other:
42 return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item() 42 return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item()
43 else: 43 else:
44 y_pred_idx = torch.argmax(y_pred, dim=1) + 1 44 y_pred_idx = torch.argmax(y_pred, dim=1) + 1
...@@ -59,7 +59,10 @@ class VITSolver(object): ...@@ -59,7 +59,10 @@ class VITSolver(object):
59 for batch, (X, y) in enumerate(self.train_loader): 59 for batch, (X, y) in enumerate(self.train_loader):
60 X, y = X.to(self.device), y.to(self.device) 60 X, y = X.to(self.device), y.to(self.device)
61 61
62 if self.no_other:
62 pred = torch.nn.Softmax(dim=1)(self.model(X)) 63 pred = torch.nn.Softmax(dim=1)(self.model(X))
64 else:
65 pred = torch.nn.Sigmoid(self.model(X))
63 66
64 correct += self.evaluate(pred, y) 67 correct += self.evaluate(pred, y)
65 68
...@@ -88,7 +91,10 @@ class VITSolver(object): ...@@ -88,7 +91,10 @@ class VITSolver(object):
88 for X, y in self.val_loader: 91 for X, y in self.val_loader:
89 X, y = X.to(self.device), y.to(self.device) 92 X, y = X.to(self.device), y.to(self.device)
90 93
94 if self.no_other:
91 pred = torch.nn.Softmax(dim=1)(self.model(X)) 95 pred = torch.nn.Softmax(dim=1)(self.model(X))
96 else:
97 pred = torch.nn.Sigmoid(self.model(X))
92 98
93 correct += self.evaluate(pred, y) 99 correct += self.evaluate(pred, y)
94 100
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!