f3fbac1b by 周伟奇

modify vit cls_token

1 parent 0a93c10d
......@@ -55,10 +55,10 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
skip_list: list 跳过的图片列表
save_dir: str 数据集保存目录
"""
# if os.path.exists(save_dir):
# return
# else:
# os.makedirs(save_dir, exist_ok=True)
if os.path.exists(save_dir):
return
else:
os.makedirs(save_dir, exist_ok=True)
top_text_count = len(top_text_list)
for img_name in sorted(os.listdir(img_dir)):
......@@ -238,11 +238,11 @@ if __name__ == '__main__':
'CH-B102708352-2.jpg',
]
# build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
# build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
# build_anno_file(train_dataset_dir, train_anno_file_path)
build_anno_file(train_dataset_dir, train_anno_file_path)
# build_anno_file(valid_dataset_dir, valid_anno_file_path)
......
......@@ -206,7 +206,8 @@ class VisionTransformer(nn.Module):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
# self.num_tokens = 2 if distilled else 1
self.num_tokens = 0
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
......@@ -260,17 +261,17 @@ class VisionTransformer(nn.Module):
# [1, 1, 768] -> [B, 1, 768]
# [B, 28+1, 8]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
# cls_token = self.cls_token.expand(x.shape[0], -1, -1)
# if self.dist_token is None:
# x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
# else:
# x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
return self.pre_logits(x[:, -1])
else:
return x[:, 0], x[:, 1]
......
......@@ -56,7 +56,7 @@ class VITSolver(object):
for batch, (X, y) in enumerate(self.train_loader):
X, y = X.to(self.device), y.to(self.device)
pred = self.model(X)
pred = torch.nn.Softmax(dim=1)(self.model(X))
correct += self.evaluate(pred, y)
......@@ -85,7 +85,7 @@ class VITSolver(object):
for X, y in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
pred = self.model(X)
pred = torch.nn.Softmax(dim=1)(self.model(X))
correct += self.evaluate(pred, y)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!