f3fbac1b by 周伟奇

modify vit cls_token

1 parent 0a93c10d
...@@ -55,10 +55,10 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save ...@@ -55,10 +55,10 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
55 skip_list: list 跳过的图片列表 55 skip_list: list 跳过的图片列表
56 save_dir: str 数据集保存目录 56 save_dir: str 数据集保存目录
57 """ 57 """
58 # if os.path.exists(save_dir): 58 if os.path.exists(save_dir):
59 # return 59 return
60 # else: 60 else:
61 # os.makedirs(save_dir, exist_ok=True) 61 os.makedirs(save_dir, exist_ok=True)
62 62
63 top_text_count = len(top_text_list) 63 top_text_count = len(top_text_list)
64 for img_name in sorted(os.listdir(img_dir)): 64 for img_name in sorted(os.listdir(img_dir)):
...@@ -238,11 +238,11 @@ if __name__ == '__main__': ...@@ -238,11 +238,11 @@ if __name__ == '__main__':
238 'CH-B102708352-2.jpg', 238 'CH-B102708352-2.jpg',
239 ] 239 ]
240 240
241 # build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) 241 build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
242 242
243 # build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir) 243 # build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
244 244
245 # build_anno_file(train_dataset_dir, train_anno_file_path) 245 build_anno_file(train_dataset_dir, train_anno_file_path)
246 # build_anno_file(valid_dataset_dir, valid_anno_file_path) 246 # build_anno_file(valid_dataset_dir, valid_anno_file_path)
247 247
248 248
......
...@@ -206,7 +206,8 @@ class VisionTransformer(nn.Module): ...@@ -206,7 +206,8 @@ class VisionTransformer(nn.Module):
206 super(VisionTransformer, self).__init__() 206 super(VisionTransformer, self).__init__()
207 self.num_classes = num_classes 207 self.num_classes = num_classes
208 self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 208 self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
209 self.num_tokens = 2 if distilled else 1 209 # self.num_tokens = 2 if distilled else 1
210 self.num_tokens = 0
210 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 211 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
211 act_layer = act_layer or nn.GELU 212 act_layer = act_layer or nn.GELU
212 213
...@@ -260,17 +261,17 @@ class VisionTransformer(nn.Module): ...@@ -260,17 +261,17 @@ class VisionTransformer(nn.Module):
260 # [1, 1, 768] -> [B, 1, 768] 261 # [1, 1, 768] -> [B, 1, 768]
261 262
262 # [B, 28+1, 8] 263 # [B, 28+1, 8]
263 cls_token = self.cls_token.expand(x.shape[0], -1, -1) 264 # cls_token = self.cls_token.expand(x.shape[0], -1, -1)
264 if self.dist_token is None: 265 # if self.dist_token is None:
265 x = torch.cat((cls_token, x), dim=1) # [B, 197, 768] 266 # x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
266 else: 267 # else:
267 x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 268 # x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
268 269
269 x = self.pos_drop(x + self.pos_embed) 270 x = self.pos_drop(x + self.pos_embed)
270 x = self.blocks(x) 271 x = self.blocks(x)
271 x = self.norm(x) 272 x = self.norm(x)
272 if self.dist_token is None: 273 if self.dist_token is None:
273 return self.pre_logits(x[:, 0]) 274 return self.pre_logits(x[:, -1])
274 else: 275 else:
275 return x[:, 0], x[:, 1] 276 return x[:, 0], x[:, 1]
276 277
......
...@@ -56,7 +56,7 @@ class VITSolver(object): ...@@ -56,7 +56,7 @@ class VITSolver(object):
56 for batch, (X, y) in enumerate(self.train_loader): 56 for batch, (X, y) in enumerate(self.train_loader):
57 X, y = X.to(self.device), y.to(self.device) 57 X, y = X.to(self.device), y.to(self.device)
58 58
59 pred = self.model(X) 59 pred = torch.nn.Softmax(dim=1)(self.model(X))
60 60
61 correct += self.evaluate(pred, y) 61 correct += self.evaluate(pred, y)
62 62
...@@ -85,7 +85,7 @@ class VITSolver(object): ...@@ -85,7 +85,7 @@ class VITSolver(object):
85 for X, y in self.val_loader: 85 for X, y in self.val_loader:
86 X, y = X.to(self.device), y.to(self.device) 86 X, y = X.to(self.device), y.to(self.device)
87 87
88 pred = self.model(X) 88 pred = torch.nn.Softmax(dim=1)(self.model(X))
89 89
90 correct += self.evaluate(pred, y) 90 correct += self.evaluate(pred, y)
91 91
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!