add VIT
Showing
4 changed files
with
187 additions
and
0 deletions
config/vit.yaml
0 → 100644
| 1 | seed: 3407 | ||
| 2 | |||
| 3 | dataset: | ||
| 4 | name: 'CoordinatesData' | ||
| 5 | args: | ||
| 6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset' | ||
| 7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv' | ||
| 8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv' | ||
| 9 | |||
| 10 | dataloader: | ||
| 11 | batch_size: 32 | ||
| 12 | num_workers: 4 | ||
| 13 | pin_memory: true | ||
| 14 | shuffle: true | ||
| 15 | |||
| 16 | model: | ||
| 17 | name: 'VisionTransformer' | ||
| 18 | args: | ||
| 19 | img_size: 224 | ||
| 20 | patch_size: 16 | ||
| 21 | in_c: 3 | ||
| 22 | num_classes: 5 | ||
| 23 | embed_dim: 8 | ||
| 24 | depth: 12 | ||
| 25 | num_heads: 12 | ||
| 26 | mlp_ratio: 4.0 | ||
| 27 | qkv_bias: true | ||
| 28 | qk_scale: none | ||
| 29 | representation_size: none | ||
| 30 | distilled: false | ||
| 31 | drop_ratio: 0. | ||
| 32 | attn_drop_ratio: 0. | ||
| 33 | drop_path_ratio: 0. | ||
| 34 | norm_layer: none | ||
| 35 | act_layer: none | ||
| 36 | |||
| 37 | solver: | ||
| 38 | name: 'VITSolver' | ||
| 39 | args: | ||
| 40 | epoch: 100 | ||
| 41 | |||
| 42 | optimizer: | ||
| 43 | name: 'Adam' | ||
| 44 | args: | ||
| 45 | lr: !!float 1e-4 | ||
| 46 | weight_decay: !!float 5e-5 | ||
| 47 | |||
| 48 | lr_scheduler: | ||
| 49 | name: 'StepLR' | ||
| 50 | args: | ||
| 51 | step_size: 15 | ||
| 52 | gamma: 0.1 | ||
| 53 | |||
| 54 | loss: | ||
| 55 | name: 'SigmoidFocalLoss' | ||
| 56 | # name: 'CrossEntropyLoss' | ||
| 57 | args: | ||
| 58 | reduction: "mean" | ||
| 59 | |||
| 60 | logger: | ||
| 61 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ||
| 62 | suffix: 'vit' | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
model/vit.py
0 → 100644
This diff is collapsed.
Click to expand it.
solver/vit_solver.py
0 → 100644
| 1 | import os | ||
| 2 | import copy | ||
| 3 | import torch | ||
| 4 | |||
| 5 | from model import build_model | ||
| 6 | from data import build_dataloader | ||
| 7 | from optimizer import build_optimizer, build_lr_scheduler | ||
| 8 | from loss import build_loss | ||
| 9 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ||
| 10 | |||
| 11 | |||
| 12 | @SOLVER_REGISTRY.register() | ||
| 13 | class VITSolver(object): | ||
| 14 | |||
| 15 | def __init__(self, cfg): | ||
| 16 | self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| 17 | |||
| 18 | self.cfg = copy.deepcopy(cfg) | ||
| 19 | |||
| 20 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
| 21 | self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) | ||
| 22 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | ||
| 23 | |||
| 24 | # BatchNorm ? | ||
| 25 | self.model = build_model(cfg).to(self.device) | ||
| 26 | |||
| 27 | self.loss_fn = build_loss(cfg) | ||
| 28 | |||
| 29 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
| 30 | |||
| 31 | self.hyper_params = cfg['solver']['args'] | ||
| 32 | try: | ||
| 33 | self.epoch = self.hyper_params['epoch'] | ||
| 34 | except Exception: | ||
| 35 | raise 'should contain epoch in {solver.args}' | ||
| 36 | |||
| 37 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | ||
| 38 | |||
| 39 | @staticmethod | ||
| 40 | def evaluate(y_pred, y_true, thresholds=0.5): | ||
| 41 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
| 42 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
| 43 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
| 44 | |||
| 45 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
| 46 | y_true_is_other = torch.sum(y_true, dim=1) | ||
| 47 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
| 48 | |||
| 49 | return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() | ||
| 50 | |||
| 51 | def train_loop(self): | ||
| 52 | self.model.train() | ||
| 53 | |||
| 54 | train_loss = torch.zeros(1).to(self.device) | ||
| 55 | correct = torch.zeros(1).to(self.device) | ||
| 56 | for batch, (X, y) in enumerate(self.train_loader): | ||
| 57 | X, y = X.to(self.device), y.to(self.device) | ||
| 58 | |||
| 59 | pred = self.model(X) | ||
| 60 | |||
| 61 | correct += self.evaluate(pred, y) | ||
| 62 | |||
| 63 | # loss = self.loss_fn(pred, y, reduction="mean") | ||
| 64 | loss = self.loss_fn(pred, y) | ||
| 65 | train_loss += loss.item() | ||
| 66 | |||
| 67 | if batch % 100 == 0: | ||
| 68 | loss_value, current = loss.item(), batch | ||
| 69 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | ||
| 70 | |||
| 71 | self.optimizer.zero_grad() | ||
| 72 | loss.backward() | ||
| 73 | self.optimizer.step() | ||
| 74 | |||
| 75 | correct /= self.train_dataset_size | ||
| 76 | train_loss /= self.train_loader_size | ||
| 77 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') | ||
| 78 | |||
| 79 | @torch.no_grad() | ||
| 80 | def val_loop(self, t): | ||
| 81 | self.model.eval() | ||
| 82 | |||
| 83 | val_loss = torch.zeros(1).to(self.device) | ||
| 84 | correct = torch.zeros(1).to(self.device) | ||
| 85 | for X, y in self.val_loader: | ||
| 86 | X, y = X.to(self.device), y.to(self.device) | ||
| 87 | |||
| 88 | pred = self.model(X) | ||
| 89 | |||
| 90 | correct += self.evaluate(pred, y) | ||
| 91 | |||
| 92 | loss = self.loss_fn(pred, y) | ||
| 93 | val_loss += loss.item() | ||
| 94 | |||
| 95 | correct /= self.val_dataset_size | ||
| 96 | val_loss /= self.val_loader_size | ||
| 97 | |||
| 98 | self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") | ||
| 99 | |||
| 100 | def save_checkpoint(self, epoch_id): | ||
| 101 | self.model.eval() | ||
| 102 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | ||
| 103 | |||
| 104 | def run(self): | ||
| 105 | self.logger.info('==> Start Training') | ||
| 106 | print(self.model) | ||
| 107 | |||
| 108 | # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
| 109 | |||
| 110 | for t in range(self.epoch): | ||
| 111 | self.logger.info(f'==> epoch {t + 1}') | ||
| 112 | |||
| 113 | self.train_loop() | ||
| 114 | self.val_loop(t + 1) | ||
| 115 | self.save_checkpoint(t + 1) | ||
| 116 | |||
| 117 | # lr_scheduler.step() | ||
| 118 | |||
| 119 | self.logger.info('==> End Training') |
-
Please register or sign in to post a comment