net.py
529 Bytes
import torch
from torch import nn
from torchvision import models
class FaceMaskNet(nn.Module):
def __init__(self):
super(FaceMaskNet, self).__init__()
self.layer = nn.Sequential(
models.mobilenet_v2(pretrained=True)
)
self.classifier = nn.Sequential(
nn.Linear(1000, 2)
)
def forward(self, x):
return self.classifier(self.layer(x))
if __name__ == '__main__':
net = FaceMaskNet()
x = torch.randn(5, 3, 128, 128)
print(net(x).shape)