From f3fbac1b9edf15ebb09eab8dac5c8fdcd10cd5fb Mon Sep 17 00:00:00 2001
From: zhouweiqi <zhouweiqi@situdata.com>
Date: Wed, 14 Dec 2022 16:47:11 +0800
Subject: [PATCH] modify vit cls_token

---
 data/create_dataset.py | 12 ++++++------
 model/vit.py           | 15 ++++++++-------
 solver/vit_solver.py   |  4 ++--
 3 files changed, 16 insertions(+), 15 deletions(-)

diff --git a/data/create_dataset.py b/data/create_dataset.py
index b5990ba..473b8bb 100644
--- a/data/create_dataset.py
+++ b/data/create_dataset.py
@@ -55,10 +55,10 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
         skip_list: list 跳过的图片列表
         save_dir: str 数据集保存目录
     """
-    # if os.path.exists(save_dir):
-    #     return
-    # else:
-    #     os.makedirs(save_dir, exist_ok=True)
+    if os.path.exists(save_dir):
+        return
+    else:
+        os.makedirs(save_dir, exist_ok=True)
 
     top_text_count = len(top_text_list)
     for img_name in sorted(os.listdir(img_dir)):
@@ -238,11 +238,11 @@ if __name__ == '__main__':
         'CH-B102708352-2.jpg',
     ]
 
-    # build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
+    build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
 
     # build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
 
-    # build_anno_file(train_dataset_dir, train_anno_file_path)
+    build_anno_file(train_dataset_dir, train_anno_file_path)
     # build_anno_file(valid_dataset_dir, valid_anno_file_path)
 
     
diff --git a/model/vit.py b/model/vit.py
index 1875141..3adf52e 100644
--- a/model/vit.py
+++ b/model/vit.py
@@ -206,7 +206,8 @@ class VisionTransformer(nn.Module):
         super(VisionTransformer, self).__init__()
         self.num_classes = num_classes
         self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
-        self.num_tokens = 2 if distilled else 1
+        # self.num_tokens = 2 if distilled else 1
+        self.num_tokens = 0 
         norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
         act_layer = act_layer or nn.GELU
 
@@ -260,17 +261,17 @@ class VisionTransformer(nn.Module):
         # [1, 1, 768] -> [B, 1, 768]
 
         # [B, 28+1, 8]
-        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
-        if self.dist_token is None:
-            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
-        else:
-            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
+        # cls_token = self.cls_token.expand(x.shape[0], -1, -1)
+        # if self.dist_token is None:
+        #     x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
+        # else:
+        #     x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
 
         x = self.pos_drop(x + self.pos_embed)
         x = self.blocks(x)
         x = self.norm(x)
         if self.dist_token is None:
-            return self.pre_logits(x[:, 0])
+            return self.pre_logits(x[:, -1])
         else:
             return x[:, 0], x[:, 1]
 
diff --git a/solver/vit_solver.py b/solver/vit_solver.py
index 56f298b..8ec5ab8 100644
--- a/solver/vit_solver.py
+++ b/solver/vit_solver.py
@@ -56,7 +56,7 @@ class VITSolver(object):
         for batch, (X, y) in enumerate(self.train_loader):
             X, y = X.to(self.device), y.to(self.device)
 
-            pred = self.model(X)
+            pred = torch.nn.Softmax(dim=1)(self.model(X))
 
             correct += self.evaluate(pred, y)
 
@@ -85,7 +85,7 @@ class VITSolver(object):
         for X, y in self.val_loader:
             X, y = X.to(self.device), y.to(self.device)
 
-            pred = self.model(X)
+            pred = torch.nn.Softmax(dim=1)(self.model(X))
 
             correct += self.evaluate(pred, y)
 
--
libgit2 0.24.0