modify evaluate
Showing
1 changed file
with
13 additions
and
10 deletions
... | @@ -37,16 +37,19 @@ class VITSolver(object): | ... | @@ -37,16 +37,19 @@ class VITSolver(object): |
37 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | 37 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) |
38 | 38 | ||
39 | @staticmethod | 39 | @staticmethod |
40 | def evaluate(y_pred, y_true, thresholds=0.5): | 40 | def evaluate(y_pred, y_true, thresholds=0.5, no_other=False): |
41 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | 41 | if no_other: |
42 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | 42 | return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item() |
43 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | 43 | else: |
44 | 44 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | |
45 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | 45 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() |
46 | y_true_is_other = torch.sum(y_true, dim=1) | 46 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) |
47 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | 47 | |
48 | 48 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | |
49 | return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() | 49 | y_true_is_other = torch.sum(y_true, dim=1) |
50 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
51 | |||
52 | return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() | ||
50 | 53 | ||
51 | def train_loop(self): | 54 | def train_loop(self): |
52 | self.model.train() | 55 | self.model.train() | ... | ... |
-
Please register or sign in to post a comment