diff --git a/config/vit.yaml b/config/vit.yaml index 1683bb7..1b24479 100644 --- a/config/vit.yaml +++ b/config/vit.yaml @@ -40,6 +40,7 @@ solver: args: epoch: 100 no_other: false + base_on: null optimizer: name: 'Adam' diff --git a/solver/vit_solver.py b/solver/vit_solver.py index 73a92d4..3d30e92 100644 --- a/solver/vit_solver.py +++ b/solver/vit_solver.py @@ -1,11 +1,12 @@ -import os import copy +import os + import torch -from model import build_model from data import build_dataloader -from optimizer import build_optimizer, build_lr_scheduler from loss import build_loss +from model import build_model +from optimizer import build_lr_scheduler, build_optimizer from utils import SOLVER_REGISTRY, get_logger_and_log_dir @@ -30,6 +31,7 @@ class VITSolver(object): self.hyper_params = cfg['solver']['args'] self.no_other = self.hyper_params['no_other'] + self.base_on = self.hyper_params['base_on'] try: self.epoch = self.hyper_params['epoch'] except Exception: @@ -62,9 +64,8 @@ class VITSolver(object): if self.no_other: pred = torch.nn.Softmax(dim=1)(self.model(X)) else: - pred = torch.nn.Sigmoid(self.model(X)) - - correct += self.evaluate(pred, y) + # pred = torch.nn.Sigmoid()(self.model(X)) + pred = self.model(X) # loss = self.loss_fn(pred, y, reduction="mean") loss = self.loss_fn(pred, y) @@ -73,6 +74,8 @@ class VITSolver(object): if batch % 100 == 0: loss_value, current = loss.item(), batch self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') + + correct += self.evaluate(torch.nn.Sigmoid()(pred), y) self.optimizer.zero_grad() loss.backward() @@ -94,13 +97,14 @@ class VITSolver(object): if self.no_other: pred = torch.nn.Softmax(dim=1)(self.model(X)) else: - pred = torch.nn.Sigmoid(self.model(X)) - - correct += self.evaluate(pred, y) + # pred = torch.nn.Sigmoid()(self.model(X)) + pred = self.model(X) loss = self.loss_fn(pred, y) val_loss += loss.item() + correct += self.evaluate(torch.nn.Sigmoid()(pred), y) + correct /= self.val_dataset_size val_loss /= self.val_loader_size @@ -111,6 +115,10 @@ class VITSolver(object): torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) def run(self): + if isinstance(self.base_on, str) and os.path.exists(self.base_on): + self.model.load_state_dict(torch.load(self.base_on)) + self.logger.info(f'==> Load Model from {self.base_on}') + self.logger.info('==> Start Training') print(self.model)