d865f629 by 周伟奇

add VIT

1 parent c424aba7
1 .DS_Store 1 .DS_Store
2 logs/ 2 logs/
3
4 __pycache__
5
6 *.log
7
8 dataset/
......
1 seed: 3407
2
3 dataset:
4 name: 'CoordinatesData'
5 args:
6 data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset'
7 train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv'
8 val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv'
9
10 dataloader:
11 batch_size: 32
12 num_workers: 4
13 pin_memory: true
14 shuffle: true
15
16 model:
17 name: 'VisionTransformer'
18 args:
19 img_size: 224
20 patch_size: 16
21 in_c: 3
22 num_classes: 5
23 embed_dim: 8
24 depth: 12
25 num_heads: 12
26 mlp_ratio: 4.0
27 qkv_bias: true
28 qk_scale: none
29 representation_size: none
30 distilled: false
31 drop_ratio: 0.
32 attn_drop_ratio: 0.
33 drop_path_ratio: 0.
34 norm_layer: none
35 act_layer: none
36
37 solver:
38 name: 'VITSolver'
39 args:
40 epoch: 100
41
42 optimizer:
43 name: 'Adam'
44 args:
45 lr: !!float 1e-4
46 weight_decay: !!float 5e-5
47
48 lr_scheduler:
49 name: 'StepLR'
50 args:
51 step_size: 15
52 gamma: 0.1
53
54 loss:
55 name: 'SigmoidFocalLoss'
56 # name: 'CrossEntropyLoss'
57 args:
58 reduction: "mean"
59
60 logger:
61 log_root: '/Users/zhouweiqi/Downloads/test/logs'
62 suffix: 'vit'
...\ No newline at end of file ...\ No newline at end of file
This diff is collapsed. Click to expand it.
1 import os
2 import copy
3 import torch
4
5 from model import build_model
6 from data import build_dataloader
7 from optimizer import build_optimizer, build_lr_scheduler
8 from loss import build_loss
9 from utils import SOLVER_REGISTRY, get_logger_and_log_dir
10
11
12 @SOLVER_REGISTRY.register()
13 class VITSolver(object):
14
15 def __init__(self, cfg):
16 self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
18 self.cfg = copy.deepcopy(cfg)
19
20 self.train_loader, self.val_loader = build_dataloader(cfg)
21 self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
22 self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
23
24 # BatchNorm ?
25 self.model = build_model(cfg).to(self.device)
26
27 self.loss_fn = build_loss(cfg)
28
29 self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
30
31 self.hyper_params = cfg['solver']['args']
32 try:
33 self.epoch = self.hyper_params['epoch']
34 except Exception:
35 raise 'should contain epoch in {solver.args}'
36
37 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
38
39 @staticmethod
40 def evaluate(y_pred, y_true, thresholds=0.5):
41 y_pred_idx = torch.argmax(y_pred, dim=1) + 1
42 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
43 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
44
45 y_true_idx = torch.argmax(y_true, dim=1) + 1
46 y_true_is_other = torch.sum(y_true, dim=1)
47 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
48
49 return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item()
50
51 def train_loop(self):
52 self.model.train()
53
54 train_loss = torch.zeros(1).to(self.device)
55 correct = torch.zeros(1).to(self.device)
56 for batch, (X, y) in enumerate(self.train_loader):
57 X, y = X.to(self.device), y.to(self.device)
58
59 pred = self.model(X)
60
61 correct += self.evaluate(pred, y)
62
63 # loss = self.loss_fn(pred, y, reduction="mean")
64 loss = self.loss_fn(pred, y)
65 train_loss += loss.item()
66
67 if batch % 100 == 0:
68 loss_value, current = loss.item(), batch
69 self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
70
71 self.optimizer.zero_grad()
72 loss.backward()
73 self.optimizer.step()
74
75 correct /= self.train_dataset_size
76 train_loss /= self.train_loader_size
77 self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
78
79 @torch.no_grad()
80 def val_loop(self, t):
81 self.model.eval()
82
83 val_loss = torch.zeros(1).to(self.device)
84 correct = torch.zeros(1).to(self.device)
85 for X, y in self.val_loader:
86 X, y = X.to(self.device), y.to(self.device)
87
88 pred = self.model(X)
89
90 correct += self.evaluate(pred, y)
91
92 loss = self.loss_fn(pred, y)
93 val_loss += loss.item()
94
95 correct /= self.val_dataset_size
96 val_loss /= self.val_loader_size
97
98 self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")
99
100 def save_checkpoint(self, epoch_id):
101 self.model.eval()
102 torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
103
104 def run(self):
105 self.logger.info('==> Start Training')
106 print(self.model)
107
108 # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
109
110 for t in range(self.epoch):
111 self.logger.info(f'==> epoch {t + 1}')
112
113 self.train_loop()
114 self.val_loop(t + 1)
115 self.save_checkpoint(t + 1)
116
117 # lr_scheduler.step()
118
119 self.logger.info('==> End Training')
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!