modify evaluate
Showing
3 changed files
with
21 additions
and
16 deletions
| ... | @@ -166,7 +166,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -166,7 +166,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
| 166 | 166 | ||
| 167 | X = list() | 167 | X = list() |
| 168 | y_true = list() | 168 | y_true = list() |
| 169 | for i in range(200): | 169 | for i in range(160): |
| 170 | if i >= valid_lens: | 170 | if i >= valid_lens: |
| 171 | X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.]) | 171 | X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.]) |
| 172 | y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | 172 | y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | ... | ... |
| ... | @@ -65,7 +65,7 @@ class SLSolver(object): | ... | @@ -65,7 +65,7 @@ class SLSolver(object): |
| 65 | train_loss = torch.zeros(1).to(self.device) | 65 | train_loss = torch.zeros(1).to(self.device) |
| 66 | correct = torch.zeros(1).to(self.device) | 66 | correct = torch.zeros(1).to(self.device) |
| 67 | for batch, (X, y, valid_lens) in enumerate(self.train_loader): | 67 | for batch, (X, y, valid_lens) in enumerate(self.train_loader): |
| 68 | X, y = X.to(self.device), y.to(self.device) | 68 | X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device) |
| 69 | 69 | ||
| 70 | pred = self.model(X, valid_lens) | 70 | pred = self.model(X, valid_lens) |
| 71 | # [batch_size, seq_len, num_classes] | 71 | # [batch_size, seq_len, num_classes] |
| ... | @@ -97,7 +97,7 @@ class SLSolver(object): | ... | @@ -97,7 +97,7 @@ class SLSolver(object): |
| 97 | val_loss = torch.zeros(1).to(self.device) | 97 | val_loss = torch.zeros(1).to(self.device) |
| 98 | correct = torch.zeros(1).to(self.device) | 98 | correct = torch.zeros(1).to(self.device) |
| 99 | for X, y, valid_lens in self.val_loader: | 99 | for X, y, valid_lens in self.val_loader: |
| 100 | X, y = X.to(self.device), y.to(self.device) | 100 | X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device) |
| 101 | 101 | ||
| 102 | # pred = torch.nn.Sigmoid()(self.model(X)) | 102 | # pred = torch.nn.Sigmoid()(self.model(X)) |
| 103 | pred = self.model(X, valid_lens) | 103 | pred = self.model(X, valid_lens) |
| ... | @@ -162,26 +162,31 @@ class SLSolver(object): | ... | @@ -162,26 +162,31 @@ class SLSolver(object): |
| 162 | 162 | ||
| 163 | label_true_list = [] | 163 | label_true_list = [] |
| 164 | label_pred_list = [] | 164 | label_pred_list = [] |
| 165 | for X, y in self.val_loader: | 165 | for X, y, valid_lens in self.val_loader: |
| 166 | X, y_true = X.to(self.device), y.to(self.device) | 166 | X, y_true, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device) |
| 167 | 167 | ||
| 168 | # pred = torch.nn.Sigmoid()(self.model(X)) | 168 | # pred = torch.nn.Sigmoid()(self.model(X)) |
| 169 | pred = self.model(X) | 169 | y_pred = self.model(X, valid_lens) |
| 170 | 170 | ||
| 171 | y_pred = torch.nn.Sigmoid()(pred) | 171 | # [batch_size, seq_len, num_classes] |
| 172 | 172 | y_pred_sigmoid = torch.nn.Sigmoid()(y_pred) | |
| 173 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | 173 | # [batch_size, seq_len] |
| 174 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | 174 | y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1 |
| 175 | # [batch_size, seq_len] | ||
| 176 | y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > 0.5).int() | ||
| 175 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | 177 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) |
| 176 | 178 | ||
| 177 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | 179 | y_true_idx = torch.argmax(y_true, dim=-1) + 1 |
| 178 | y_true_is_other = torch.sum(y_true, dim=1) | 180 | y_true_is_other = torch.sum(y_true, dim=-1).int() |
| 179 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | 181 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) |
| 180 | 182 | ||
| 181 | label_true_list.extend(y_true_rebuild.cpu().numpy().tolist()) | 183 | # masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1) |
| 182 | label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist()) | ||
| 183 | |||
| 184 | 184 | ||
| 185 | for idx, seq_result in enumerate(y_true_rebuild.cpu().numpy().tolist()): | ||
| 186 | label_true_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]]) | ||
| 187 | for idx, seq_result in enumerate(y_pred_rebuild.cpu().numpy().tolist()): | ||
| 188 | label_pred_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]]) | ||
| 189 | |||
| 185 | acc = accuracy_score(label_true_list, label_pred_list) | 190 | acc = accuracy_score(label_true_list, label_pred_list) |
| 186 | cm = confusion_matrix(label_true_list, label_pred_list) | 191 | cm = confusion_matrix(label_true_list, label_pred_list) |
| 187 | report = classification_report(label_true_list, label_pred_list) | 192 | report = classification_report(label_true_list, label_pred_list) | ... | ... |
-
Please register or sign in to post a comment