net.py 5.93 KB
'''
 _*_coding:utf-8 _*_
 @Time     :2022/1/28   19:05
 @Author   : qiaofengsheng
 @File     :net.py
 @Software :PyCharm
 '''

import torch
from torchvision import models
from torch import nn
from efficientnet_pytorch import EfficientNet


class ClassifierNet(nn.Module):
    def __init__(self, net_type='resnet18', num_classes=10, pretrained=False):
        super(ClassifierNet, self).__init__()
        self.layer = None
        if net_type == 'resnet18': self.layer = nn.Sequential(models.resnet18(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'resnet34': self.layer = nn.Sequential(models.resnet34(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'resnet50': self.layer = nn.Sequential(models.resnet50(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'resnet101': self.layer = nn.Sequential(models.resnet101(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'resnet152': self.layer = nn.Sequential(models.resnet152(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'resnext101_32x8d': self.layer = nn.Sequential(models.resnext101_32x8d(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'resnext50_32x4d': self.layer = nn.Sequential(models.resnext50_32x4d(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'wide_resnet50_2': self.layer = nn.Sequential(models.wide_resnet50_2(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'wide_resnet101_2': self.layer = nn.Sequential(models.wide_resnet101_2(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'densenet121': self.layer = nn.Sequential(models.densenet121(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'densenet161': self.layer = nn.Sequential(models.densenet161(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'densenet169': self.layer = nn.Sequential(models.densenet169(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'densenet201': self.layer = nn.Sequential(models.densenet201(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'vgg11': self.layer = nn.Sequential(models.vgg11(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'vgg13': self.layer = nn.Sequential(models.vgg13(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'vgg13_bn': self.layer = nn.Sequential(models.vgg13_bn(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'vgg19': self.layer = nn.Sequential(models.vgg19(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'vgg19_bn': self.layer = nn.Sequential(models.vgg19_bn(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'vgg16': self.layer = nn.Sequential(models.vgg16(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'vgg16_bn': self.layer = nn.Sequential(models.vgg16_bn(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'inception_v3': self.layer = nn.Sequential(models.inception_v3(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'mobilenet_v2': self.layer = nn.Sequential(models.mobilenet_v2(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'mobilenet_v3_small': self.layer = nn.Sequential(
            models.mobilenet_v3_small(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'mobilenet_v3_large': self.layer = nn.Sequential(
            models.mobilenet_v3_large(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'shufflenet_v2_x0_5': self.layer = nn.Sequential(
            models.shufflenet_v2_x0_5(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'shufflenet_v2_x1_0': self.layer = nn.Sequential(
            models.shufflenet_v2_x1_0(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'shufflenet_v2_x1_5': self.layer = nn.Sequential(
            models.shufflenet_v2_x1_5(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'shufflenet_v2_x2_0': self.layer = nn.Sequential(
            models.shufflenet_v2_x2_0(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'alexnet':
            self.layer = nn.Sequential(models.alexnet(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'googlenet':
            self.layer = nn.Sequential(models.googlenet(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'mnasnet0_5':
            self.layer = nn.Sequential(models.mnasnet0_5(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'mnasnet1_0':
            self.layer = nn.Sequential(models.mnasnet1_0(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'mnasnet1_3':
            self.layer = nn.Sequential(models.mnasnet1_3(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'mnasnet0_75':
            self.layer = nn.Sequential(models.mnasnet0_75(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'squeezenet1_0':
            self.layer = nn.Sequential(models.squeezenet1_0(pretrained=pretrained,num_classes=num_classes), )
        if net_type == 'squeezenet1_1':
            self.layer = nn.Sequential(models.squeezenet1_1(pretrained=pretrained,num_classes=num_classes), )
        if net_type in ['efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4',
                        'efficientnet-b5', 'efficientnet-b6']:
            if pretrained:
                self.layer = nn.Sequential(EfficientNet.from_pretrained(net_type,num_classes=num_classes))
            else:
                self.layer = nn.Sequential(EfficientNet.from_name(net_type,num_classes=num_classes))

    def forward(self, x):
        return self.layer(x)

if __name__ == '__main__':
    net=ClassifierNet('mnasnet1_0',pretrained=False)
    x=torch.randn(1,3,125,125)
    print(net(x).shape)