hed.py
4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torchvision
from torch import nn
class HED_vgg16(nn.Module):
def __init__(self, num_filters=32, pretrained=False, class_number=2):
# Here is the function part, with no braces ()
super().__init__()
encoder = torchvision.models.vgg16(pretrained=pretrained).features
self.pool = nn.MaxPool2d(2, 2)
self.conv1 = encoder[0:4]
self.score1 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True)) # 256*256
self.conv2 = encoder[5:9]
self.d_conv2 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True)) # 128*128
self.score2 = nn.UpsamplingBilinear2d(scale_factor=2) # 256*256
self.conv3 = encoder[10:16]
self.d_conv3 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) # 64*64
self.score3 = nn.UpsamplingBilinear2d(scale_factor=4) # 256*256
self.conv4 = encoder[17:23]
self.d_conv4 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) # 32*32
self.score4 = nn.UpsamplingBilinear2d(scale_factor=8) # 256*256
self.conv5 = encoder[24:30]
self.d_conv5 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) # 16*16
self.score5 = nn.UpsamplingBilinear2d(scale_factor=16) # 256*256
self.score = nn.Conv2d(5, class_number, 1, 1) # No relu
def forward(self, x):
# Here is the part that calculates the return value
x = self.conv1(x)
s1 = self.score1(x)
x = self.pool(x)
x = self.conv2(x)
s_x = self.d_conv2(x)
s2 = self.score2(s_x)
x = self.pool(x)
x = self.conv3(x)
s_x = self.d_conv3(x)
s3 = self.score3(s_x)
x = self.pool(x)
x = self.conv3(x)
s_x = self.d_conv4(x)
s4 = self.score4(s_x)
x = self.pool(x)
x = self.conv5(x)
s_x = self.d_conv5(x)
s5 = self.score5(s_x)
score = self.score(torch.cat((s1, s2, s3, s4, s5), dim=1))
return score
class HED_res34(nn.Module):
def __init__(self, num_filters=32, pretrained=False, class_number=2):
super().__init__()
encoder = torchvision.models.resnet34(pretrained=pretrained)
self.pool = nn.MaxPool2d(3, 2, 1)
# start
self.start = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) # 128*128
self.d_convs = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True))
self.scores = nn.UpsamplingBilinear2d(scale_factor=2) # 256*256
self.layer1 = encoder.layer1 # 64*64
self.d_conv1 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True))
self.score1 = nn.UpsamplingBilinear2d(scale_factor=4) # 256*256
self.layer2 = encoder.layer2 # 32*32
self.d_conv2 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True))
self.score2 = nn.UpsamplingBilinear2d(scale_factor=8) # 256*256
self.layer3 = encoder.layer3 # 16*16
self.d_conv3 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True))
self.score3 = nn.UpsamplingBilinear2d(scale_factor=16) # 256*256
self.layer4 = encoder.layer4 # 8*8
self.d_conv4 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True))
self.score4 = nn.UpsamplingBilinear2d(scale_factor=32) # 256*256
self.score = nn.Conv2d(5, class_number, 1, 1) # No relu loss_func has softmax
def forward(self, x):
x = self.start(x)
s_x = self.d_convs(x)
ss = self.scores(s_x)
x = self.pool(x)
x = self.layer1(x)
s_x = self.d_conv1(x)
s1 = self.score1(s_x)
x = self.layer2(x)
s_x = self.d_conv2(x)
s2 = self.score2(s_x)
x = self.layer3(x)
s_x = self.d_conv3(x)
s3 = self.score3(s_x)
x = self.layer4(x)
s_x = self.d_conv4(x)
s4 = self.score4(s_x)
score = self.score(torch.cat((s1, s2, s3, s4, ss), dim=1))
return score
def cross_entropy_loss_RCF(prediction, labelf, beta=1.1):
label = labelf.long()
mask = labelf.clone()
num_positive = torch.sum(label == 1).float()
num_negative = torch.sum(label == 0).float()
mask[label == 1] = 1.0 * num_negative / (num_positive + num_negative)
mask[label == 0] = beta * num_positive / (num_positive + num_negative)
mask[label == 2] = 0
cost = F.binary_cross_entropy(
prediction, labelf, weight=mask, reduction='sum')
return cost
if __name__ == '__main__':
model = HED_res34()
total = sum([param.nelement() for param in model.parameters()])
print(total / 1e6)
del model
del total
model = HED_vgg16()
total = sum([param.nelement() for param in model.parameters()])
print(total / 1e6)
del model
del total