30cf7dbe by 周伟奇

modify evaluate

1 parent f3fbac1b
...@@ -37,16 +37,19 @@ class VITSolver(object): ...@@ -37,16 +37,19 @@ class VITSolver(object):
37 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) 37 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
38 38
39 @staticmethod 39 @staticmethod
40 def evaluate(y_pred, y_true, thresholds=0.5): 40 def evaluate(y_pred, y_true, thresholds=0.5, no_other=False):
41 y_pred_idx = torch.argmax(y_pred, dim=1) + 1 41 if no_other:
42 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() 42 return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item()
43 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) 43 else:
44 44 y_pred_idx = torch.argmax(y_pred, dim=1) + 1
45 y_true_idx = torch.argmax(y_true, dim=1) + 1 45 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
46 y_true_is_other = torch.sum(y_true, dim=1) 46 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
47 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) 47
48 48 y_true_idx = torch.argmax(y_true, dim=1) + 1
49 return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() 49 y_true_is_other = torch.sum(y_true, dim=1)
50 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
51
52 return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item()
50 53
51 def train_loop(self): 54 def train_loop(self):
52 self.model.train() 55 self.model.train()
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!