modify evaluate
Showing
3 changed files
with
20 additions
and
15 deletions
... | @@ -166,7 +166,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -166,7 +166,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
166 | 166 | ||
167 | X = list() | 167 | X = list() |
168 | y_true = list() | 168 | y_true = list() |
169 | for i in range(200): | 169 | for i in range(160): |
170 | if i >= valid_lens: | 170 | if i >= valid_lens: |
171 | X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.]) | 171 | X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.]) |
172 | y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | 172 | y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | ... | ... |
... | @@ -65,7 +65,7 @@ class SLSolver(object): | ... | @@ -65,7 +65,7 @@ class SLSolver(object): |
65 | train_loss = torch.zeros(1).to(self.device) | 65 | train_loss = torch.zeros(1).to(self.device) |
66 | correct = torch.zeros(1).to(self.device) | 66 | correct = torch.zeros(1).to(self.device) |
67 | for batch, (X, y, valid_lens) in enumerate(self.train_loader): | 67 | for batch, (X, y, valid_lens) in enumerate(self.train_loader): |
68 | X, y = X.to(self.device), y.to(self.device) | 68 | X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device) |
69 | 69 | ||
70 | pred = self.model(X, valid_lens) | 70 | pred = self.model(X, valid_lens) |
71 | # [batch_size, seq_len, num_classes] | 71 | # [batch_size, seq_len, num_classes] |
... | @@ -97,7 +97,7 @@ class SLSolver(object): | ... | @@ -97,7 +97,7 @@ class SLSolver(object): |
97 | val_loss = torch.zeros(1).to(self.device) | 97 | val_loss = torch.zeros(1).to(self.device) |
98 | correct = torch.zeros(1).to(self.device) | 98 | correct = torch.zeros(1).to(self.device) |
99 | for X, y, valid_lens in self.val_loader: | 99 | for X, y, valid_lens in self.val_loader: |
100 | X, y = X.to(self.device), y.to(self.device) | 100 | X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device) |
101 | 101 | ||
102 | # pred = torch.nn.Sigmoid()(self.model(X)) | 102 | # pred = torch.nn.Sigmoid()(self.model(X)) |
103 | pred = self.model(X, valid_lens) | 103 | pred = self.model(X, valid_lens) |
... | @@ -162,25 +162,30 @@ class SLSolver(object): | ... | @@ -162,25 +162,30 @@ class SLSolver(object): |
162 | 162 | ||
163 | label_true_list = [] | 163 | label_true_list = [] |
164 | label_pred_list = [] | 164 | label_pred_list = [] |
165 | for X, y in self.val_loader: | 165 | for X, y, valid_lens in self.val_loader: |
166 | X, y_true = X.to(self.device), y.to(self.device) | 166 | X, y_true, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device) |
167 | 167 | ||
168 | # pred = torch.nn.Sigmoid()(self.model(X)) | 168 | # pred = torch.nn.Sigmoid()(self.model(X)) |
169 | pred = self.model(X) | 169 | y_pred = self.model(X, valid_lens) |
170 | |||
171 | y_pred = torch.nn.Sigmoid()(pred) | ||
172 | 170 | ||
173 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | 171 | # [batch_size, seq_len, num_classes] |
174 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | 172 | y_pred_sigmoid = torch.nn.Sigmoid()(y_pred) |
173 | # [batch_size, seq_len] | ||
174 | y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1 | ||
175 | # [batch_size, seq_len] | ||
176 | y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > 0.5).int() | ||
175 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | 177 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) |
176 | 178 | ||
177 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | 179 | y_true_idx = torch.argmax(y_true, dim=-1) + 1 |
178 | y_true_is_other = torch.sum(y_true, dim=1) | 180 | y_true_is_other = torch.sum(y_true, dim=-1).int() |
179 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | 181 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) |
180 | 182 | ||
181 | label_true_list.extend(y_true_rebuild.cpu().numpy().tolist()) | 183 | # masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1) |
182 | label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist()) | ||
183 | 184 | ||
185 | for idx, seq_result in enumerate(y_true_rebuild.cpu().numpy().tolist()): | ||
186 | label_true_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]]) | ||
187 | for idx, seq_result in enumerate(y_pred_rebuild.cpu().numpy().tolist()): | ||
188 | label_pred_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]]) | ||
184 | 189 | ||
185 | acc = accuracy_score(label_true_list, label_pred_list) | 190 | acc = accuracy_score(label_true_list, label_pred_list) |
186 | cm = confusion_matrix(label_true_list, label_pred_list) | 191 | cm = confusion_matrix(label_true_list, label_pred_list) | ... | ... |
-
Please register or sign in to post a comment