add VIT
Showing
4 changed files
with
187 additions
and
0 deletions
config/vit.yaml
0 → 100644
1 | seed: 3407 | ||
2 | |||
3 | dataset: | ||
4 | name: 'CoordinatesData' | ||
5 | args: | ||
6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset' | ||
7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv' | ||
8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv' | ||
9 | |||
10 | dataloader: | ||
11 | batch_size: 32 | ||
12 | num_workers: 4 | ||
13 | pin_memory: true | ||
14 | shuffle: true | ||
15 | |||
16 | model: | ||
17 | name: 'VisionTransformer' | ||
18 | args: | ||
19 | img_size: 224 | ||
20 | patch_size: 16 | ||
21 | in_c: 3 | ||
22 | num_classes: 5 | ||
23 | embed_dim: 8 | ||
24 | depth: 12 | ||
25 | num_heads: 12 | ||
26 | mlp_ratio: 4.0 | ||
27 | qkv_bias: true | ||
28 | qk_scale: none | ||
29 | representation_size: none | ||
30 | distilled: false | ||
31 | drop_ratio: 0. | ||
32 | attn_drop_ratio: 0. | ||
33 | drop_path_ratio: 0. | ||
34 | norm_layer: none | ||
35 | act_layer: none | ||
36 | |||
37 | solver: | ||
38 | name: 'VITSolver' | ||
39 | args: | ||
40 | epoch: 100 | ||
41 | |||
42 | optimizer: | ||
43 | name: 'Adam' | ||
44 | args: | ||
45 | lr: !!float 1e-4 | ||
46 | weight_decay: !!float 5e-5 | ||
47 | |||
48 | lr_scheduler: | ||
49 | name: 'StepLR' | ||
50 | args: | ||
51 | step_size: 15 | ||
52 | gamma: 0.1 | ||
53 | |||
54 | loss: | ||
55 | name: 'SigmoidFocalLoss' | ||
56 | # name: 'CrossEntropyLoss' | ||
57 | args: | ||
58 | reduction: "mean" | ||
59 | |||
60 | logger: | ||
61 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ||
62 | suffix: 'vit' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
model/vit.py
0 → 100644
This diff is collapsed.
Click to expand it.
solver/vit_solver.py
0 → 100644
1 | import os | ||
2 | import copy | ||
3 | import torch | ||
4 | |||
5 | from model import build_model | ||
6 | from data import build_dataloader | ||
7 | from optimizer import build_optimizer, build_lr_scheduler | ||
8 | from loss import build_loss | ||
9 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ||
10 | |||
11 | |||
12 | @SOLVER_REGISTRY.register() | ||
13 | class VITSolver(object): | ||
14 | |||
15 | def __init__(self, cfg): | ||
16 | self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
17 | |||
18 | self.cfg = copy.deepcopy(cfg) | ||
19 | |||
20 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
21 | self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) | ||
22 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | ||
23 | |||
24 | # BatchNorm ? | ||
25 | self.model = build_model(cfg).to(self.device) | ||
26 | |||
27 | self.loss_fn = build_loss(cfg) | ||
28 | |||
29 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
30 | |||
31 | self.hyper_params = cfg['solver']['args'] | ||
32 | try: | ||
33 | self.epoch = self.hyper_params['epoch'] | ||
34 | except Exception: | ||
35 | raise 'should contain epoch in {solver.args}' | ||
36 | |||
37 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | ||
38 | |||
39 | @staticmethod | ||
40 | def evaluate(y_pred, y_true, thresholds=0.5): | ||
41 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
42 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
43 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
44 | |||
45 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
46 | y_true_is_other = torch.sum(y_true, dim=1) | ||
47 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
48 | |||
49 | return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() | ||
50 | |||
51 | def train_loop(self): | ||
52 | self.model.train() | ||
53 | |||
54 | train_loss = torch.zeros(1).to(self.device) | ||
55 | correct = torch.zeros(1).to(self.device) | ||
56 | for batch, (X, y) in enumerate(self.train_loader): | ||
57 | X, y = X.to(self.device), y.to(self.device) | ||
58 | |||
59 | pred = self.model(X) | ||
60 | |||
61 | correct += self.evaluate(pred, y) | ||
62 | |||
63 | # loss = self.loss_fn(pred, y, reduction="mean") | ||
64 | loss = self.loss_fn(pred, y) | ||
65 | train_loss += loss.item() | ||
66 | |||
67 | if batch % 100 == 0: | ||
68 | loss_value, current = loss.item(), batch | ||
69 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | ||
70 | |||
71 | self.optimizer.zero_grad() | ||
72 | loss.backward() | ||
73 | self.optimizer.step() | ||
74 | |||
75 | correct /= self.train_dataset_size | ||
76 | train_loss /= self.train_loader_size | ||
77 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') | ||
78 | |||
79 | @torch.no_grad() | ||
80 | def val_loop(self, t): | ||
81 | self.model.eval() | ||
82 | |||
83 | val_loss = torch.zeros(1).to(self.device) | ||
84 | correct = torch.zeros(1).to(self.device) | ||
85 | for X, y in self.val_loader: | ||
86 | X, y = X.to(self.device), y.to(self.device) | ||
87 | |||
88 | pred = self.model(X) | ||
89 | |||
90 | correct += self.evaluate(pred, y) | ||
91 | |||
92 | loss = self.loss_fn(pred, y) | ||
93 | val_loss += loss.item() | ||
94 | |||
95 | correct /= self.val_dataset_size | ||
96 | val_loss /= self.val_loader_size | ||
97 | |||
98 | self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") | ||
99 | |||
100 | def save_checkpoint(self, epoch_id): | ||
101 | self.model.eval() | ||
102 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | ||
103 | |||
104 | def run(self): | ||
105 | self.logger.info('==> Start Training') | ||
106 | print(self.model) | ||
107 | |||
108 | # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
109 | |||
110 | for t in range(self.epoch): | ||
111 | self.logger.info(f'==> epoch {t + 1}') | ||
112 | |||
113 | self.train_loop() | ||
114 | self.val_loop(t + 1) | ||
115 | self.save_checkpoint(t + 1) | ||
116 | |||
117 | # lr_scheduler.step() | ||
118 | |||
119 | self.logger.info('==> End Training') |
-
Please register or sign in to post a comment