fb5f4ba1 by 周伟奇

add no other

1 parent 30cf7dbe
......@@ -39,6 +39,7 @@ solver:
name: 'VITSolver'
args:
epoch: 100
no_other: false
optimizer:
name: 'Adam'
......
......@@ -29,6 +29,7 @@ class VITSolver(object):
self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
self.hyper_params = cfg['solver']['args']
self.no_other = self.hyper_params['no_other']
try:
self.epoch = self.hyper_params['epoch']
except Exception:
......@@ -36,9 +37,8 @@ class VITSolver(object):
self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
@staticmethod
def evaluate(y_pred, y_true, thresholds=0.5, no_other=False):
if no_other:
def evaluate(self, y_pred, y_true, thresholds=0.5):
if self.no_other:
return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item()
else:
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
......@@ -59,7 +59,10 @@ class VITSolver(object):
for batch, (X, y) in enumerate(self.train_loader):
X, y = X.to(self.device), y.to(self.device)
if self.no_other:
pred = torch.nn.Softmax(dim=1)(self.model(X))
else:
pred = torch.nn.Sigmoid(self.model(X))
correct += self.evaluate(pred, y)
......@@ -88,7 +91,10 @@ class VITSolver(object):
for X, y in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
if self.no_other:
pred = torch.nn.Softmax(dim=1)(self.model(X))
else:
pred = torch.nn.Sigmoid(self.model(X))
correct += self.evaluate(pred, y)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!