sl_solver.py 7.39 KB
import copy
import os

import torch

from data import build_dataloader
from loss import build_loss
from model import build_model
from optimizer import build_lr_scheduler, build_optimizer
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
from utils import sequence_mask
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report


@SOLVER_REGISTRY.register()
class SLSolver(object):

    def __init__(self, cfg):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.cfg = copy.deepcopy(cfg)

        self.train_loader, self.val_loader = build_dataloader(cfg)
        self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
        self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)

        # BatchNorm ?
        self.model = build_model(cfg).to(self.device)

        self.loss_fn = build_loss(cfg)

        self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])

        self.hyper_params = cfg['solver']['args']
        self.base_on = self.hyper_params['base_on']
        self.model_path = self.hyper_params['model_path']
        try:
            self.epoch = self.hyper_params['epoch']
        except Exception:
            raise 'should contain epoch in {solver.args}'

        self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])

    def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5):
        # [batch_size, seq_len, num_classes]
        y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
        # [batch_size, seq_len]
        y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
        # [batch_size, seq_len]
        y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int()
        y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)

        y_true_idx = torch.argmax(y_true, dim=-1) + 1
        y_true_is_other = torch.sum(y_true, dim=-1).int()
        y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)

        masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)

        return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item()

    def train_loop(self):
        self.model.train()

        seq_lens_sum = torch.zeros(1).to(self.device)
        train_loss = torch.zeros(1).to(self.device)
        correct = torch.zeros(1).to(self.device)
        for batch, (X, y, valid_lens) in enumerate(self.train_loader):
            X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)
            
            pred = self.model(X, valid_lens)
            # [batch_size, seq_len, num_classes]

            loss = self.loss_fn(pred, y, valid_lens)
            train_loss += loss.sum()

            if batch % 100 == 0:
                loss_value, current = loss.sum().item(), batch
                self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
            
            self.optimizer.zero_grad()
            loss.sum().backward()
            self.optimizer.step()

            seq_lens_sum += valid_lens.sum()
            correct += self.accuracy(pred, y, valid_lens)

        # correct /= self.train_dataset_size
        correct /= seq_lens_sum 
        train_loss /= self.train_loader_size
        self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')

    @torch.no_grad()
    def val_loop(self, t):
        self.model.eval()

        seq_lens_sum = torch.zeros(1).to(self.device)
        val_loss = torch.zeros(1).to(self.device)
        correct = torch.zeros(1).to(self.device)
        for X, y, valid_lens in self.val_loader:
            X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)

            # pred = torch.nn.Sigmoid()(self.model(X))
            pred = self.model(X, valid_lens)
            # [batch_size, seq_len, num_classes]

            loss = self.loss_fn(pred, y, valid_lens)
            val_loss += loss.sum()

            seq_lens_sum += valid_lens.sum()
            correct += self.accuracy(pred, y, valid_lens)

        # correct /= self.val_dataset_size
        correct /= seq_lens_sum
        val_loss /= self.val_loader_size
            
        self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")

    def save_checkpoint(self, epoch_id):
        self.model.eval()
        torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))

    def run(self):
        if isinstance(self.base_on, str) and os.path.exists(self.base_on):
            self.model.load_state_dict(torch.load(self.base_on))
            self.logger.info(f'==> Load Model from {self.base_on}')

        self.logger.info('==> Start Training')
        print(self.model)

        lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])

        for t in range(self.epoch):
            self.logger.info(f'==> epoch {t + 1}')

            self.train_loop()
            self.val_loop(t + 1)
            self.save_checkpoint(t + 1)

            lr_scheduler.step()

        self.logger.info('==> End Training')

    # def run(self):
    #     from torch.nn import functional

    #     y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
    #     valid_lens = torch.randint(50, 100, (8, ))
    #     print(valid_lens)

    #     pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)

    #     print(self.accuracy(pred, y, valid_lens))

    def evaluate(self):
        if isinstance(self.model_path, str) and os.path.exists(self.model_path):
            self.model.load_state_dict(torch.load(self.model_path))
            self.logger.info(f'==> Load Model from {self.model_path}')
        else:
            return

        self.model.eval()

        label_true_list = []
        label_pred_list = []
        for X, y, valid_lens in self.val_loader:
            X, y_true, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)

            # pred = torch.nn.Sigmoid()(self.model(X))
            y_pred = self.model(X, valid_lens)
            
            # [batch_size, seq_len, num_classes]
            y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
            # [batch_size, seq_len]
            y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
            # [batch_size, seq_len]
            y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > 0.5).int()
            y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)

            y_true_idx = torch.argmax(y_true, dim=-1) + 1
            y_true_is_other = torch.sum(y_true, dim=-1).int()
            y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)

            # masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)

            for idx, seq_result in enumerate(y_true_rebuild.cpu().numpy().tolist()):
                label_true_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
            for idx, seq_result in enumerate(y_pred_rebuild.cpu().numpy().tolist()):
                label_pred_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
        
        acc = accuracy_score(label_true_list, label_pred_list)
        cm = confusion_matrix(label_true_list, label_pred_list)
        report = classification_report(label_true_list, label_pred_list)
        print(acc)
        print(cm)
        print(report)