solver.py
3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn as nn
from segmentation_models_pytorch.losses import SoftCrossEntropyLoss
import segmentation_models_pytorch as smp
import os
import logging
from data.finetune_loader import *
from torch.optim import Adam, lr_scheduler
import torch.nn.functional as F
import cv2
import imutils
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
logging.basicConfig(level=logging.INFO,
filename='/home/lxl/work/ocr/documentSeg/log/logfile/finetune_random_crop.log',
filemode='a')
def calc_iou(inputs, targets):
inputs = inputs.detach().cpu()
targets = targets.detach().cpu()
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
total = (inputs + targets).sum()
union = total - intersection
IoU = intersection / union
del inputs, targets
return IoU
class BCEDiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super().__init__()
def forward(self, input, target):
pred = input.view(-1)
truth = target.view(-1)
# BCE loss
bce_loss = nn.BCELoss()(pred, truth).double()
# Dice Loss
dice_coef = (2.0 * (pred * truth).double().sum() + 1) / (
pred.double().sum() + truth.double().sum() + 1
)
return bce_loss + (1 - dice_coef)
class Solver:
def __init__(self):
self.train_loader, self.test_loader = get_loader()
self.len_train_loader = len(self.train_loader)
self.len_test_loader = len(self.test_loader)
self.model = smp.DeepLabV3Plus(
encoder_name='resnet50',
encoder_weights='imagenet',
in_channels=3,
classes=1,
activation='sigmoid'
)
self.save_dir = '/home/lxl/work/ocr/documentSeg/log/checkpoint/finetune_random_crop/'
self.model = nn.DataParallel(self.model)
self.model = self.model.cuda()
self.epoch = 20
self.lr = 0.0008
self.weight_decay = 0.00001
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.optimizer = Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
self.criterion = BCEDiceLoss().cuda()
def train(self):
lr_decay = lr_scheduler.StepLR(self.optimizer, step_size=2, gamma=0.5)
for t in range(self.epoch):
logging.info("==============EPOCH {} START================".format(t + 1))
self.model.train()
all_iou = 0.0
for i, (x, y) in enumerate(self.train_loader):
x = x.cuda()
y = y.cuda()
self.optimizer.zero_grad()
pred = self.model(x)
loss = self.criterion(pred, y)
iou = calc_iou(pred, y)
all_iou += iou
iteration = i + 1
if iteration % 50 == 0 or iteration == 1:
logging.info(
'epoch: {}/{}, iteration: {}/{}, loss: {:.4f}'.format(
t + 1, self.epoch, iteration, self.len_train_loader, loss.item()))
loss.backward()
self.optimizer.step()
lr_decay.step()
miou = all_iou / self.len_train_loader
logging.info('EPOCH: {}/{}, MIOU: {:.4f}'.format(t + 1, self.epoch, miou))
self.model.eval()
torch.save(self.model.module.state_dict(), '%s/ckpt_epoch_%s.pt' % (self.save_dir, str(t + 1)))
self.val()
@torch.no_grad()
def val(self):
self.model.eval()
all_iou = 0.0
for i, (x, y) in enumerate(self.test_loader):
x = x.cuda()
y = y.cuda()
pred = self.model(x)
iou = calc_iou(pred, y)
all_iou += iou
miou = all_iou / self.len_test_loader
logging.info('VAL MIOU: {:.4f}'.format(miou))