modify evaluate
Showing
1 changed file
with
4 additions
and
1 deletions
... | @@ -37,7 +37,10 @@ class VITSolver(object): | ... | @@ -37,7 +37,10 @@ class VITSolver(object): |
37 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | 37 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) |
38 | 38 | ||
39 | @staticmethod | 39 | @staticmethod |
40 | def evaluate(y_pred, y_true, thresholds=0.5): | 40 | def evaluate(y_pred, y_true, thresholds=0.5, no_other=False): |
41 | if no_other: | ||
42 | return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item() | ||
43 | else: | ||
41 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | 44 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 |
42 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | 45 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() |
43 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | 46 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ... | ... |
-
Please register or sign in to post a comment