diff --git a/config/vit.yaml b/config/vit.yaml
index 1683bb7..1b24479 100644
--- a/config/vit.yaml
+++ b/config/vit.yaml
@@ -40,6 +40,7 @@ solver:
   args:
     epoch: 100
     no_other: false
+    base_on: null
 
   optimizer:
     name: 'Adam'
diff --git a/solver/vit_solver.py b/solver/vit_solver.py
index 73a92d4..3d30e92 100644
--- a/solver/vit_solver.py
+++ b/solver/vit_solver.py
@@ -1,11 +1,12 @@
-import os
 import copy
+import os
+
 import torch
 
-from model import build_model
 from data import build_dataloader
-from optimizer import build_optimizer, build_lr_scheduler
 from loss import build_loss
+from model import build_model
+from optimizer import build_lr_scheduler, build_optimizer
 from utils import SOLVER_REGISTRY, get_logger_and_log_dir
 
 
@@ -30,6 +31,7 @@ class VITSolver(object):
 
         self.hyper_params = cfg['solver']['args']
         self.no_other = self.hyper_params['no_other'] 
+        self.base_on = self.hyper_params['base_on']
         try:
             self.epoch = self.hyper_params['epoch']
         except Exception:
@@ -62,9 +64,8 @@ class VITSolver(object):
             if self.no_other:
                 pred = torch.nn.Softmax(dim=1)(self.model(X))
             else:
-                pred = torch.nn.Sigmoid(self.model(X))
-
-            correct += self.evaluate(pred, y)
+                # pred = torch.nn.Sigmoid()(self.model(X))
+                pred = self.model(X)
 
             # loss = self.loss_fn(pred, y, reduction="mean")
             loss = self.loss_fn(pred, y)
@@ -73,6 +74,8 @@ class VITSolver(object):
             if batch % 100 == 0:
                 loss_value, current = loss.item(), batch
                 self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
+
+            correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
             
             self.optimizer.zero_grad()
             loss.backward()
@@ -94,13 +97,14 @@ class VITSolver(object):
             if self.no_other:
                 pred = torch.nn.Softmax(dim=1)(self.model(X))
             else:
-                pred = torch.nn.Sigmoid(self.model(X))
-
-            correct += self.evaluate(pred, y)
+                # pred = torch.nn.Sigmoid()(self.model(X))
+                pred = self.model(X)
 
             loss = self.loss_fn(pred, y)
             val_loss += loss.item()
 
+            correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
+
         correct /= self.val_dataset_size
         val_loss /= self.val_loader_size
             
@@ -111,6 +115,10 @@ class VITSolver(object):
         torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
 
     def run(self):
+        if isinstance(self.base_on, str) and os.path.exists(self.base_on):
+            self.model.load_state_dict(torch.load(self.base_on))
+            self.logger.info(f'==> Load Model from {self.base_on}')
+
         self.logger.info('==> Start Training')
         print(self.model)