d865f629 by 周伟奇

add VIT

1 parent c424aba7
.DS_Store
logs/
__pycache__
*.log
dataset/
......
seed: 3407
dataset:
name: 'CoordinatesData'
args:
data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset'
train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv'
val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv'
dataloader:
batch_size: 32
num_workers: 4
pin_memory: true
shuffle: true
model:
name: 'VisionTransformer'
args:
img_size: 224
patch_size: 16
in_c: 3
num_classes: 5
embed_dim: 8
depth: 12
num_heads: 12
mlp_ratio: 4.0
qkv_bias: true
qk_scale: none
representation_size: none
distilled: false
drop_ratio: 0.
attn_drop_ratio: 0.
drop_path_ratio: 0.
norm_layer: none
act_layer: none
solver:
name: 'VITSolver'
args:
epoch: 100
optimizer:
name: 'Adam'
args:
lr: !!float 1e-4
weight_decay: !!float 5e-5
lr_scheduler:
name: 'StepLR'
args:
step_size: 15
gamma: 0.1
loss:
name: 'SigmoidFocalLoss'
# name: 'CrossEntropyLoss'
args:
reduction: "mean"
logger:
log_root: '/Users/zhouweiqi/Downloads/test/logs'
suffix: 'vit'
\ No newline at end of file
This diff is collapsed. Click to expand it.
import os
import copy
import torch
from model import build_model
from data import build_dataloader
from optimizer import build_optimizer, build_lr_scheduler
from loss import build_loss
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
@SOLVER_REGISTRY.register()
class VITSolver(object):
def __init__(self, cfg):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.cfg = copy.deepcopy(cfg)
self.train_loader, self.val_loader = build_dataloader(cfg)
self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
# BatchNorm ?
self.model = build_model(cfg).to(self.device)
self.loss_fn = build_loss(cfg)
self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
self.hyper_params = cfg['solver']['args']
try:
self.epoch = self.hyper_params['epoch']
except Exception:
raise 'should contain epoch in {solver.args}'
self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
@staticmethod
def evaluate(y_pred, y_true, thresholds=0.5):
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=1) + 1
y_true_is_other = torch.sum(y_true, dim=1)
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item()
def train_loop(self):
self.model.train()
train_loss = torch.zeros(1).to(self.device)
correct = torch.zeros(1).to(self.device)
for batch, (X, y) in enumerate(self.train_loader):
X, y = X.to(self.device), y.to(self.device)
pred = self.model(X)
correct += self.evaluate(pred, y)
# loss = self.loss_fn(pred, y, reduction="mean")
loss = self.loss_fn(pred, y)
train_loss += loss.item()
if batch % 100 == 0:
loss_value, current = loss.item(), batch
self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
correct /= self.train_dataset_size
train_loss /= self.train_loader_size
self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
@torch.no_grad()
def val_loop(self, t):
self.model.eval()
val_loss = torch.zeros(1).to(self.device)
correct = torch.zeros(1).to(self.device)
for X, y in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
pred = self.model(X)
correct += self.evaluate(pred, y)
loss = self.loss_fn(pred, y)
val_loss += loss.item()
correct /= self.val_dataset_size
val_loss /= self.val_loader_size
self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")
def save_checkpoint(self, epoch_id):
self.model.eval()
torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
def run(self):
self.logger.info('==> Start Training')
print(self.model)
# lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
for t in range(self.epoch):
self.logger.info(f'==> epoch {t + 1}')
self.train_loop()
self.val_loop(t + 1)
self.save_checkpoint(t + 1)
# lr_scheduler.step()
self.logger.info('==> End Training')
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!