mlp_solver.py 4.11 KB
import os
import copy
import torch

from model import build_model
from data import build_dataloader
from optimizer import build_optimizer, build_lr_scheduler
from loss import build_loss
from utils import SOLVER_REGISTRY, get_logger_and_log_dir


@SOLVER_REGISTRY.register()
class MLPSolver(object):

    def __init__(self, cfg):
        self.cfg = copy.deepcopy(cfg)

        self.train_loader, self.val_loader = build_dataloader(cfg)
        self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
        self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)

        # BatchNorm ?
        self.model = build_model(cfg)

        self.loss_fn = build_loss(cfg)

        self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])

        self.hyper_params = cfg['solver']['args']
        try:
            self.epoch = self.hyper_params['epoch']
        except Exception:
            raise 'should contain epoch in {solver.args}'

        self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])

    @staticmethod
    def evaluate(y_pred, y_true, thresholds=0.5):
        y_pred_idx = torch.argmax(y_pred, dim=1) + 1
        y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
        y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)

        y_true_idx = torch.argmax(y_true, dim=1) + 1
        y_true_is_other = torch.sum(y_true, dim=1)
        y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)

        return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item()

    def train_loop(self):
        self.model.train()

        train_loss = 0
        for batch, (X, y) in enumerate(self.train_loader):
            pred = self.model(X)

            # loss = self.loss_fn(pred, y, reduction="mean")
            loss = self.loss_fn(pred, y)
            train_loss += loss.item()

            if batch % 100 == 0:
                loss_value, current = loss.item(), batch
                self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        train_loss /= self.train_loader_size
        self.logger.info(f'train mean loss: {train_loss :.4f}')

    @torch.no_grad()
    def val_loop(self, t):
        self.model.eval()

        val_loss, correct = 0, 0
        for X, y in self.val_loader:
            pred = self.model(X)

            correct += self.evaluate(pred, y)

            loss = self.loss_fn(pred, y)
            val_loss += loss.item()

        correct /= self.val_dataset_size
        val_loss /= self.val_loader_size
            
        self.logger.info(f"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}")

    def save_checkpoint(self, epoch_id):
        self.model.eval()
        torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))

    def run(self):
        self.logger.info('==> Start Training')
        print(self.model)

        # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])

        for t in range(self.epoch):
            self.logger.info(f'==> epoch {t + 1}')

            self.train_loop()
            self.val_loop(t + 1)
            self.save_checkpoint(t + 1)

            # lr_scheduler.step()

        self.logger.info('==> End Training')

        # for X, y in self.train_loader:
        #     print(X.size())
        #     print(y.size())

        #     pred = self.model(X)
        #     print(pred)
        #     print(y)

        #     loss = self.loss_fn(pred, y, reduction="mean")
        #     print(loss)

        #     break

        # y_true = [
        #     [0, 1, 0],
        #     [0, 1, 0],
        #     [0, 0, 1],
        #     [0, 0, 0],
        # ]
        # y_pred = [
        #     [0.1, 0.8, 0.9],
        #     [0.2, 0.8, 0.1],
        #     [0.2, 0.1, 0.85],
        #     [0.2, 0.6, 0.1],
        # ]
        # acc_num = self.evaluate(torch.tensor(y_pred), torch.tensor(y_true))