fix bug
Showing
4 changed files
with
11 additions
and
7 deletions
... | @@ -22,17 +22,18 @@ model: | ... | @@ -22,17 +22,18 @@ model: |
22 | num_classes: 5 | 22 | num_classes: 5 |
23 | embed_dim: 8 | 23 | embed_dim: 8 |
24 | depth: 12 | 24 | depth: 12 |
25 | num_heads: 12 | 25 | num_heads: 2 |
26 | mlp_ratio: 4.0 | 26 | mlp_ratio: 4.0 |
27 | qkv_bias: true | 27 | qkv_bias: true |
28 | qk_scale: none | 28 | qk_scale: null |
29 | representation_size: none | 29 | representation_size: null |
30 | distilled: false | 30 | distilled: false |
31 | drop_ratio: 0. | 31 | drop_ratio: 0. |
32 | attn_drop_ratio: 0. | 32 | attn_drop_ratio: 0. |
33 | drop_path_ratio: 0. | 33 | drop_path_ratio: 0. |
34 | norm_layer: none | 34 | norm_layer: null |
35 | act_layer: none | 35 | act_layer: null |
36 | input_length: 29 | ||
36 | 37 | ||
37 | solver: | 38 | solver: |
38 | name: 'VITSolver' | 39 | name: 'VITSolver' | ... | ... |
... | @@ -182,7 +182,7 @@ class VisionTransformer(nn.Module): | ... | @@ -182,7 +182,7 @@ class VisionTransformer(nn.Module): |
182 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, | 182 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, |
183 | qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., | 183 | qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., |
184 | attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, | 184 | attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, |
185 | act_layer=None): | 185 | act_layer=None, input_length=29): |
186 | """ | 186 | """ |
187 | Args: | 187 | Args: |
188 | img_size (int, tuple): input image size | 188 | img_size (int, tuple): input image size |
... | @@ -215,7 +215,8 @@ class VisionTransformer(nn.Module): | ... | @@ -215,7 +215,8 @@ class VisionTransformer(nn.Module): |
215 | 215 | ||
216 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | 216 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
217 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None | 217 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None |
218 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) | 218 | # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) |
219 | self.pos_embed = nn.Parameter(torch.zeros(1, input_length + self.num_tokens, embed_dim)) | ||
219 | self.pos_drop = nn.Dropout(p=drop_ratio) | 220 | self.pos_drop = nn.Dropout(p=drop_ratio) |
220 | 221 | ||
221 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule | 222 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule | ... | ... |
... | @@ -2,6 +2,7 @@ import copy | ... | @@ -2,6 +2,7 @@ import copy |
2 | 2 | ||
3 | from utils.registery import SOLVER_REGISTRY | 3 | from utils.registery import SOLVER_REGISTRY |
4 | from .mlp_solver import MLPSolver | 4 | from .mlp_solver import MLPSolver |
5 | from .vit_solver import VITSolver | ||
5 | 6 | ||
6 | 7 | ||
7 | def build_solver(cfg): | 8 | def build_solver(cfg): | ... | ... |
-
Please register or sign in to post a comment