base_solver.py 5.35 KB
import torch
from core.data import build_dataloader
from core.model import build_model
from core.optimizer import build_optimizer, build_lr_scheduler
from core.loss import build_loss
from core.metric import build_metric
from utils.registery import SOLVER_REGISTRY
from utils.logger import get_logger_and_log_path
import os
import copy
import datetime
from torch.nn.parallel import DistributedDataParallel
import numpy as np
import pandas as pd
import yaml


@SOLVER_REGISTRY.register()
class BaseSolver(object):
    def __init__(self, cfg):
        self.cfg = copy.deepcopy(cfg)
        self.local_rank = torch.distributed.get_rank()
        self.train_loader, self.val_loader = build_dataloader(cfg)
        self.len_train_loader, self.len_val_loader = len(self.train_loader), len(self.val_loader)
        self.criterion = build_loss(cfg).cuda(self.local_rank)
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(build_model(cfg))
        self.model = DistributedDataParallel(model.cuda(self.local_rank), device_ids=[self.local_rank], find_unused_parameters=True)
        self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
        self.hyper_params = cfg['solver']['args']
        crt_date = datetime.date.today().strftime('%Y-%m-%d')
        self.logger, self.log_path = get_logger_and_log_path(crt_date=crt_date, **cfg['solver']['logger'])
        self.metric_fn = build_metric(cfg)
        try:
            self.epoch = self.hyper_params['epoch']
        except Exception:
            raise 'should contain epoch in {solver.args}'
        if self.local_rank == 0:
            self.save_dict_to_yaml(self.cfg, os.path.join(self.log_path, 'config.yaml'))
            self.logger.info(self.cfg)

    def train(self):
        if torch.distributed.get_rank() == 0:
            self.logger.info('==> Start Training')
        lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])

        for t in range(self.epoch):
            self.train_loader.sampler.set_epoch(t)
            if torch.distributed.get_rank() == 0:
                self.logger.info(f'==> epoch {t + 1}')
            self.model.train()

            pred_list = list()
            label_list = list()

            mean_loss = 0.0

            for i, data in enumerate(self.train_loader):
                self.optimizer.zero_grad()
                image = data['image'].cuda(self.local_rank)
                label = data['label'].cuda(self.local_rank)

                pred = self.model(image)

                loss = self.criterion(pred, label)
                mean_loss += loss.item()

                if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0):
                    loss_value = loss.item()
                    self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, loss: {loss_value :.4f}')
                
                loss.backward()
                self.optimizer.step()

                # batch_pred = [torch.zeros_like(pred) for _ in range(torch.distributed.get_world_size())] # 1
                # torch.distributed.all_gather(batch_pred, pred)
                # pred_list.append(torch.cat(batch_pred, dim=0).detach().cpu())

                # batch_label = [torch.zeros_like(label) for _ in range(torch.distributed.get_world_size())]
                # torch.distributed.all_gather(batch_label, label)
                # label_list.append(torch.cat(batch_label, dim=0).detach().cpu())

            # pred_list = torch.cat(pred_list, dim=0)
            # label_list = torch.cat(label_list, dim=0)
            # metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
            mean_loss = mean_loss / self.len_train_loader
            
            if torch.distributed.get_rank() == 0:
                # self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
                self.logger.info(f'==> train mean loss: {mean_loss :.4f}')
                self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1)
            self.val(t + 1)
            lr_scheduler.step()

        if self.local_rank == 0:
            self.logger.info('==> End Training')

    @torch.no_grad()
    def val(self, t):
        self.model.eval()

        pred_list = list()
        label_list = list()

        for i, data in enumerate(self.val_loader):
            feat = data['image'].cuda(self.local_rank)
            label = data['label'].cuda(self.local_rank)

            pred = self.model(feat)

            pred_list.append(pred.detach().cpu())
            label_list.append(label.detach().cpu())

        pred_list = torch.cat(pred_list, dim=0)
        label_list = torch.cat(label_list, dim=0)

        metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
        if torch.distributed.get_rank() == 0:
            self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")

    def run(self):
        self.train()

    @staticmethod
    def save_dict_to_yaml(dict_value, save_path):
        with open(save_path, 'w', encoding='utf-8') as file:
            yaml.dump(dict_value, file, sort_keys=False)


    def save_checkpoint(self, model, cfg, log_path, epoch_id):
        model.eval()
        torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))