d865f629 by 周伟奇

add VIT

1 parent c424aba7
1 .DS_Store 1 .DS_Store
2 logs/ 2 logs/
3
4 __pycache__
5
6 *.log
7
8 dataset/
......
1 seed: 3407
2
3 dataset:
4 name: 'CoordinatesData'
5 args:
6 data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset'
7 train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv'
8 val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv'
9
10 dataloader:
11 batch_size: 32
12 num_workers: 4
13 pin_memory: true
14 shuffle: true
15
16 model:
17 name: 'VisionTransformer'
18 args:
19 img_size: 224
20 patch_size: 16
21 in_c: 3
22 num_classes: 5
23 embed_dim: 8
24 depth: 12
25 num_heads: 12
26 mlp_ratio: 4.0
27 qkv_bias: true
28 qk_scale: none
29 representation_size: none
30 distilled: false
31 drop_ratio: 0.
32 attn_drop_ratio: 0.
33 drop_path_ratio: 0.
34 norm_layer: none
35 act_layer: none
36
37 solver:
38 name: 'VITSolver'
39 args:
40 epoch: 100
41
42 optimizer:
43 name: 'Adam'
44 args:
45 lr: !!float 1e-4
46 weight_decay: !!float 5e-5
47
48 lr_scheduler:
49 name: 'StepLR'
50 args:
51 step_size: 15
52 gamma: 0.1
53
54 loss:
55 name: 'SigmoidFocalLoss'
56 # name: 'CrossEntropyLoss'
57 args:
58 reduction: "mean"
59
60 logger:
61 log_root: '/Users/zhouweiqi/Downloads/test/logs'
62 suffix: 'vit'
...\ No newline at end of file ...\ No newline at end of file
1 from functools import partial
2 from collections import OrderedDict
3
4 import torch
5 import torch.nn as nn
6 from utils.registery import MODEL_REGISTRY
7
8
9 def _init_vit_weights(m):
10 """
11 ViT weight initialization
12 :param m: module
13 """
14 if isinstance(m, nn.Linear):
15 nn.init.trunc_normal_(m.weight, std=.01)
16 if m.bias is not None:
17 nn.init.zeros_(m.bias)
18 elif isinstance(m, nn.Conv2d):
19 nn.init.kaiming_normal_(m.weight, mode="fan_out")
20 if m.bias is not None:
21 nn.init.zeros_(m.bias)
22 elif isinstance(m, nn.LayerNorm):
23 nn.init.zeros_(m.bias)
24 nn.init.ones_(m.weight)
25
26
27 def drop_path(x, drop_prob: float = 0., training: bool = False):
28 """
29 Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
30 This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
31 the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
32 See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
33 changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
34 'survival rate' as the argument.
35 """
36 if drop_prob == 0. or not training:
37 return x
38 keep_prob = 1 - drop_prob
39 shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
40 random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
41 random_tensor.floor_() # binarize
42 output = x.div(keep_prob) * random_tensor
43 return output
44
45
46 class DropPath(nn.Module):
47 """
48 Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
49 """
50 def __init__(self, drop_prob=None):
51 super(DropPath, self).__init__()
52 self.drop_prob = drop_prob
53
54 def forward(self, x):
55 return drop_path(x, self.drop_prob, self.training)
56
57
58 class Attention(nn.Module):
59 def __init__(self,
60 dim, # 输入token的dim
61 num_heads=8,
62 qkv_bias=False,
63 qk_scale=None,
64 attn_drop_ratio=0.,
65 proj_drop_ratio=0.):
66 super(Attention, self).__init__()
67 self.num_heads = num_heads
68 head_dim = dim // num_heads
69 self.scale = qk_scale or head_dim ** -0.5
70 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
71 self.attn_drop = nn.Dropout(attn_drop_ratio)
72 self.proj = nn.Linear(dim, dim)
73 self.proj_drop = nn.Dropout(proj_drop_ratio)
74
75 def forward(self, x):
76 # [batch_size, num_patches + 1, total_embed_dim]
77 B, N, C = x.shape
78
79 # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
80 # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
81 # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
82 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
83 # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
84 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
85
86 # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
87 # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
88 attn = (q @ k.transpose(-2, -1)) * self.scale
89 attn = attn.softmax(dim=-1)
90 attn = self.attn_drop(attn)
91
92 # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
93 # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
94 # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
95 x = (attn @ v).transpose(1, 2).reshape(B, N, C)
96 x = self.proj(x)
97 x = self.proj_drop(x)
98 return x
99
100
101 class Mlp(nn.Module):
102 """
103 MLP as used in Vision Transformer, MLP-Mixer and related networks
104 """
105 def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
106 super().__init__()
107 out_features = out_features or in_features
108 hidden_features = hidden_features or in_features
109 self.fc1 = nn.Linear(in_features, hidden_features)
110 self.act = act_layer()
111 self.fc2 = nn.Linear(hidden_features, out_features)
112 self.drop = nn.Dropout(drop)
113
114 def forward(self, x):
115 x = self.fc1(x)
116 x = self.act(x)
117 x = self.drop(x)
118 x = self.fc2(x)
119 x = self.drop(x)
120 return x
121
122
123 class Block(nn.Module):
124 def __init__(self,
125 dim,
126 num_heads,
127 mlp_ratio=4.,
128 qkv_bias=False,
129 qk_scale=None,
130 drop_ratio=0.,
131 attn_drop_ratio=0.,
132 drop_path_ratio=0.,
133 act_layer=nn.GELU,
134 norm_layer=nn.LayerNorm):
135 super(Block, self).__init__()
136 self.norm1 = norm_layer(dim)
137 self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
138 attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
139 # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
140 self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
141 self.norm2 = norm_layer(dim)
142 mlp_hidden_dim = int(dim * mlp_ratio)
143 self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
144
145 def forward(self, x):
146 x = x + self.drop_path(self.attn(self.norm1(x)))
147 x = x + self.drop_path(self.mlp(self.norm2(x)))
148 return x
149
150
151 class PatchEmbed(nn.Module):
152 """
153 2D Image to Patch Embedding
154 """
155 def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
156 super().__init__()
157 img_size = (img_size, img_size)
158 patch_size = (patch_size, patch_size)
159 self.img_size = img_size
160 self.patch_size = patch_size
161 self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
162 self.num_patches = self.grid_size[0] * self.grid_size[1]
163
164 self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
165 self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
166
167 def forward(self, x):
168 B, C, H, W = x.shape
169 assert H == self.img_size[0] and W == self.img_size[1], \
170 f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
171
172 # flatten: [B, C, H, W] -> [B, C, HW]
173 # transpose: [B, C, HW] -> [B, HW, C]
174 x = self.proj(x).flatten(2).transpose(1, 2)
175 x = self.norm(x)
176 return x
177
178
179 @MODEL_REGISTRY.register()
180 class VisionTransformer(nn.Module):
181 def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
182 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
183 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
184 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
185 act_layer=None):
186 """
187 Args:
188 img_size (int, tuple): input image size
189 patch_size (int, tuple): patch size
190 in_c (int): number of input channels
191 num_classes (int): number of classes for classification head
192 embed_dim (int): embedding dimension
193 depth (int): depth of transformer
194 num_heads (int): number of attention heads
195 mlp_ratio (int): ratio of mlp hidden dim to embedding dim
196 qkv_bias (bool): enable bias for qkv if True
197 qk_scale (float): override default qk scale of head_dim ** -0.5 if set
198 representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
199 distilled (bool): model includes a distillation token and head as in DeiT models
200 drop_ratio (float): dropout rate
201 attn_drop_ratio (float): attention dropout rate
202 drop_path_ratio (float): stochastic depth rate
203 embed_layer (nn.Module): patch embedding layer
204 norm_layer: (nn.Module): normalization layer
205 """
206 super(VisionTransformer, self).__init__()
207 self.num_classes = num_classes
208 self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
209 self.num_tokens = 2 if distilled else 1
210 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
211 act_layer = act_layer or nn.GELU
212
213 self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
214 num_patches = self.patch_embed.num_patches
215
216 self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
217 self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
218 self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
219 self.pos_drop = nn.Dropout(p=drop_ratio)
220
221 dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
222 self.blocks = nn.Sequential(*[
223 Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
224 drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
225 norm_layer=norm_layer, act_layer=act_layer)
226 for i in range(depth)
227 ])
228 self.norm = norm_layer(embed_dim)
229
230 # Representation layer
231 if representation_size and not distilled:
232 self.has_logits = True
233 self.num_features = representation_size
234 self.pre_logits = nn.Sequential(OrderedDict([
235 ("fc", nn.Linear(embed_dim, representation_size)),
236 ("act", nn.Tanh())
237 ]))
238 else:
239 self.has_logits = False
240 self.pre_logits = nn.Identity()
241
242 # Classifier head(s)
243 self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
244 self.head_dist = None
245 if distilled:
246 self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
247
248 # Weight init
249 nn.init.trunc_normal_(self.pos_embed, std=0.02)
250 if self.dist_token is not None:
251 nn.init.trunc_normal_(self.dist_token, std=0.02)
252
253 nn.init.trunc_normal_(self.cls_token, std=0.02)
254 self.apply(_init_vit_weights)
255
256 def forward_features(self, x):
257 # [B, C, H, W] -> [B, num_patches, embed_dim]
258 # x = self.patch_embed(x) # [B, 196, 768]
259 # [1, 1, 768] -> [B, 1, 768]
260
261 # [B, 28+1, 8]
262 cls_token = self.cls_token.expand(x.shape[0], -1, -1)
263 if self.dist_token is None:
264 x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
265 else:
266 x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
267
268 x = self.pos_drop(x + self.pos_embed)
269 x = self.blocks(x)
270 x = self.norm(x)
271 if self.dist_token is None:
272 return self.pre_logits(x[:, 0])
273 else:
274 return x[:, 0], x[:, 1]
275
276 def forward(self, x):
277 x = self.forward_features(x)
278 if self.head_dist is not None:
279 x, x_dist = self.head(x[0]), self.head_dist(x[1])
280 if self.training and not torch.jit.is_scripting():
281 # during inference, return the average of both classifier predictions
282 return x, x_dist
283 else:
284 return (x + x_dist) / 2
285 else:
286 x = self.head(x)
287 return x
1 import os
2 import copy
3 import torch
4
5 from model import build_model
6 from data import build_dataloader
7 from optimizer import build_optimizer, build_lr_scheduler
8 from loss import build_loss
9 from utils import SOLVER_REGISTRY, get_logger_and_log_dir
10
11
12 @SOLVER_REGISTRY.register()
13 class VITSolver(object):
14
15 def __init__(self, cfg):
16 self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
18 self.cfg = copy.deepcopy(cfg)
19
20 self.train_loader, self.val_loader = build_dataloader(cfg)
21 self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
22 self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
23
24 # BatchNorm ?
25 self.model = build_model(cfg).to(self.device)
26
27 self.loss_fn = build_loss(cfg)
28
29 self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
30
31 self.hyper_params = cfg['solver']['args']
32 try:
33 self.epoch = self.hyper_params['epoch']
34 except Exception:
35 raise 'should contain epoch in {solver.args}'
36
37 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
38
39 @staticmethod
40 def evaluate(y_pred, y_true, thresholds=0.5):
41 y_pred_idx = torch.argmax(y_pred, dim=1) + 1
42 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
43 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
44
45 y_true_idx = torch.argmax(y_true, dim=1) + 1
46 y_true_is_other = torch.sum(y_true, dim=1)
47 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
48
49 return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item()
50
51 def train_loop(self):
52 self.model.train()
53
54 train_loss = torch.zeros(1).to(self.device)
55 correct = torch.zeros(1).to(self.device)
56 for batch, (X, y) in enumerate(self.train_loader):
57 X, y = X.to(self.device), y.to(self.device)
58
59 pred = self.model(X)
60
61 correct += self.evaluate(pred, y)
62
63 # loss = self.loss_fn(pred, y, reduction="mean")
64 loss = self.loss_fn(pred, y)
65 train_loss += loss.item()
66
67 if batch % 100 == 0:
68 loss_value, current = loss.item(), batch
69 self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
70
71 self.optimizer.zero_grad()
72 loss.backward()
73 self.optimizer.step()
74
75 correct /= self.train_dataset_size
76 train_loss /= self.train_loader_size
77 self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
78
79 @torch.no_grad()
80 def val_loop(self, t):
81 self.model.eval()
82
83 val_loss = torch.zeros(1).to(self.device)
84 correct = torch.zeros(1).to(self.device)
85 for X, y in self.val_loader:
86 X, y = X.to(self.device), y.to(self.device)
87
88 pred = self.model(X)
89
90 correct += self.evaluate(pred, y)
91
92 loss = self.loss_fn(pred, y)
93 val_loss += loss.item()
94
95 correct /= self.val_dataset_size
96 val_loss /= self.val_loader_size
97
98 self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")
99
100 def save_checkpoint(self, epoch_id):
101 self.model.eval()
102 torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
103
104 def run(self):
105 self.logger.info('==> Start Training')
106 print(self.model)
107
108 # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
109
110 for t in range(self.epoch):
111 self.logger.info(f'==> epoch {t + 1}')
112
113 self.train_loop()
114 self.val_loop(t + 1)
115 self.save_checkpoint(t + 1)
116
117 # lr_scheduler.step()
118
119 self.logger.info('==> End Training')
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!