From f3fbac1b9edf15ebb09eab8dac5c8fdcd10cd5fb Mon Sep 17 00:00:00 2001 From: zhouweiqi <zhouweiqi@situdata.com> Date: Wed, 14 Dec 2022 16:47:11 +0800 Subject: [PATCH] modify vit cls_token --- data/create_dataset.py | 12 ++++++------ model/vit.py | 15 ++++++++------- solver/vit_solver.py | 4 ++-- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/data/create_dataset.py b/data/create_dataset.py index b5990ba..473b8bb 100644 --- a/data/create_dataset.py +++ b/data/create_dataset.py @@ -55,10 +55,10 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save skip_list: list 跳过的图片列表 save_dir: str 数据集保存目录 """ - # if os.path.exists(save_dir): - # return - # else: - # os.makedirs(save_dir, exist_ok=True) + if os.path.exists(save_dir): + return + else: + os.makedirs(save_dir, exist_ok=True) top_text_count = len(top_text_list) for img_name in sorted(os.listdir(img_dir)): @@ -238,11 +238,11 @@ if __name__ == '__main__': 'CH-B102708352-2.jpg', ] - # build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) + build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) # build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir) - # build_anno_file(train_dataset_dir, train_anno_file_path) + build_anno_file(train_dataset_dir, train_anno_file_path) # build_anno_file(valid_dataset_dir, valid_anno_file_path) diff --git a/model/vit.py b/model/vit.py index 1875141..3adf52e 100644 --- a/model/vit.py +++ b/model/vit.py @@ -206,7 +206,8 @@ class VisionTransformer(nn.Module): super(VisionTransformer, self).__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 2 if distilled else 1 + # self.num_tokens = 2 if distilled else 1 + self.num_tokens = 0 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU @@ -260,17 +261,17 @@ class VisionTransformer(nn.Module): # [1, 1, 768] -> [B, 1, 768] # [B, 28+1, 8] - cls_token = self.cls_token.expand(x.shape[0], -1, -1) - if self.dist_token is None: - x = torch.cat((cls_token, x), dim=1) # [B, 197, 768] - else: - x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + # cls_token = self.cls_token.expand(x.shape[0], -1, -1) + # if self.dist_token is None: + # x = torch.cat((cls_token, x), dim=1) # [B, 197, 768] + # else: + # x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) x = self.pos_drop(x + self.pos_embed) x = self.blocks(x) x = self.norm(x) if self.dist_token is None: - return self.pre_logits(x[:, 0]) + return self.pre_logits(x[:, -1]) else: return x[:, 0], x[:, 1] diff --git a/solver/vit_solver.py b/solver/vit_solver.py index 56f298b..8ec5ab8 100644 --- a/solver/vit_solver.py +++ b/solver/vit_solver.py @@ -56,7 +56,7 @@ class VITSolver(object): for batch, (X, y) in enumerate(self.train_loader): X, y = X.to(self.device), y.to(self.device) - pred = self.model(X) + pred = torch.nn.Softmax(dim=1)(self.model(X)) correct += self.evaluate(pred, y) @@ -85,7 +85,7 @@ class VITSolver(object): for X, y in self.val_loader: X, y = X.to(self.device), y.to(self.device) - pred = self.model(X) + pred = torch.nn.Softmax(dim=1)(self.model(X)) correct += self.evaluate(pred, y) -- libgit2 0.24.0