hed.py 4.86 KB
import torch
import torchvision
from torch import nn


class HED_vgg16(nn.Module):
    def __init__(self, num_filters=32, pretrained=False, class_number=2):
        # Here is the function part, with no braces ()
        super().__init__()
        encoder = torchvision.models.vgg16(pretrained=pretrained).features

        self.pool = nn.MaxPool2d(2, 2)

        self.conv1 = encoder[0:4]
        self.score1 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True))  # 256*256

        self.conv2 = encoder[5:9]
        self.d_conv2 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True))  # 128*128
        self.score2 = nn.UpsamplingBilinear2d(scale_factor=2)  # 256*256

        self.conv3 = encoder[10:16]
        self.d_conv3 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True))  # 64*64
        self.score3 = nn.UpsamplingBilinear2d(scale_factor=4)  # 256*256

        self.conv4 = encoder[17:23]
        self.d_conv4 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True))  # 32*32
        self.score4 = nn.UpsamplingBilinear2d(scale_factor=8)  # 256*256

        self.conv5 = encoder[24:30]
        self.d_conv5 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True))  # 16*16
        self.score5 = nn.UpsamplingBilinear2d(scale_factor=16)  # 256*256

        self.score = nn.Conv2d(5, class_number, 1, 1)  # No relu

    def forward(self, x):
        # Here is the part that calculates the return value
        x = self.conv1(x)
        s1 = self.score1(x)
        x = self.pool(x)

        x = self.conv2(x)
        s_x = self.d_conv2(x)
        s2 = self.score2(s_x)
        x = self.pool(x)

        x = self.conv3(x)
        s_x = self.d_conv3(x)
        s3 = self.score3(s_x)
        x = self.pool(x)

        x = self.conv3(x)
        s_x = self.d_conv4(x)
        s4 = self.score4(s_x)
        x = self.pool(x)

        x = self.conv5(x)
        s_x = self.d_conv5(x)
        s5 = self.score5(s_x)

        score = self.score(torch.cat((s1, s2, s3, s4, s5), dim=1))

        return score


class HED_res34(nn.Module):
    def __init__(self, num_filters=32, pretrained=False, class_number=2):
        super().__init__()
        encoder = torchvision.models.resnet34(pretrained=pretrained)

        self.pool = nn.MaxPool2d(3, 2, 1)

        # start
        self.start = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)  # 128*128
        self.d_convs = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True))
        self.scores = nn.UpsamplingBilinear2d(scale_factor=2)  # 256*256

        self.layer1 = encoder.layer1  # 64*64
        self.d_conv1 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True))
        self.score1 = nn.UpsamplingBilinear2d(scale_factor=4)  # 256*256

        self.layer2 = encoder.layer2  # 32*32
        self.d_conv2 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True))
        self.score2 = nn.UpsamplingBilinear2d(scale_factor=8)  # 256*256

        self.layer3 = encoder.layer3  # 16*16
        self.d_conv3 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True))
        self.score3 = nn.UpsamplingBilinear2d(scale_factor=16)  # 256*256

        self.layer4 = encoder.layer4  # 8*8
        self.d_conv4 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True))
        self.score4 = nn.UpsamplingBilinear2d(scale_factor=32)  # 256*256

        self.score = nn.Conv2d(5, class_number, 1, 1)  # No relu loss_func has softmax

    def forward(self, x):
        x = self.start(x)
        s_x = self.d_convs(x)
        ss = self.scores(s_x)
        x = self.pool(x)

        x = self.layer1(x)
        s_x = self.d_conv1(x)
        s1 = self.score1(s_x)

        x = self.layer2(x)
        s_x = self.d_conv2(x)
        s2 = self.score2(s_x)

        x = self.layer3(x)
        s_x = self.d_conv3(x)
        s3 = self.score3(s_x)

        x = self.layer4(x)
        s_x = self.d_conv4(x)
        s4 = self.score4(s_x)

        score = self.score(torch.cat((s1, s2, s3, s4, ss), dim=1))

        return score


def cross_entropy_loss_RCF(prediction, labelf, beta=1.1):
    label = labelf.long()
    mask = labelf.clone()
    num_positive = torch.sum(label == 1).float()
    num_negative = torch.sum(label == 0).float()

    mask[label == 1] = 1.0 * num_negative / (num_positive + num_negative)
    mask[label == 0] = beta * num_positive / (num_positive + num_negative)
    mask[label == 2] = 0
    cost = F.binary_cross_entropy(
        prediction, labelf, weight=mask, reduction='sum')

    return cost


if __name__ == '__main__':
    model = HED_res34()
    total = sum([param.nelement() for param in model.parameters()])
    print(total / 1e6)
    del model
    del total
    model = HED_vgg16()
    total = sum([param.nelement() for param in model.parameters()])
    print(total / 1e6)
    del model
    del total