mlp_solver.py
4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import copy
import torch
from model import build_model
from data import build_dataloader
from optimizer import build_optimizer, build_lr_scheduler
from loss import build_loss
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
@SOLVER_REGISTRY.register()
class MLPSolver(object):
def __init__(self, cfg):
self.cfg = copy.deepcopy(cfg)
self.train_loader, self.val_loader = build_dataloader(cfg)
self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
# BatchNorm ?
self.model = build_model(cfg)
self.loss_fn = build_loss(cfg)
self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
self.hyper_params = cfg['solver']['args']
try:
self.epoch = self.hyper_params['epoch']
except Exception:
raise 'should contain epoch in {solver.args}'
self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
@staticmethod
def evaluate(y_pred, y_true, thresholds=0.5):
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=1) + 1
y_true_is_other = torch.sum(y_true, dim=1)
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item()
def train_loop(self):
self.model.train()
train_loss = 0
for batch, (X, y) in enumerate(self.train_loader):
pred = self.model(X)
# loss = self.loss_fn(pred, y, reduction="mean")
loss = self.loss_fn(pred, y)
train_loss += loss.item()
if batch % 100 == 0:
loss_value, current = loss.item(), batch
self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
train_loss /= self.train_loader_size
self.logger.info(f'train mean loss: {train_loss :.4f}')
@torch.no_grad()
def val_loop(self, t):
self.model.eval()
val_loss, correct = 0, 0
for X, y in self.val_loader:
pred = self.model(X)
correct += self.evaluate(pred, y)
loss = self.loss_fn(pred, y)
val_loss += loss.item()
correct /= self.val_dataset_size
val_loss /= self.val_loader_size
self.logger.info(f"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}")
def save_checkpoint(self, epoch_id):
self.model.eval()
torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
def run(self):
self.logger.info('==> Start Training')
print(self.model)
# lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
for t in range(self.epoch):
self.logger.info(f'==> epoch {t + 1}')
self.train_loop()
self.val_loop(t + 1)
self.save_checkpoint(t + 1)
# lr_scheduler.step()
self.logger.info('==> End Training')
# for X, y in self.train_loader:
# print(X.size())
# print(y.size())
# pred = self.model(X)
# print(pred)
# print(y)
# loss = self.loss_fn(pred, y, reduction="mean")
# print(loss)
# break
# y_true = [
# [0, 1, 0],
# [0, 1, 0],
# [0, 0, 1],
# [0, 0, 0],
# ]
# y_pred = [
# [0.1, 0.8, 0.9],
# [0.2, 0.8, 0.1],
# [0.2, 0.1, 0.85],
# [0.2, 0.6, 0.1],
# ]
# acc_num = self.evaluate(torch.tensor(y_pred), torch.tensor(y_true))