import os import copy import torch from model import build_model from data import build_dataloader from optimizer import build_optimizer, build_lr_scheduler from loss import build_loss from utils import SOLVER_REGISTRY, get_logger_and_log_dir @SOLVER_REGISTRY.register() class MLPSolver(object): def __init__(self, cfg): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.cfg = copy.deepcopy(cfg) self.train_loader, self.val_loader = build_dataloader(cfg) self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) # BatchNorm ? self.model = build_model(cfg).to(self.device) self.loss_fn = build_loss(cfg) self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) self.hyper_params = cfg['solver']['args'] try: self.epoch = self.hyper_params['epoch'] except Exception: raise 'should contain epoch in {solver.args}' self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) @staticmethod def evaluate(y_pred, y_true, thresholds=0.5): y_pred_idx = torch.argmax(y_pred, dim=1) + 1 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) y_true_idx = torch.argmax(y_true, dim=1) + 1 y_true_is_other = torch.sum(y_true, dim=1) y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() def train_loop(self): self.model.train() train_loss = torch.zeros(1).to(self.device) correct = torch.zeros(1).to(self.device) for batch, (X, y) in enumerate(self.train_loader): X, y = X.to(self.device), y.to(self.device) pred = self.model(X) correct += self.evaluate(pred, y) # loss = self.loss_fn(pred, y, reduction="mean") loss = self.loss_fn(pred, y) train_loss += loss.item() if batch % 100 == 0: loss_value, current = loss.item(), batch self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') self.optimizer.zero_grad() loss.backward() self.optimizer.step() correct /= self.train_dataset_size train_loss /= self.train_loader_size self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') @torch.no_grad() def val_loop(self, t): self.model.eval() val_loss = torch.zeros(1).to(self.device) correct = torch.zeros(1).to(self.device) for X, y in self.val_loader: X, y = X.to(self.device), y.to(self.device) pred = self.model(X) correct += self.evaluate(pred, y) loss = self.loss_fn(pred, y) val_loss += loss.item() correct /= self.val_dataset_size val_loss /= self.val_loader_size self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") def save_checkpoint(self, epoch_id): self.model.eval() torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) def run(self): self.logger.info('==> Start Training') print(self.model) # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) for t in range(self.epoch): self.logger.info(f'==> epoch {t + 1}') self.train_loop() self.val_loop(t + 1) self.save_checkpoint(t + 1) # lr_scheduler.step() self.logger.info('==> End Training') # for X, y in self.train_loader: # print(X.size()) # print(y.size()) # pred = self.model(X) # print(pred) # print(y) # loss = self.loss_fn(pred, y, reduction="mean") # print(loss) # break # y_true = [ # [0, 1, 0], # [0, 1, 0], # [0, 0, 1], # [0, 0, 0], # ] # y_pred = [ # [0.1, 0.8, 0.9], # [0.2, 0.8, 0.1], # [0.2, 0.1, 0.85], # [0.2, 0.6, 0.1], # ] # acc_num = self.evaluate(torch.tensor(y_pred), torch.tensor(y_true))