hed_edge_detection_solver.py
2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn
import os
import logging
from model.hed import *
from data.edge_loader import get_loader
from torch.optim import Adam, lr_scheduler
import torch.nn.functional as F
import cv2
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
logging.basicConfig(level=logging.INFO,
filename='/home/lxl/work/ocr/documentSeg/log/logfile/hed_edge_detection.log',
filemode='a')
class Solver:
def __init__(self):
self.train_loader, self.test_loader = get_loader()
self.len_train_loader = len(self.train_loader)
self.len_test_loader = len(self.test_loader)
self.model = HED_res34()
self.model = nn.DataParallel(self.model)
self.model = self.model.cuda()
self.epoch = 10
self.lr = 0.00005
self.weight_decay = 0.00001
self.save_dir = '/home/lxl/work/ocr/documentSeg/log/checkpoint/hed_edge_detection'
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.optimizer = Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
def train(self):
lr_decay = lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5)
for t in range(self.epoch):
logging.info("==============EPOCH {} START================".format(t + 1))
self.model.train()
for i, (x, y) in enumerate(self.train_loader):
x = x.cuda()
y = y.cuda()
self.optimizer.zero_grad()
pred = self.model(x)
loss = cross_entropy_loss_RCF(pred, y.type(torch.int64))
iteration = i + 1
if iteration % 50 == 0 or iteration == 1:
logging.info(
'epoch: {}/{}, iteration: {}/{}, loss: {:.4f}'.format(
t + 1, self.epoch, iteration, self.len_train_loader, loss))
loss.backward()
self.optimizer.step()
lr_decay.step()
logging.info('EPOCH: {}/{}, '.format(t + 1, self.epoch, ))
self.model.eval()
torch.save(self.model.module.state_dict(), '%s/ckpt_epoch_%s.pt' % (self.save_dir, str(t + 1)))
self.test()
@torch.no_grad()
def test(self):
self.model.eval()
for i, (x, y) in enumerate(self.test_loader):
x = x.cuda()
y = y.cuda()
pred = self.model(x)
logging.info('VAL '.format())