net.py
5.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
'''
_*_coding:utf-8 _*_
@Time :2022/1/28 19:05
@Author : qiaofengsheng
@File :net.py
@Software :PyCharm
'''
import torch
from torchvision import models
from torch import nn
from efficientnet_pytorch import EfficientNet
class ClassifierNet(nn.Module):
def __init__(self, net_type='resnet18', num_classes=10, pretrained=False):
super(ClassifierNet, self).__init__()
self.layer = None
if net_type == 'resnet18': self.layer = nn.Sequential(models.resnet18(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'resnet34': self.layer = nn.Sequential(models.resnet34(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'resnet50': self.layer = nn.Sequential(models.resnet50(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'resnet101': self.layer = nn.Sequential(models.resnet101(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'resnet152': self.layer = nn.Sequential(models.resnet152(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'resnext101_32x8d': self.layer = nn.Sequential(models.resnext101_32x8d(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'resnext50_32x4d': self.layer = nn.Sequential(models.resnext50_32x4d(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'wide_resnet50_2': self.layer = nn.Sequential(models.wide_resnet50_2(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'wide_resnet101_2': self.layer = nn.Sequential(models.wide_resnet101_2(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'densenet121': self.layer = nn.Sequential(models.densenet121(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'densenet161': self.layer = nn.Sequential(models.densenet161(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'densenet169': self.layer = nn.Sequential(models.densenet169(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'densenet201': self.layer = nn.Sequential(models.densenet201(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'vgg11': self.layer = nn.Sequential(models.vgg11(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'vgg13': self.layer = nn.Sequential(models.vgg13(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'vgg13_bn': self.layer = nn.Sequential(models.vgg13_bn(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'vgg19': self.layer = nn.Sequential(models.vgg19(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'vgg19_bn': self.layer = nn.Sequential(models.vgg19_bn(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'vgg16': self.layer = nn.Sequential(models.vgg16(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'vgg16_bn': self.layer = nn.Sequential(models.vgg16_bn(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'inception_v3': self.layer = nn.Sequential(models.inception_v3(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'mobilenet_v2': self.layer = nn.Sequential(models.mobilenet_v2(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'mobilenet_v3_small': self.layer = nn.Sequential(
models.mobilenet_v3_small(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'mobilenet_v3_large': self.layer = nn.Sequential(
models.mobilenet_v3_large(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'shufflenet_v2_x0_5': self.layer = nn.Sequential(
models.shufflenet_v2_x0_5(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'shufflenet_v2_x1_0': self.layer = nn.Sequential(
models.shufflenet_v2_x1_0(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'shufflenet_v2_x1_5': self.layer = nn.Sequential(
models.shufflenet_v2_x1_5(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'shufflenet_v2_x2_0': self.layer = nn.Sequential(
models.shufflenet_v2_x2_0(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'alexnet':
self.layer = nn.Sequential(models.alexnet(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'googlenet':
self.layer = nn.Sequential(models.googlenet(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'mnasnet0_5':
self.layer = nn.Sequential(models.mnasnet0_5(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'mnasnet1_0':
self.layer = nn.Sequential(models.mnasnet1_0(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'mnasnet1_3':
self.layer = nn.Sequential(models.mnasnet1_3(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'mnasnet0_75':
self.layer = nn.Sequential(models.mnasnet0_75(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'squeezenet1_0':
self.layer = nn.Sequential(models.squeezenet1_0(pretrained=pretrained,num_classes=num_classes), )
if net_type == 'squeezenet1_1':
self.layer = nn.Sequential(models.squeezenet1_1(pretrained=pretrained,num_classes=num_classes), )
if net_type in ['efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4',
'efficientnet-b5', 'efficientnet-b6']:
if pretrained:
self.layer = nn.Sequential(EfficientNet.from_pretrained(net_type,num_classes=num_classes))
else:
self.layer = nn.Sequential(EfficientNet.from_name(net_type,num_classes=num_classes))
def forward(self, x):
return self.layer(x)
if __name__ == '__main__':
net=ClassifierNet('mnasnet1_0',pretrained=False)
x=torch.randn(1,3,125,125)
print(net(x).shape)