add VIT
Showing
4 changed files
with
474 additions
and
0 deletions
config/vit.yaml
0 → 100644
1 | seed: 3407 | ||
2 | |||
3 | dataset: | ||
4 | name: 'CoordinatesData' | ||
5 | args: | ||
6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset' | ||
7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv' | ||
8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv' | ||
9 | |||
10 | dataloader: | ||
11 | batch_size: 32 | ||
12 | num_workers: 4 | ||
13 | pin_memory: true | ||
14 | shuffle: true | ||
15 | |||
16 | model: | ||
17 | name: 'VisionTransformer' | ||
18 | args: | ||
19 | img_size: 224 | ||
20 | patch_size: 16 | ||
21 | in_c: 3 | ||
22 | num_classes: 5 | ||
23 | embed_dim: 8 | ||
24 | depth: 12 | ||
25 | num_heads: 12 | ||
26 | mlp_ratio: 4.0 | ||
27 | qkv_bias: true | ||
28 | qk_scale: none | ||
29 | representation_size: none | ||
30 | distilled: false | ||
31 | drop_ratio: 0. | ||
32 | attn_drop_ratio: 0. | ||
33 | drop_path_ratio: 0. | ||
34 | norm_layer: none | ||
35 | act_layer: none | ||
36 | |||
37 | solver: | ||
38 | name: 'VITSolver' | ||
39 | args: | ||
40 | epoch: 100 | ||
41 | |||
42 | optimizer: | ||
43 | name: 'Adam' | ||
44 | args: | ||
45 | lr: !!float 1e-4 | ||
46 | weight_decay: !!float 5e-5 | ||
47 | |||
48 | lr_scheduler: | ||
49 | name: 'StepLR' | ||
50 | args: | ||
51 | step_size: 15 | ||
52 | gamma: 0.1 | ||
53 | |||
54 | loss: | ||
55 | name: 'SigmoidFocalLoss' | ||
56 | # name: 'CrossEntropyLoss' | ||
57 | args: | ||
58 | reduction: "mean" | ||
59 | |||
60 | logger: | ||
61 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ||
62 | suffix: 'vit' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
model/vit.py
0 → 100644
1 | from functools import partial | ||
2 | from collections import OrderedDict | ||
3 | |||
4 | import torch | ||
5 | import torch.nn as nn | ||
6 | from utils.registery import MODEL_REGISTRY | ||
7 | |||
8 | |||
9 | def _init_vit_weights(m): | ||
10 | """ | ||
11 | ViT weight initialization | ||
12 | :param m: module | ||
13 | """ | ||
14 | if isinstance(m, nn.Linear): | ||
15 | nn.init.trunc_normal_(m.weight, std=.01) | ||
16 | if m.bias is not None: | ||
17 | nn.init.zeros_(m.bias) | ||
18 | elif isinstance(m, nn.Conv2d): | ||
19 | nn.init.kaiming_normal_(m.weight, mode="fan_out") | ||
20 | if m.bias is not None: | ||
21 | nn.init.zeros_(m.bias) | ||
22 | elif isinstance(m, nn.LayerNorm): | ||
23 | nn.init.zeros_(m.bias) | ||
24 | nn.init.ones_(m.weight) | ||
25 | |||
26 | |||
27 | def drop_path(x, drop_prob: float = 0., training: bool = False): | ||
28 | """ | ||
29 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | ||
30 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | ||
31 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | ||
32 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | ||
33 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | ||
34 | 'survival rate' as the argument. | ||
35 | """ | ||
36 | if drop_prob == 0. or not training: | ||
37 | return x | ||
38 | keep_prob = 1 - drop_prob | ||
39 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | ||
40 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | ||
41 | random_tensor.floor_() # binarize | ||
42 | output = x.div(keep_prob) * random_tensor | ||
43 | return output | ||
44 | |||
45 | |||
46 | class DropPath(nn.Module): | ||
47 | """ | ||
48 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | ||
49 | """ | ||
50 | def __init__(self, drop_prob=None): | ||
51 | super(DropPath, self).__init__() | ||
52 | self.drop_prob = drop_prob | ||
53 | |||
54 | def forward(self, x): | ||
55 | return drop_path(x, self.drop_prob, self.training) | ||
56 | |||
57 | |||
58 | class Attention(nn.Module): | ||
59 | def __init__(self, | ||
60 | dim, # 输入token的dim | ||
61 | num_heads=8, | ||
62 | qkv_bias=False, | ||
63 | qk_scale=None, | ||
64 | attn_drop_ratio=0., | ||
65 | proj_drop_ratio=0.): | ||
66 | super(Attention, self).__init__() | ||
67 | self.num_heads = num_heads | ||
68 | head_dim = dim // num_heads | ||
69 | self.scale = qk_scale or head_dim ** -0.5 | ||
70 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||
71 | self.attn_drop = nn.Dropout(attn_drop_ratio) | ||
72 | self.proj = nn.Linear(dim, dim) | ||
73 | self.proj_drop = nn.Dropout(proj_drop_ratio) | ||
74 | |||
75 | def forward(self, x): | ||
76 | # [batch_size, num_patches + 1, total_embed_dim] | ||
77 | B, N, C = x.shape | ||
78 | |||
79 | # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] | ||
80 | # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head] | ||
81 | # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head] | ||
82 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | ||
83 | # [batch_size, num_heads, num_patches + 1, embed_dim_per_head] | ||
84 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | ||
85 | |||
86 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] | ||
87 | # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1] | ||
88 | attn = (q @ k.transpose(-2, -1)) * self.scale | ||
89 | attn = attn.softmax(dim=-1) | ||
90 | attn = self.attn_drop(attn) | ||
91 | |||
92 | # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head] | ||
93 | # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head] | ||
94 | # reshape: -> [batch_size, num_patches + 1, total_embed_dim] | ||
95 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) | ||
96 | x = self.proj(x) | ||
97 | x = self.proj_drop(x) | ||
98 | return x | ||
99 | |||
100 | |||
101 | class Mlp(nn.Module): | ||
102 | """ | ||
103 | MLP as used in Vision Transformer, MLP-Mixer and related networks | ||
104 | """ | ||
105 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | ||
106 | super().__init__() | ||
107 | out_features = out_features or in_features | ||
108 | hidden_features = hidden_features or in_features | ||
109 | self.fc1 = nn.Linear(in_features, hidden_features) | ||
110 | self.act = act_layer() | ||
111 | self.fc2 = nn.Linear(hidden_features, out_features) | ||
112 | self.drop = nn.Dropout(drop) | ||
113 | |||
114 | def forward(self, x): | ||
115 | x = self.fc1(x) | ||
116 | x = self.act(x) | ||
117 | x = self.drop(x) | ||
118 | x = self.fc2(x) | ||
119 | x = self.drop(x) | ||
120 | return x | ||
121 | |||
122 | |||
123 | class Block(nn.Module): | ||
124 | def __init__(self, | ||
125 | dim, | ||
126 | num_heads, | ||
127 | mlp_ratio=4., | ||
128 | qkv_bias=False, | ||
129 | qk_scale=None, | ||
130 | drop_ratio=0., | ||
131 | attn_drop_ratio=0., | ||
132 | drop_path_ratio=0., | ||
133 | act_layer=nn.GELU, | ||
134 | norm_layer=nn.LayerNorm): | ||
135 | super(Block, self).__init__() | ||
136 | self.norm1 = norm_layer(dim) | ||
137 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, | ||
138 | attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio) | ||
139 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | ||
140 | self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() | ||
141 | self.norm2 = norm_layer(dim) | ||
142 | mlp_hidden_dim = int(dim * mlp_ratio) | ||
143 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) | ||
144 | |||
145 | def forward(self, x): | ||
146 | x = x + self.drop_path(self.attn(self.norm1(x))) | ||
147 | x = x + self.drop_path(self.mlp(self.norm2(x))) | ||
148 | return x | ||
149 | |||
150 | |||
151 | class PatchEmbed(nn.Module): | ||
152 | """ | ||
153 | 2D Image to Patch Embedding | ||
154 | """ | ||
155 | def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None): | ||
156 | super().__init__() | ||
157 | img_size = (img_size, img_size) | ||
158 | patch_size = (patch_size, patch_size) | ||
159 | self.img_size = img_size | ||
160 | self.patch_size = patch_size | ||
161 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) | ||
162 | self.num_patches = self.grid_size[0] * self.grid_size[1] | ||
163 | |||
164 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) | ||
165 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | ||
166 | |||
167 | def forward(self, x): | ||
168 | B, C, H, W = x.shape | ||
169 | assert H == self.img_size[0] and W == self.img_size[1], \ | ||
170 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." | ||
171 | |||
172 | # flatten: [B, C, H, W] -> [B, C, HW] | ||
173 | # transpose: [B, C, HW] -> [B, HW, C] | ||
174 | x = self.proj(x).flatten(2).transpose(1, 2) | ||
175 | x = self.norm(x) | ||
176 | return x | ||
177 | |||
178 | |||
179 | @MODEL_REGISTRY.register() | ||
180 | class VisionTransformer(nn.Module): | ||
181 | def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, | ||
182 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, | ||
183 | qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., | ||
184 | attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, | ||
185 | act_layer=None): | ||
186 | """ | ||
187 | Args: | ||
188 | img_size (int, tuple): input image size | ||
189 | patch_size (int, tuple): patch size | ||
190 | in_c (int): number of input channels | ||
191 | num_classes (int): number of classes for classification head | ||
192 | embed_dim (int): embedding dimension | ||
193 | depth (int): depth of transformer | ||
194 | num_heads (int): number of attention heads | ||
195 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim | ||
196 | qkv_bias (bool): enable bias for qkv if True | ||
197 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set | ||
198 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set | ||
199 | distilled (bool): model includes a distillation token and head as in DeiT models | ||
200 | drop_ratio (float): dropout rate | ||
201 | attn_drop_ratio (float): attention dropout rate | ||
202 | drop_path_ratio (float): stochastic depth rate | ||
203 | embed_layer (nn.Module): patch embedding layer | ||
204 | norm_layer: (nn.Module): normalization layer | ||
205 | """ | ||
206 | super(VisionTransformer, self).__init__() | ||
207 | self.num_classes = num_classes | ||
208 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | ||
209 | self.num_tokens = 2 if distilled else 1 | ||
210 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | ||
211 | act_layer = act_layer or nn.GELU | ||
212 | |||
213 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim) | ||
214 | num_patches = self.patch_embed.num_patches | ||
215 | |||
216 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | ||
217 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None | ||
218 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) | ||
219 | self.pos_drop = nn.Dropout(p=drop_ratio) | ||
220 | |||
221 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule | ||
222 | self.blocks = nn.Sequential(*[ | ||
223 | Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, | ||
224 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], | ||
225 | norm_layer=norm_layer, act_layer=act_layer) | ||
226 | for i in range(depth) | ||
227 | ]) | ||
228 | self.norm = norm_layer(embed_dim) | ||
229 | |||
230 | # Representation layer | ||
231 | if representation_size and not distilled: | ||
232 | self.has_logits = True | ||
233 | self.num_features = representation_size | ||
234 | self.pre_logits = nn.Sequential(OrderedDict([ | ||
235 | ("fc", nn.Linear(embed_dim, representation_size)), | ||
236 | ("act", nn.Tanh()) | ||
237 | ])) | ||
238 | else: | ||
239 | self.has_logits = False | ||
240 | self.pre_logits = nn.Identity() | ||
241 | |||
242 | # Classifier head(s) | ||
243 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() | ||
244 | self.head_dist = None | ||
245 | if distilled: | ||
246 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() | ||
247 | |||
248 | # Weight init | ||
249 | nn.init.trunc_normal_(self.pos_embed, std=0.02) | ||
250 | if self.dist_token is not None: | ||
251 | nn.init.trunc_normal_(self.dist_token, std=0.02) | ||
252 | |||
253 | nn.init.trunc_normal_(self.cls_token, std=0.02) | ||
254 | self.apply(_init_vit_weights) | ||
255 | |||
256 | def forward_features(self, x): | ||
257 | # [B, C, H, W] -> [B, num_patches, embed_dim] | ||
258 | # x = self.patch_embed(x) # [B, 196, 768] | ||
259 | # [1, 1, 768] -> [B, 1, 768] | ||
260 | |||
261 | # [B, 28+1, 8] | ||
262 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) | ||
263 | if self.dist_token is None: | ||
264 | x = torch.cat((cls_token, x), dim=1) # [B, 197, 768] | ||
265 | else: | ||
266 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) | ||
267 | |||
268 | x = self.pos_drop(x + self.pos_embed) | ||
269 | x = self.blocks(x) | ||
270 | x = self.norm(x) | ||
271 | if self.dist_token is None: | ||
272 | return self.pre_logits(x[:, 0]) | ||
273 | else: | ||
274 | return x[:, 0], x[:, 1] | ||
275 | |||
276 | def forward(self, x): | ||
277 | x = self.forward_features(x) | ||
278 | if self.head_dist is not None: | ||
279 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) | ||
280 | if self.training and not torch.jit.is_scripting(): | ||
281 | # during inference, return the average of both classifier predictions | ||
282 | return x, x_dist | ||
283 | else: | ||
284 | return (x + x_dist) / 2 | ||
285 | else: | ||
286 | x = self.head(x) | ||
287 | return x |
solver/vit_solver.py
0 → 100644
1 | import os | ||
2 | import copy | ||
3 | import torch | ||
4 | |||
5 | from model import build_model | ||
6 | from data import build_dataloader | ||
7 | from optimizer import build_optimizer, build_lr_scheduler | ||
8 | from loss import build_loss | ||
9 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ||
10 | |||
11 | |||
12 | @SOLVER_REGISTRY.register() | ||
13 | class VITSolver(object): | ||
14 | |||
15 | def __init__(self, cfg): | ||
16 | self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
17 | |||
18 | self.cfg = copy.deepcopy(cfg) | ||
19 | |||
20 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
21 | self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) | ||
22 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | ||
23 | |||
24 | # BatchNorm ? | ||
25 | self.model = build_model(cfg).to(self.device) | ||
26 | |||
27 | self.loss_fn = build_loss(cfg) | ||
28 | |||
29 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
30 | |||
31 | self.hyper_params = cfg['solver']['args'] | ||
32 | try: | ||
33 | self.epoch = self.hyper_params['epoch'] | ||
34 | except Exception: | ||
35 | raise 'should contain epoch in {solver.args}' | ||
36 | |||
37 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | ||
38 | |||
39 | @staticmethod | ||
40 | def evaluate(y_pred, y_true, thresholds=0.5): | ||
41 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
42 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
43 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
44 | |||
45 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
46 | y_true_is_other = torch.sum(y_true, dim=1) | ||
47 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
48 | |||
49 | return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() | ||
50 | |||
51 | def train_loop(self): | ||
52 | self.model.train() | ||
53 | |||
54 | train_loss = torch.zeros(1).to(self.device) | ||
55 | correct = torch.zeros(1).to(self.device) | ||
56 | for batch, (X, y) in enumerate(self.train_loader): | ||
57 | X, y = X.to(self.device), y.to(self.device) | ||
58 | |||
59 | pred = self.model(X) | ||
60 | |||
61 | correct += self.evaluate(pred, y) | ||
62 | |||
63 | # loss = self.loss_fn(pred, y, reduction="mean") | ||
64 | loss = self.loss_fn(pred, y) | ||
65 | train_loss += loss.item() | ||
66 | |||
67 | if batch % 100 == 0: | ||
68 | loss_value, current = loss.item(), batch | ||
69 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | ||
70 | |||
71 | self.optimizer.zero_grad() | ||
72 | loss.backward() | ||
73 | self.optimizer.step() | ||
74 | |||
75 | correct /= self.train_dataset_size | ||
76 | train_loss /= self.train_loader_size | ||
77 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') | ||
78 | |||
79 | @torch.no_grad() | ||
80 | def val_loop(self, t): | ||
81 | self.model.eval() | ||
82 | |||
83 | val_loss = torch.zeros(1).to(self.device) | ||
84 | correct = torch.zeros(1).to(self.device) | ||
85 | for X, y in self.val_loader: | ||
86 | X, y = X.to(self.device), y.to(self.device) | ||
87 | |||
88 | pred = self.model(X) | ||
89 | |||
90 | correct += self.evaluate(pred, y) | ||
91 | |||
92 | loss = self.loss_fn(pred, y) | ||
93 | val_loss += loss.item() | ||
94 | |||
95 | correct /= self.val_dataset_size | ||
96 | val_loss /= self.val_loader_size | ||
97 | |||
98 | self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") | ||
99 | |||
100 | def save_checkpoint(self, epoch_id): | ||
101 | self.model.eval() | ||
102 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | ||
103 | |||
104 | def run(self): | ||
105 | self.logger.info('==> Start Training') | ||
106 | print(self.model) | ||
107 | |||
108 | # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
109 | |||
110 | for t in range(self.epoch): | ||
111 | self.logger.info(f'==> epoch {t + 1}') | ||
112 | |||
113 | self.train_loop() | ||
114 | self.val_loop(t + 1) | ||
115 | self.save_checkpoint(t + 1) | ||
116 | |||
117 | # lr_scheduler.step() | ||
118 | |||
119 | self.logger.info('==> End Training') |
-
Please register or sign in to post a comment