add VIT
Showing
4 changed files
with
474 additions
and
0 deletions
config/vit.yaml
0 → 100644
| 1 | seed: 3407 | ||
| 2 | |||
| 3 | dataset: | ||
| 4 | name: 'CoordinatesData' | ||
| 5 | args: | ||
| 6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset' | ||
| 7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv' | ||
| 8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv' | ||
| 9 | |||
| 10 | dataloader: | ||
| 11 | batch_size: 32 | ||
| 12 | num_workers: 4 | ||
| 13 | pin_memory: true | ||
| 14 | shuffle: true | ||
| 15 | |||
| 16 | model: | ||
| 17 | name: 'VisionTransformer' | ||
| 18 | args: | ||
| 19 | img_size: 224 | ||
| 20 | patch_size: 16 | ||
| 21 | in_c: 3 | ||
| 22 | num_classes: 5 | ||
| 23 | embed_dim: 8 | ||
| 24 | depth: 12 | ||
| 25 | num_heads: 12 | ||
| 26 | mlp_ratio: 4.0 | ||
| 27 | qkv_bias: true | ||
| 28 | qk_scale: none | ||
| 29 | representation_size: none | ||
| 30 | distilled: false | ||
| 31 | drop_ratio: 0. | ||
| 32 | attn_drop_ratio: 0. | ||
| 33 | drop_path_ratio: 0. | ||
| 34 | norm_layer: none | ||
| 35 | act_layer: none | ||
| 36 | |||
| 37 | solver: | ||
| 38 | name: 'VITSolver' | ||
| 39 | args: | ||
| 40 | epoch: 100 | ||
| 41 | |||
| 42 | optimizer: | ||
| 43 | name: 'Adam' | ||
| 44 | args: | ||
| 45 | lr: !!float 1e-4 | ||
| 46 | weight_decay: !!float 5e-5 | ||
| 47 | |||
| 48 | lr_scheduler: | ||
| 49 | name: 'StepLR' | ||
| 50 | args: | ||
| 51 | step_size: 15 | ||
| 52 | gamma: 0.1 | ||
| 53 | |||
| 54 | loss: | ||
| 55 | name: 'SigmoidFocalLoss' | ||
| 56 | # name: 'CrossEntropyLoss' | ||
| 57 | args: | ||
| 58 | reduction: "mean" | ||
| 59 | |||
| 60 | logger: | ||
| 61 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ||
| 62 | suffix: 'vit' | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
model/vit.py
0 → 100644
| 1 | from functools import partial | ||
| 2 | from collections import OrderedDict | ||
| 3 | |||
| 4 | import torch | ||
| 5 | import torch.nn as nn | ||
| 6 | from utils.registery import MODEL_REGISTRY | ||
| 7 | |||
| 8 | |||
| 9 | def _init_vit_weights(m): | ||
| 10 | """ | ||
| 11 | ViT weight initialization | ||
| 12 | :param m: module | ||
| 13 | """ | ||
| 14 | if isinstance(m, nn.Linear): | ||
| 15 | nn.init.trunc_normal_(m.weight, std=.01) | ||
| 16 | if m.bias is not None: | ||
| 17 | nn.init.zeros_(m.bias) | ||
| 18 | elif isinstance(m, nn.Conv2d): | ||
| 19 | nn.init.kaiming_normal_(m.weight, mode="fan_out") | ||
| 20 | if m.bias is not None: | ||
| 21 | nn.init.zeros_(m.bias) | ||
| 22 | elif isinstance(m, nn.LayerNorm): | ||
| 23 | nn.init.zeros_(m.bias) | ||
| 24 | nn.init.ones_(m.weight) | ||
| 25 | |||
| 26 | |||
| 27 | def drop_path(x, drop_prob: float = 0., training: bool = False): | ||
| 28 | """ | ||
| 29 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | ||
| 30 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | ||
| 31 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | ||
| 32 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | ||
| 33 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | ||
| 34 | 'survival rate' as the argument. | ||
| 35 | """ | ||
| 36 | if drop_prob == 0. or not training: | ||
| 37 | return x | ||
| 38 | keep_prob = 1 - drop_prob | ||
| 39 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | ||
| 40 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | ||
| 41 | random_tensor.floor_() # binarize | ||
| 42 | output = x.div(keep_prob) * random_tensor | ||
| 43 | return output | ||
| 44 | |||
| 45 | |||
| 46 | class DropPath(nn.Module): | ||
| 47 | """ | ||
| 48 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | ||
| 49 | """ | ||
| 50 | def __init__(self, drop_prob=None): | ||
| 51 | super(DropPath, self).__init__() | ||
| 52 | self.drop_prob = drop_prob | ||
| 53 | |||
| 54 | def forward(self, x): | ||
| 55 | return drop_path(x, self.drop_prob, self.training) | ||
| 56 | |||
| 57 | |||
| 58 | class Attention(nn.Module): | ||
| 59 | def __init__(self, | ||
| 60 | dim, # 输入token的dim | ||
| 61 | num_heads=8, | ||
| 62 | qkv_bias=False, | ||
| 63 | qk_scale=None, | ||
| 64 | attn_drop_ratio=0., | ||
| 65 | proj_drop_ratio=0.): | ||
| 66 | super(Attention, self).__init__() | ||
| 67 | self.num_heads = num_heads | ||
| 68 | head_dim = dim // num_heads | ||
| 69 | self.scale = qk_scale or head_dim ** -0.5 | ||
| 70 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||
| 71 | self.attn_drop = nn.Dropout(attn_drop_ratio) | ||
| 72 | self.proj = nn.Linear(dim, dim) | ||
| 73 | self.proj_drop = nn.Dropout(proj_drop_ratio) | ||
| 74 | |||
| 75 | def forward(self, x): | ||
| 76 | # [batch_size, num_patches + 1, total_embed_dim] | ||
| 77 | B, N, C = x.shape | ||
| 78 | |||
| 79 | # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] | ||
| 80 | # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head] | ||
| 81 | # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head] | ||
| 82 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | ||
| 83 | # [batch_size, num_heads, num_patches + 1, embed_dim_per_head] | ||
| 84 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | ||
| 85 | |||
| 86 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] | ||
| 87 | # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1] | ||
| 88 | attn = (q @ k.transpose(-2, -1)) * self.scale | ||
| 89 | attn = attn.softmax(dim=-1) | ||
| 90 | attn = self.attn_drop(attn) | ||
| 91 | |||
| 92 | # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head] | ||
| 93 | # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head] | ||
| 94 | # reshape: -> [batch_size, num_patches + 1, total_embed_dim] | ||
| 95 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) | ||
| 96 | x = self.proj(x) | ||
| 97 | x = self.proj_drop(x) | ||
| 98 | return x | ||
| 99 | |||
| 100 | |||
| 101 | class Mlp(nn.Module): | ||
| 102 | """ | ||
| 103 | MLP as used in Vision Transformer, MLP-Mixer and related networks | ||
| 104 | """ | ||
| 105 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | ||
| 106 | super().__init__() | ||
| 107 | out_features = out_features or in_features | ||
| 108 | hidden_features = hidden_features or in_features | ||
| 109 | self.fc1 = nn.Linear(in_features, hidden_features) | ||
| 110 | self.act = act_layer() | ||
| 111 | self.fc2 = nn.Linear(hidden_features, out_features) | ||
| 112 | self.drop = nn.Dropout(drop) | ||
| 113 | |||
| 114 | def forward(self, x): | ||
| 115 | x = self.fc1(x) | ||
| 116 | x = self.act(x) | ||
| 117 | x = self.drop(x) | ||
| 118 | x = self.fc2(x) | ||
| 119 | x = self.drop(x) | ||
| 120 | return x | ||
| 121 | |||
| 122 | |||
| 123 | class Block(nn.Module): | ||
| 124 | def __init__(self, | ||
| 125 | dim, | ||
| 126 | num_heads, | ||
| 127 | mlp_ratio=4., | ||
| 128 | qkv_bias=False, | ||
| 129 | qk_scale=None, | ||
| 130 | drop_ratio=0., | ||
| 131 | attn_drop_ratio=0., | ||
| 132 | drop_path_ratio=0., | ||
| 133 | act_layer=nn.GELU, | ||
| 134 | norm_layer=nn.LayerNorm): | ||
| 135 | super(Block, self).__init__() | ||
| 136 | self.norm1 = norm_layer(dim) | ||
| 137 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, | ||
| 138 | attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio) | ||
| 139 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | ||
| 140 | self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() | ||
| 141 | self.norm2 = norm_layer(dim) | ||
| 142 | mlp_hidden_dim = int(dim * mlp_ratio) | ||
| 143 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) | ||
| 144 | |||
| 145 | def forward(self, x): | ||
| 146 | x = x + self.drop_path(self.attn(self.norm1(x))) | ||
| 147 | x = x + self.drop_path(self.mlp(self.norm2(x))) | ||
| 148 | return x | ||
| 149 | |||
| 150 | |||
| 151 | class PatchEmbed(nn.Module): | ||
| 152 | """ | ||
| 153 | 2D Image to Patch Embedding | ||
| 154 | """ | ||
| 155 | def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None): | ||
| 156 | super().__init__() | ||
| 157 | img_size = (img_size, img_size) | ||
| 158 | patch_size = (patch_size, patch_size) | ||
| 159 | self.img_size = img_size | ||
| 160 | self.patch_size = patch_size | ||
| 161 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) | ||
| 162 | self.num_patches = self.grid_size[0] * self.grid_size[1] | ||
| 163 | |||
| 164 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) | ||
| 165 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | ||
| 166 | |||
| 167 | def forward(self, x): | ||
| 168 | B, C, H, W = x.shape | ||
| 169 | assert H == self.img_size[0] and W == self.img_size[1], \ | ||
| 170 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." | ||
| 171 | |||
| 172 | # flatten: [B, C, H, W] -> [B, C, HW] | ||
| 173 | # transpose: [B, C, HW] -> [B, HW, C] | ||
| 174 | x = self.proj(x).flatten(2).transpose(1, 2) | ||
| 175 | x = self.norm(x) | ||
| 176 | return x | ||
| 177 | |||
| 178 | |||
| 179 | @MODEL_REGISTRY.register() | ||
| 180 | class VisionTransformer(nn.Module): | ||
| 181 | def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, | ||
| 182 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, | ||
| 183 | qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., | ||
| 184 | attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, | ||
| 185 | act_layer=None): | ||
| 186 | """ | ||
| 187 | Args: | ||
| 188 | img_size (int, tuple): input image size | ||
| 189 | patch_size (int, tuple): patch size | ||
| 190 | in_c (int): number of input channels | ||
| 191 | num_classes (int): number of classes for classification head | ||
| 192 | embed_dim (int): embedding dimension | ||
| 193 | depth (int): depth of transformer | ||
| 194 | num_heads (int): number of attention heads | ||
| 195 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim | ||
| 196 | qkv_bias (bool): enable bias for qkv if True | ||
| 197 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set | ||
| 198 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set | ||
| 199 | distilled (bool): model includes a distillation token and head as in DeiT models | ||
| 200 | drop_ratio (float): dropout rate | ||
| 201 | attn_drop_ratio (float): attention dropout rate | ||
| 202 | drop_path_ratio (float): stochastic depth rate | ||
| 203 | embed_layer (nn.Module): patch embedding layer | ||
| 204 | norm_layer: (nn.Module): normalization layer | ||
| 205 | """ | ||
| 206 | super(VisionTransformer, self).__init__() | ||
| 207 | self.num_classes = num_classes | ||
| 208 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | ||
| 209 | self.num_tokens = 2 if distilled else 1 | ||
| 210 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | ||
| 211 | act_layer = act_layer or nn.GELU | ||
| 212 | |||
| 213 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim) | ||
| 214 | num_patches = self.patch_embed.num_patches | ||
| 215 | |||
| 216 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | ||
| 217 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None | ||
| 218 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) | ||
| 219 | self.pos_drop = nn.Dropout(p=drop_ratio) | ||
| 220 | |||
| 221 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule | ||
| 222 | self.blocks = nn.Sequential(*[ | ||
| 223 | Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, | ||
| 224 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], | ||
| 225 | norm_layer=norm_layer, act_layer=act_layer) | ||
| 226 | for i in range(depth) | ||
| 227 | ]) | ||
| 228 | self.norm = norm_layer(embed_dim) | ||
| 229 | |||
| 230 | # Representation layer | ||
| 231 | if representation_size and not distilled: | ||
| 232 | self.has_logits = True | ||
| 233 | self.num_features = representation_size | ||
| 234 | self.pre_logits = nn.Sequential(OrderedDict([ | ||
| 235 | ("fc", nn.Linear(embed_dim, representation_size)), | ||
| 236 | ("act", nn.Tanh()) | ||
| 237 | ])) | ||
| 238 | else: | ||
| 239 | self.has_logits = False | ||
| 240 | self.pre_logits = nn.Identity() | ||
| 241 | |||
| 242 | # Classifier head(s) | ||
| 243 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() | ||
| 244 | self.head_dist = None | ||
| 245 | if distilled: | ||
| 246 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() | ||
| 247 | |||
| 248 | # Weight init | ||
| 249 | nn.init.trunc_normal_(self.pos_embed, std=0.02) | ||
| 250 | if self.dist_token is not None: | ||
| 251 | nn.init.trunc_normal_(self.dist_token, std=0.02) | ||
| 252 | |||
| 253 | nn.init.trunc_normal_(self.cls_token, std=0.02) | ||
| 254 | self.apply(_init_vit_weights) | ||
| 255 | |||
| 256 | def forward_features(self, x): | ||
| 257 | # [B, C, H, W] -> [B, num_patches, embed_dim] | ||
| 258 | # x = self.patch_embed(x) # [B, 196, 768] | ||
| 259 | # [1, 1, 768] -> [B, 1, 768] | ||
| 260 | |||
| 261 | # [B, 28+1, 8] | ||
| 262 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) | ||
| 263 | if self.dist_token is None: | ||
| 264 | x = torch.cat((cls_token, x), dim=1) # [B, 197, 768] | ||
| 265 | else: | ||
| 266 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) | ||
| 267 | |||
| 268 | x = self.pos_drop(x + self.pos_embed) | ||
| 269 | x = self.blocks(x) | ||
| 270 | x = self.norm(x) | ||
| 271 | if self.dist_token is None: | ||
| 272 | return self.pre_logits(x[:, 0]) | ||
| 273 | else: | ||
| 274 | return x[:, 0], x[:, 1] | ||
| 275 | |||
| 276 | def forward(self, x): | ||
| 277 | x = self.forward_features(x) | ||
| 278 | if self.head_dist is not None: | ||
| 279 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) | ||
| 280 | if self.training and not torch.jit.is_scripting(): | ||
| 281 | # during inference, return the average of both classifier predictions | ||
| 282 | return x, x_dist | ||
| 283 | else: | ||
| 284 | return (x + x_dist) / 2 | ||
| 285 | else: | ||
| 286 | x = self.head(x) | ||
| 287 | return x |
solver/vit_solver.py
0 → 100644
| 1 | import os | ||
| 2 | import copy | ||
| 3 | import torch | ||
| 4 | |||
| 5 | from model import build_model | ||
| 6 | from data import build_dataloader | ||
| 7 | from optimizer import build_optimizer, build_lr_scheduler | ||
| 8 | from loss import build_loss | ||
| 9 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ||
| 10 | |||
| 11 | |||
| 12 | @SOLVER_REGISTRY.register() | ||
| 13 | class VITSolver(object): | ||
| 14 | |||
| 15 | def __init__(self, cfg): | ||
| 16 | self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| 17 | |||
| 18 | self.cfg = copy.deepcopy(cfg) | ||
| 19 | |||
| 20 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
| 21 | self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) | ||
| 22 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | ||
| 23 | |||
| 24 | # BatchNorm ? | ||
| 25 | self.model = build_model(cfg).to(self.device) | ||
| 26 | |||
| 27 | self.loss_fn = build_loss(cfg) | ||
| 28 | |||
| 29 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
| 30 | |||
| 31 | self.hyper_params = cfg['solver']['args'] | ||
| 32 | try: | ||
| 33 | self.epoch = self.hyper_params['epoch'] | ||
| 34 | except Exception: | ||
| 35 | raise 'should contain epoch in {solver.args}' | ||
| 36 | |||
| 37 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | ||
| 38 | |||
| 39 | @staticmethod | ||
| 40 | def evaluate(y_pred, y_true, thresholds=0.5): | ||
| 41 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
| 42 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
| 43 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
| 44 | |||
| 45 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
| 46 | y_true_is_other = torch.sum(y_true, dim=1) | ||
| 47 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
| 48 | |||
| 49 | return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() | ||
| 50 | |||
| 51 | def train_loop(self): | ||
| 52 | self.model.train() | ||
| 53 | |||
| 54 | train_loss = torch.zeros(1).to(self.device) | ||
| 55 | correct = torch.zeros(1).to(self.device) | ||
| 56 | for batch, (X, y) in enumerate(self.train_loader): | ||
| 57 | X, y = X.to(self.device), y.to(self.device) | ||
| 58 | |||
| 59 | pred = self.model(X) | ||
| 60 | |||
| 61 | correct += self.evaluate(pred, y) | ||
| 62 | |||
| 63 | # loss = self.loss_fn(pred, y, reduction="mean") | ||
| 64 | loss = self.loss_fn(pred, y) | ||
| 65 | train_loss += loss.item() | ||
| 66 | |||
| 67 | if batch % 100 == 0: | ||
| 68 | loss_value, current = loss.item(), batch | ||
| 69 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | ||
| 70 | |||
| 71 | self.optimizer.zero_grad() | ||
| 72 | loss.backward() | ||
| 73 | self.optimizer.step() | ||
| 74 | |||
| 75 | correct /= self.train_dataset_size | ||
| 76 | train_loss /= self.train_loader_size | ||
| 77 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') | ||
| 78 | |||
| 79 | @torch.no_grad() | ||
| 80 | def val_loop(self, t): | ||
| 81 | self.model.eval() | ||
| 82 | |||
| 83 | val_loss = torch.zeros(1).to(self.device) | ||
| 84 | correct = torch.zeros(1).to(self.device) | ||
| 85 | for X, y in self.val_loader: | ||
| 86 | X, y = X.to(self.device), y.to(self.device) | ||
| 87 | |||
| 88 | pred = self.model(X) | ||
| 89 | |||
| 90 | correct += self.evaluate(pred, y) | ||
| 91 | |||
| 92 | loss = self.loss_fn(pred, y) | ||
| 93 | val_loss += loss.item() | ||
| 94 | |||
| 95 | correct /= self.val_dataset_size | ||
| 96 | val_loss /= self.val_loader_size | ||
| 97 | |||
| 98 | self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") | ||
| 99 | |||
| 100 | def save_checkpoint(self, epoch_id): | ||
| 101 | self.model.eval() | ||
| 102 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | ||
| 103 | |||
| 104 | def run(self): | ||
| 105 | self.logger.info('==> Start Training') | ||
| 106 | print(self.model) | ||
| 107 | |||
| 108 | # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
| 109 | |||
| 110 | for t in range(self.epoch): | ||
| 111 | self.logger.info(f'==> epoch {t + 1}') | ||
| 112 | |||
| 113 | self.train_loop() | ||
| 114 | self.val_loop(t + 1) | ||
| 115 | self.save_checkpoint(t + 1) | ||
| 116 | |||
| 117 | # lr_scheduler.step() | ||
| 118 | |||
| 119 | self.logger.info('==> End Training') |
-
Please register or sign in to post a comment