69e75f77 by 周伟奇

add load model

1 parent fb5f4ba1
......@@ -40,6 +40,7 @@ solver:
args:
epoch: 100
no_other: false
base_on: null
optimizer:
name: 'Adam'
......
import os
import copy
import os
import torch
from model import build_model
from data import build_dataloader
from optimizer import build_optimizer, build_lr_scheduler
from loss import build_loss
from model import build_model
from optimizer import build_lr_scheduler, build_optimizer
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
......@@ -30,6 +31,7 @@ class VITSolver(object):
self.hyper_params = cfg['solver']['args']
self.no_other = self.hyper_params['no_other']
self.base_on = self.hyper_params['base_on']
try:
self.epoch = self.hyper_params['epoch']
except Exception:
......@@ -62,9 +64,8 @@ class VITSolver(object):
if self.no_other:
pred = torch.nn.Softmax(dim=1)(self.model(X))
else:
pred = torch.nn.Sigmoid(self.model(X))
correct += self.evaluate(pred, y)
# pred = torch.nn.Sigmoid()(self.model(X))
pred = self.model(X)
# loss = self.loss_fn(pred, y, reduction="mean")
loss = self.loss_fn(pred, y)
......@@ -74,6 +75,8 @@ class VITSolver(object):
loss_value, current = loss.item(), batch
self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
......@@ -94,13 +97,14 @@ class VITSolver(object):
if self.no_other:
pred = torch.nn.Softmax(dim=1)(self.model(X))
else:
pred = torch.nn.Sigmoid(self.model(X))
correct += self.evaluate(pred, y)
# pred = torch.nn.Sigmoid()(self.model(X))
pred = self.model(X)
loss = self.loss_fn(pred, y)
val_loss += loss.item()
correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
correct /= self.val_dataset_size
val_loss /= self.val_loader_size
......@@ -111,6 +115,10 @@ class VITSolver(object):
torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
def run(self):
if isinstance(self.base_on, str) and os.path.exists(self.base_on):
self.model.load_state_dict(torch.load(self.base_on))
self.logger.info(f'==> Load Model from {self.base_on}')
self.logger.info('==> Start Training')
print(self.model)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!