modify vit cls_token
Showing
3 changed files
with
16 additions
and
15 deletions
... | @@ -55,10 +55,10 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -55,10 +55,10 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
55 | skip_list: list 跳过的图片列表 | 55 | skip_list: list 跳过的图片列表 |
56 | save_dir: str 数据集保存目录 | 56 | save_dir: str 数据集保存目录 |
57 | """ | 57 | """ |
58 | # if os.path.exists(save_dir): | 58 | if os.path.exists(save_dir): |
59 | # return | 59 | return |
60 | # else: | 60 | else: |
61 | # os.makedirs(save_dir, exist_ok=True) | 61 | os.makedirs(save_dir, exist_ok=True) |
62 | 62 | ||
63 | top_text_count = len(top_text_list) | 63 | top_text_count = len(top_text_list) |
64 | for img_name in sorted(os.listdir(img_dir)): | 64 | for img_name in sorted(os.listdir(img_dir)): |
... | @@ -238,11 +238,11 @@ if __name__ == '__main__': | ... | @@ -238,11 +238,11 @@ if __name__ == '__main__': |
238 | 'CH-B102708352-2.jpg', | 238 | 'CH-B102708352-2.jpg', |
239 | ] | 239 | ] |
240 | 240 | ||
241 | # build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) | 241 | build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) |
242 | 242 | ||
243 | # build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir) | 243 | # build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir) |
244 | 244 | ||
245 | # build_anno_file(train_dataset_dir, train_anno_file_path) | 245 | build_anno_file(train_dataset_dir, train_anno_file_path) |
246 | # build_anno_file(valid_dataset_dir, valid_anno_file_path) | 246 | # build_anno_file(valid_dataset_dir, valid_anno_file_path) |
247 | 247 | ||
248 | 248 | ... | ... |
... | @@ -206,7 +206,8 @@ class VisionTransformer(nn.Module): | ... | @@ -206,7 +206,8 @@ class VisionTransformer(nn.Module): |
206 | super(VisionTransformer, self).__init__() | 206 | super(VisionTransformer, self).__init__() |
207 | self.num_classes = num_classes | 207 | self.num_classes = num_classes |
208 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | 208 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models |
209 | self.num_tokens = 2 if distilled else 1 | 209 | # self.num_tokens = 2 if distilled else 1 |
210 | self.num_tokens = 0 | ||
210 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | 211 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
211 | act_layer = act_layer or nn.GELU | 212 | act_layer = act_layer or nn.GELU |
212 | 213 | ||
... | @@ -260,17 +261,17 @@ class VisionTransformer(nn.Module): | ... | @@ -260,17 +261,17 @@ class VisionTransformer(nn.Module): |
260 | # [1, 1, 768] -> [B, 1, 768] | 261 | # [1, 1, 768] -> [B, 1, 768] |
261 | 262 | ||
262 | # [B, 28+1, 8] | 263 | # [B, 28+1, 8] |
263 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) | 264 | # cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
264 | if self.dist_token is None: | 265 | # if self.dist_token is None: |
265 | x = torch.cat((cls_token, x), dim=1) # [B, 197, 768] | 266 | # x = torch.cat((cls_token, x), dim=1) # [B, 197, 768] |
266 | else: | 267 | # else: |
267 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) | 268 | # x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) |
268 | 269 | ||
269 | x = self.pos_drop(x + self.pos_embed) | 270 | x = self.pos_drop(x + self.pos_embed) |
270 | x = self.blocks(x) | 271 | x = self.blocks(x) |
271 | x = self.norm(x) | 272 | x = self.norm(x) |
272 | if self.dist_token is None: | 273 | if self.dist_token is None: |
273 | return self.pre_logits(x[:, 0]) | 274 | return self.pre_logits(x[:, -1]) |
274 | else: | 275 | else: |
275 | return x[:, 0], x[:, 1] | 276 | return x[:, 0], x[:, 1] |
276 | 277 | ... | ... |
... | @@ -56,7 +56,7 @@ class VITSolver(object): | ... | @@ -56,7 +56,7 @@ class VITSolver(object): |
56 | for batch, (X, y) in enumerate(self.train_loader): | 56 | for batch, (X, y) in enumerate(self.train_loader): |
57 | X, y = X.to(self.device), y.to(self.device) | 57 | X, y = X.to(self.device), y.to(self.device) |
58 | 58 | ||
59 | pred = self.model(X) | 59 | pred = torch.nn.Softmax(dim=1)(self.model(X)) |
60 | 60 | ||
61 | correct += self.evaluate(pred, y) | 61 | correct += self.evaluate(pred, y) |
62 | 62 | ||
... | @@ -85,7 +85,7 @@ class VITSolver(object): | ... | @@ -85,7 +85,7 @@ class VITSolver(object): |
85 | for X, y in self.val_loader: | 85 | for X, y in self.val_loader: |
86 | X, y = X.to(self.device), y.to(self.device) | 86 | X, y = X.to(self.device), y.to(self.device) |
87 | 87 | ||
88 | pred = self.model(X) | 88 | pred = torch.nn.Softmax(dim=1)(self.model(X)) |
89 | 89 | ||
90 | correct += self.evaluate(pred, y) | 90 | correct += self.evaluate(pred, y) |
91 | 91 | ... | ... |
-
Please register or sign in to post a comment