base_solver.py
5.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
from core.data import build_dataloader
from core.model import build_model
from core.optimizer import build_optimizer, build_lr_scheduler
from core.loss import build_loss
from core.metric import build_metric
from utils.registery import SOLVER_REGISTRY
from utils.logger import get_logger_and_log_path
import os
import copy
import datetime
from torch.nn.parallel import DistributedDataParallel
import numpy as np
import pandas as pd
import yaml
@SOLVER_REGISTRY.register()
class BaseSolver(object):
def __init__(self, cfg):
self.cfg = copy.deepcopy(cfg)
self.local_rank = torch.distributed.get_rank()
self.train_loader, self.val_loader = build_dataloader(cfg)
self.len_train_loader, self.len_val_loader = len(self.train_loader), len(self.val_loader)
self.criterion = build_loss(cfg).cuda(self.local_rank)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(build_model(cfg))
self.model = DistributedDataParallel(model.cuda(self.local_rank), device_ids=[self.local_rank], find_unused_parameters=True)
self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
self.hyper_params = cfg['solver']['args']
crt_date = datetime.date.today().strftime('%Y-%m-%d')
self.logger, self.log_path = get_logger_and_log_path(crt_date=crt_date, **cfg['solver']['logger'])
self.metric_fn = build_metric(cfg)
try:
self.epoch = self.hyper_params['epoch']
except Exception:
raise 'should contain epoch in {solver.args}'
if self.local_rank == 0:
self.save_dict_to_yaml(self.cfg, os.path.join(self.log_path, 'config.yaml'))
self.logger.info(self.cfg)
def train(self):
if torch.distributed.get_rank() == 0:
self.logger.info('==> Start Training')
lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
for t in range(self.epoch):
self.train_loader.sampler.set_epoch(t)
if torch.distributed.get_rank() == 0:
self.logger.info(f'==> epoch {t + 1}')
self.model.train()
pred_list = list()
label_list = list()
mean_loss = 0.0
for i, data in enumerate(self.train_loader):
self.optimizer.zero_grad()
image = data['image'].cuda(self.local_rank)
label = data['label'].cuda(self.local_rank)
pred = self.model(image)
loss = self.criterion(pred, label)
mean_loss += loss.item()
if (i == 0 or i % 200 == 0) and (torch.distributed.get_rank() == 0):
loss_value = loss.item()
self.logger.info(f'epoch: {t + 1}/{self.epoch}, iteration: {i + 1}/{self.len_train_loader}, loss: {loss_value :.4f}')
loss.backward()
self.optimizer.step()
# batch_pred = [torch.zeros_like(pred) for _ in range(torch.distributed.get_world_size())] # 1
# torch.distributed.all_gather(batch_pred, pred)
# pred_list.append(torch.cat(batch_pred, dim=0).detach().cpu())
# batch_label = [torch.zeros_like(label) for _ in range(torch.distributed.get_world_size())]
# torch.distributed.all_gather(batch_label, label)
# label_list.append(torch.cat(batch_label, dim=0).detach().cpu())
# pred_list = torch.cat(pred_list, dim=0)
# label_list = torch.cat(label_list, dim=0)
# metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
mean_loss = mean_loss / self.len_train_loader
if torch.distributed.get_rank() == 0:
# self.logger.info(f"==> train mean loss: {mean_loss :.4f}, psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
self.logger.info(f'==> train mean loss: {mean_loss :.4f}')
self.save_checkpoint(self.model, self.cfg, self.log_path, t + 1)
self.val(t + 1)
lr_scheduler.step()
if self.local_rank == 0:
self.logger.info('==> End Training')
@torch.no_grad()
def val(self, t):
self.model.eval()
pred_list = list()
label_list = list()
for i, data in enumerate(self.val_loader):
feat = data['image'].cuda(self.local_rank)
label = data['label'].cuda(self.local_rank)
pred = self.model(feat)
pred_list.append(pred.detach().cpu())
label_list.append(label.detach().cpu())
pred_list = torch.cat(pred_list, dim=0)
label_list = torch.cat(label_list, dim=0)
metric_dict = self.metric_fn(**{'pred': pred_list, 'label': label_list})
if torch.distributed.get_rank() == 0:
self.logger.info(f"==> val psnr: {metric_dict['psnr'] :.4f}, ssim: {metric_dict['ssim'] :.4f}")
def run(self):
self.train()
@staticmethod
def save_dict_to_yaml(dict_value, save_path):
with open(save_path, 'w', encoding='utf-8') as file:
yaml.dump(dict_value, file, sort_keys=False)
def save_checkpoint(self, model, cfg, log_path, epoch_id):
model.eval()
torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))