0a93c10d by 周伟奇

fix bug

1 parent d865f629
......@@ -22,17 +22,18 @@ model:
num_classes: 5
embed_dim: 8
depth: 12
num_heads: 12
num_heads: 2
mlp_ratio: 4.0
qkv_bias: true
qk_scale: none
representation_size: none
qk_scale: null
representation_size: null
distilled: false
drop_ratio: 0.
attn_drop_ratio: 0.
drop_path_ratio: 0.
norm_layer: none
act_layer: none
norm_layer: null
act_layer: null
input_length: 29
solver:
name: 'VITSolver'
......
......@@ -2,6 +2,7 @@ import copy
from utils import MODEL_REGISTRY
from .mlp import MLPModel
from .vit import VisionTransformer
def build_model(cfg):
......
......@@ -182,7 +182,7 @@ class VisionTransformer(nn.Module):
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None):
act_layer=None, input_length=29):
"""
Args:
img_size (int, tuple): input image size
......@@ -215,7 +215,8 @@ class VisionTransformer(nn.Module):
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
# self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, input_length + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
......
......@@ -2,6 +2,7 @@ import copy
from utils.registery import SOLVER_REGISTRY
from .mlp_solver import MLPSolver
from .vit_solver import VITSolver
def build_solver(cfg):
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!