3e58f6b0 by 周伟奇

modify evaluate

1 parent 40ca6fe1
...@@ -16,7 +16,7 @@ dataloader: ...@@ -16,7 +16,7 @@ dataloader:
16 model: 16 model:
17 name: 'SLTransformer' 17 name: 'SLTransformer'
18 args: 18 args:
19 seq_lens: 200 19 seq_lens: 160
20 num_classes: 10 20 num_classes: 10
21 embed_dim: 9 21 embed_dim: 9
22 depth: 6 22 depth: 6
......
...@@ -166,7 +166,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save ...@@ -166,7 +166,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
166 166
167 X = list() 167 X = list()
168 y_true = list() 168 y_true = list()
169 for i in range(200): 169 for i in range(160):
170 if i >= valid_lens: 170 if i >= valid_lens:
171 X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.]) 171 X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.])
172 y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) 172 y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
......
...@@ -65,7 +65,7 @@ class SLSolver(object): ...@@ -65,7 +65,7 @@ class SLSolver(object):
65 train_loss = torch.zeros(1).to(self.device) 65 train_loss = torch.zeros(1).to(self.device)
66 correct = torch.zeros(1).to(self.device) 66 correct = torch.zeros(1).to(self.device)
67 for batch, (X, y, valid_lens) in enumerate(self.train_loader): 67 for batch, (X, y, valid_lens) in enumerate(self.train_loader):
68 X, y = X.to(self.device), y.to(self.device) 68 X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)
69 69
70 pred = self.model(X, valid_lens) 70 pred = self.model(X, valid_lens)
71 # [batch_size, seq_len, num_classes] 71 # [batch_size, seq_len, num_classes]
...@@ -97,7 +97,7 @@ class SLSolver(object): ...@@ -97,7 +97,7 @@ class SLSolver(object):
97 val_loss = torch.zeros(1).to(self.device) 97 val_loss = torch.zeros(1).to(self.device)
98 correct = torch.zeros(1).to(self.device) 98 correct = torch.zeros(1).to(self.device)
99 for X, y, valid_lens in self.val_loader: 99 for X, y, valid_lens in self.val_loader:
100 X, y = X.to(self.device), y.to(self.device) 100 X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)
101 101
102 # pred = torch.nn.Sigmoid()(self.model(X)) 102 # pred = torch.nn.Sigmoid()(self.model(X))
103 pred = self.model(X, valid_lens) 103 pred = self.model(X, valid_lens)
...@@ -162,25 +162,30 @@ class SLSolver(object): ...@@ -162,25 +162,30 @@ class SLSolver(object):
162 162
163 label_true_list = [] 163 label_true_list = []
164 label_pred_list = [] 164 label_pred_list = []
165 for X, y in self.val_loader: 165 for X, y, valid_lens in self.val_loader:
166 X, y_true = X.to(self.device), y.to(self.device) 166 X, y_true, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)
167 167
168 # pred = torch.nn.Sigmoid()(self.model(X)) 168 # pred = torch.nn.Sigmoid()(self.model(X))
169 pred = self.model(X) 169 y_pred = self.model(X, valid_lens)
170
171 y_pred = torch.nn.Sigmoid()(pred)
172 170
173 y_pred_idx = torch.argmax(y_pred, dim=1) + 1 171 # [batch_size, seq_len, num_classes]
174 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() 172 y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
173 # [batch_size, seq_len]
174 y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
175 # [batch_size, seq_len]
176 y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > 0.5).int()
175 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) 177 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
176 178
177 y_true_idx = torch.argmax(y_true, dim=1) + 1 179 y_true_idx = torch.argmax(y_true, dim=-1) + 1
178 y_true_is_other = torch.sum(y_true, dim=1) 180 y_true_is_other = torch.sum(y_true, dim=-1).int()
179 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) 181 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
180 182
181 label_true_list.extend(y_true_rebuild.cpu().numpy().tolist()) 183 # masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)
182 label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist())
183 184
185 for idx, seq_result in enumerate(y_true_rebuild.cpu().numpy().tolist()):
186 label_true_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
187 for idx, seq_result in enumerate(y_pred_rebuild.cpu().numpy().tolist()):
188 label_pred_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
184 189
185 acc = accuracy_score(label_true_list, label_pred_list) 190 acc = accuracy_score(label_true_list, label_pred_list)
186 cm = confusion_matrix(label_true_list, label_pred_list) 191 cm = confusion_matrix(label_true_list, label_pred_list)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!