skip_solver.py
4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from core.data import build_dataloader
from core.model import build_model
from core.optimizer import build_optimizer, build_lr_scheduler
from core.loss import build_loss
from core.metric import build_metric
from utils.registery import SOLVER_REGISTRY
from utils.logger import get_logger_and_log_path
import os
import copy
import datetime
from torch.nn.parallel import DistributedDataParallel
import numpy as np
import pandas as pd
import yaml
from .base_solver import BaseSolver
@SOLVER_REGISTRY.register()
class SkipSolver(BaseSolver):
def __init__(self, cfg):
super().__init__(cfg)
def train(self):
if torch.distributed.get_rank() == 0:
self.logger.info('==> Start Training')
lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
for t in range(self.epoch):
self.train_loader.sampler.set_epoch(t)
if torch.distributed.get_rank() == 0:
self.logger.info(f'==> epoch {t + 1}')
self.model.train()
pred_list = list()
label_list = list()
mean_loss = 0.0
for i, data in enumerate(self.train_loader):
self.optimizer.zero_grad()
image = data['image'].cuda(self.local_rank)
label = data['label'].cuda(self.local_rank)
residual = image - label
pred = self.model(image)
reconstruction_loss = self.criterion(pred['reconstruction'], label)
residual_loss = self.criterion(pred['residual'], residual)
loss = reconstruction_loss + residual_loss
mean_loss += loss.item()
if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0):
reconstruction_loss_value = reconstruction_loss.item()
residual_loss_value = residual_loss.item()
loss_value = loss.item()
self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, reconstruction loss: {reconstruction_loss_value :.4f}, residual loss: {residual_loss_value :.4f}, loss: {loss_value :.4f}')
loss.backward()
self.optimizer.step()
mean_loss = mean_loss / self.len_train_loader
if torch.distributed.get_rank() == 0:
# self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
self.logger.info(f'==> train mean loss: {mean_loss :.4f}')
self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1)
self.val(t + 1)
lr_scheduler.step()
if self.local_rank == 0:
self.logger.info('==> End Training')
@torch.no_grad()
def val(self, t):
self.model.eval()
pred_list = list()
label_list = list()
for i, data in enumerate(self.val_loader):
image = data['image'].cuda(self.local_rank)
label = data['label'].cuda(self.local_rank)
residual = image - label
pred = self.model(image)
pred_list.append(pred['reconstruction'].detach().cpu())
label_list.append(label.detach().cpu())
pred_list = torch.cat(pred_list, dim=0)
label_list = torch.cat(label_list, dim=0)
metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
if torch.distributed.get_rank() == 0:
self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
def run(self):
self.train()
@staticmethod
def save_dict_to_yaml(dict_value, save_path):
with open(save_path, 'w', encoding='utf-8') as file:
yaml.dump(dict_value, file, sort_keys=False)
def save_checkpoint(self, model, cfg, log_path, epoch_id):
model.eval()
torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))