import copy import os import torch from data import build_dataloader from loss import build_loss from model import build_model from optimizer import build_lr_scheduler, build_optimizer from utils import SOLVER_REGISTRY, get_logger_and_log_dir from sklearn.metrics import confusion_matrix, accuracy_score, classification_report @SOLVER_REGISTRY.register() class VITSolver(object): def __init__(self, cfg): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.cfg = copy.deepcopy(cfg) self.train_loader, self.val_loader = build_dataloader(cfg) self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) # BatchNorm ? self.model = build_model(cfg).to(self.device) self.loss_fn = build_loss(cfg) self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) self.hyper_params = cfg['solver']['args'] self.no_other = self.hyper_params['no_other'] self.base_on = self.hyper_params['base_on'] self.model_path = self.hyper_params['model_path'] try: self.epoch = self.hyper_params['epoch'] except Exception: raise 'should contain epoch in {solver.args}' self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) def accuracy(self, y_pred, y_true, thresholds=0.5): if self.no_other: return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item() else: y_pred_idx = torch.argmax(y_pred, dim=1) + 1 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) y_true_idx = torch.argmax(y_true, dim=1) + 1 y_true_is_other = torch.sum(y_true, dim=1) y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() def train_loop(self): self.model.train() train_loss = torch.zeros(1).to(self.device) correct = torch.zeros(1).to(self.device) for batch, (X, y) in enumerate(self.train_loader): X, y = X.to(self.device), y.to(self.device) if self.no_other: pred = torch.nn.Softmax(dim=1)(self.model(X)) else: # pred = torch.nn.Sigmoid()(self.model(X)) pred = self.model(X) # loss = self.loss_fn(pred, y, reduction="mean") loss = self.loss_fn(pred, y) train_loss += loss.item() if batch % 100 == 0: loss_value, current = loss.item(), batch self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.no_other: correct += self.accuracy(pred, y) else: correct += self.accuracy(torch.nn.Sigmoid()(pred), y) correct /= self.train_dataset_size train_loss /= self.train_loader_size self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') @torch.no_grad() def val_loop(self, t): self.model.eval() val_loss = torch.zeros(1).to(self.device) correct = torch.zeros(1).to(self.device) for X, y in self.val_loader: X, y = X.to(self.device), y.to(self.device) if self.no_other: pred = torch.nn.Softmax(dim=1)(self.model(X)) else: # pred = torch.nn.Sigmoid()(self.model(X)) pred = self.model(X) loss = self.loss_fn(pred, y) val_loss += loss.item() if self.no_other: correct += self.accuracy(pred, y) else: correct += self.accuracy(torch.nn.Sigmoid()(pred), y) correct /= self.val_dataset_size val_loss /= self.val_loader_size self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") def save_checkpoint(self, epoch_id): self.model.eval() torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) def run(self): if isinstance(self.base_on, str) and os.path.exists(self.base_on): self.model.load_state_dict(torch.load(self.base_on)) self.logger.info(f'==> Load Model from {self.base_on}') self.logger.info('==> Start Training') print(self.model) lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) for t in range(self.epoch): self.logger.info(f'==> epoch {t + 1}') self.train_loop() self.val_loop(t + 1) self.save_checkpoint(t + 1) lr_scheduler.step() self.logger.info('==> End Training') def evaluate(self): if isinstance(self.model_path, str) and os.path.exists(self.model_path): self.model.load_state_dict(torch.load(self.model_path)) self.logger.info(f'==> Load Model from {self.model_path}') else: return self.model.eval() label_true_list = [] label_pred_list = [] for X, y in self.val_loader: X, y_true = X.to(self.device), y.to(self.device) if self.no_other: pred = torch.nn.Softmax(dim=1)(self.model(X)) else: # pred = torch.nn.Sigmoid()(self.model(X)) pred = self.model(X) y_pred = torch.nn.Sigmoid()(pred) y_pred_idx = torch.argmax(y_pred, dim=1) + 1 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) y_true_idx = torch.argmax(y_true, dim=1) + 1 y_true_is_other = torch.sum(y_true, dim=1) y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) label_true_list.extend(y_true_rebuild.cpu().numpy().tolist()) label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist()) acc = accuracy_score(label_true_list, label_pred_list) cm = confusion_matrix(label_true_list, label_pred_list) report = classification_report(label_true_list, label_pred_list) print(acc) print(cm) print(report)