30cf7dbe by 周伟奇

modify evaluate

1 parent f3fbac1b
......@@ -37,16 +37,19 @@ class VITSolver(object):
self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
@staticmethod
def evaluate(y_pred, y_true, thresholds=0.5):
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=1) + 1
y_true_is_other = torch.sum(y_true, dim=1)
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item()
def evaluate(y_pred, y_true, thresholds=0.5, no_other=False):
if no_other:
return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item()
else:
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=1) + 1
y_true_is_other = torch.sum(y_true, dim=1)
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item()
def train_loop(self):
self.model.train()
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!