0a93c10d by 周伟奇

fix bug

1 parent d865f629
...@@ -22,17 +22,18 @@ model: ...@@ -22,17 +22,18 @@ model:
22 num_classes: 5 22 num_classes: 5
23 embed_dim: 8 23 embed_dim: 8
24 depth: 12 24 depth: 12
25 num_heads: 12 25 num_heads: 2
26 mlp_ratio: 4.0 26 mlp_ratio: 4.0
27 qkv_bias: true 27 qkv_bias: true
28 qk_scale: none 28 qk_scale: null
29 representation_size: none 29 representation_size: null
30 distilled: false 30 distilled: false
31 drop_ratio: 0. 31 drop_ratio: 0.
32 attn_drop_ratio: 0. 32 attn_drop_ratio: 0.
33 drop_path_ratio: 0. 33 drop_path_ratio: 0.
34 norm_layer: none 34 norm_layer: null
35 act_layer: none 35 act_layer: null
36 input_length: 29
36 37
37 solver: 38 solver:
38 name: 'VITSolver' 39 name: 'VITSolver'
......
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
2 from utils import MODEL_REGISTRY 2 from utils import MODEL_REGISTRY
3 3
4 from .mlp import MLPModel 4 from .mlp import MLPModel
5 from .vit import VisionTransformer
5 6
6 7
7 def build_model(cfg): 8 def build_model(cfg):
......
...@@ -182,7 +182,7 @@ class VisionTransformer(nn.Module): ...@@ -182,7 +182,7 @@ class VisionTransformer(nn.Module):
182 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, 182 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
183 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., 183 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
184 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, 184 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
185 act_layer=None): 185 act_layer=None, input_length=29):
186 """ 186 """
187 Args: 187 Args:
188 img_size (int, tuple): input image size 188 img_size (int, tuple): input image size
...@@ -215,7 +215,8 @@ class VisionTransformer(nn.Module): ...@@ -215,7 +215,8 @@ class VisionTransformer(nn.Module):
215 215
216 self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 216 self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
217 self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 217 self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
218 self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 218 # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
219 self.pos_embed = nn.Parameter(torch.zeros(1, input_length + self.num_tokens, embed_dim))
219 self.pos_drop = nn.Dropout(p=drop_ratio) 220 self.pos_drop = nn.Dropout(p=drop_ratio)
220 221
221 dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule 222 dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
......
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
2 2
3 from utils.registery import SOLVER_REGISTRY 3 from utils.registery import SOLVER_REGISTRY
4 from .mlp_solver import MLPSolver 4 from .mlp_solver import MLPSolver
5 from .vit_solver import VITSolver
5 6
6 7
7 def build_solver(cfg): 8 def build_solver(cfg):
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!