30cf7dbe by 周伟奇

modify evaluate

1 parent f3fbac1b
......@@ -37,7 +37,10 @@ class VITSolver(object):
self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
@staticmethod
def evaluate(y_pred, y_true, thresholds=0.5):
def evaluate(y_pred, y_true, thresholds=0.5, no_other=False):
if no_other:
return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item()
else:
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!