3e58f6b0 by 周伟奇

modify evaluate

1 parent 40ca6fe1
......@@ -16,7 +16,7 @@ dataloader:
model:
name: 'SLTransformer'
args:
seq_lens: 200
seq_lens: 160
num_classes: 10
embed_dim: 9
depth: 6
......
......@@ -166,7 +166,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
X = list()
y_true = list()
for i in range(200):
for i in range(160):
if i >= valid_lens:
X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.])
y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
......
......@@ -65,7 +65,7 @@ class SLSolver(object):
train_loss = torch.zeros(1).to(self.device)
correct = torch.zeros(1).to(self.device)
for batch, (X, y, valid_lens) in enumerate(self.train_loader):
X, y = X.to(self.device), y.to(self.device)
X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)
pred = self.model(X, valid_lens)
# [batch_size, seq_len, num_classes]
......@@ -97,7 +97,7 @@ class SLSolver(object):
val_loss = torch.zeros(1).to(self.device)
correct = torch.zeros(1).to(self.device)
for X, y, valid_lens in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)
# pred = torch.nn.Sigmoid()(self.model(X))
pred = self.model(X, valid_lens)
......@@ -162,26 +162,31 @@ class SLSolver(object):
label_true_list = []
label_pred_list = []
for X, y in self.val_loader:
X, y_true = X.to(self.device), y.to(self.device)
for X, y, valid_lens in self.val_loader:
X, y_true, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)
# pred = torch.nn.Sigmoid()(self.model(X))
pred = self.model(X)
y_pred = self.model(X, valid_lens)
y_pred = torch.nn.Sigmoid()(pred)
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
# [batch_size, seq_len, num_classes]
y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
# [batch_size, seq_len]
y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
# [batch_size, seq_len]
y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=1) + 1
y_true_is_other = torch.sum(y_true, dim=1)
y_true_idx = torch.argmax(y_true, dim=-1) + 1
y_true_is_other = torch.sum(y_true, dim=-1).int()
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
label_true_list.extend(y_true_rebuild.cpu().numpy().tolist())
label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist())
# masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)
for idx, seq_result in enumerate(y_true_rebuild.cpu().numpy().tolist()):
label_true_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
for idx, seq_result in enumerate(y_pred_rebuild.cpu().numpy().tolist()):
label_pred_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
acc = accuracy_score(label_true_list, label_pred_list)
cm = confusion_matrix(label_true_list, label_pred_list)
report = classification_report(label_true_list, label_pred_list)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!