skip_solver.py 4.04 KB
import torch
from core.data import build_dataloader
from core.model import build_model
from core.optimizer import build_optimizer, build_lr_scheduler
from core.loss import build_loss
from core.metric import build_metric
from utils.registery import SOLVER_REGISTRY
from utils.logger import get_logger_and_log_path
import os
import copy
import datetime
from torch.nn.parallel import DistributedDataParallel
import numpy as np
import pandas as pd
import yaml

from .base_solver import BaseSolver


@SOLVER_REGISTRY.register()
class SkipSolver(BaseSolver):
    def __init__(self, cfg):
        super().__init__(cfg)

    def train(self):
        if torch.distributed.get_rank() == 0:
            self.logger.info('==> Start Training')
        lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])

        for t in range(self.epoch):
            self.train_loader.sampler.set_epoch(t)
            if torch.distributed.get_rank() == 0:
                self.logger.info(f'==> epoch {t + 1}')
            self.model.train()

            pred_list = list()
            label_list = list()

            mean_loss = 0.0

            for i, data in enumerate(self.train_loader):
                self.optimizer.zero_grad()
                image = data['image'].cuda(self.local_rank)
                label = data['label'].cuda(self.local_rank)
                residual = image - label

                pred = self.model(image)
                
                reconstruction_loss = self.criterion(pred['reconstruction'], label)
                residual_loss = self.criterion(pred['residual'], residual)
                loss = reconstruction_loss + residual_loss
                mean_loss += loss.item()

                if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0):
                    reconstruction_loss_value = reconstruction_loss.item()
                    residual_loss_value = residual_loss.item()
                    loss_value = loss.item()
                    self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, reconstruction loss: {reconstruction_loss_value :.4f}, residual loss: {residual_loss_value :.4f}, loss: {loss_value :.4f}')
                
                loss.backward()
                self.optimizer.step()

            mean_loss = mean_loss / self.len_train_loader
            
            if torch.distributed.get_rank() == 0:
                # self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
                self.logger.info(f'==> train mean loss: {mean_loss :.4f}')
                self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1)
            self.val(t + 1)
            lr_scheduler.step()

        if self.local_rank == 0:
            self.logger.info('==> End Training')

    @torch.no_grad()
    def val(self, t):
        self.model.eval()

        pred_list = list()
        label_list = list()

        for i, data in enumerate(self.val_loader):
            image = data['image'].cuda(self.local_rank)
            label = data['label'].cuda(self.local_rank)
            residual = image - label

            pred = self.model(image)

            pred_list.append(pred['reconstruction'].detach().cpu())
            label_list.append(label.detach().cpu())

        pred_list = torch.cat(pred_list, dim=0)
        label_list = torch.cat(label_list, dim=0)

        metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
        if torch.distributed.get_rank() == 0:
            self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")

    def run(self):
        self.train()

    @staticmethod
    def save_dict_to_yaml(dict_value, save_path):
        with open(save_path, 'w', encoding='utf-8') as file:
            yaml.dump(dict_value, file, sort_keys=False)


    def save_checkpoint(self, model, cfg, log_path, epoch_id):
        model.eval()
        torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))