net.py 529 Bytes
import torch
from torch import nn
from torchvision import models


class FaceMaskNet(nn.Module):
    def __init__(self):
        super(FaceMaskNet, self).__init__()
        self.layer = nn.Sequential(
            models.mobilenet_v2(pretrained=True)
        )
        self.classifier = nn.Sequential(
            nn.Linear(1000, 2)
        )

    def forward(self, x):
        return self.classifier(self.layer(x))



if __name__ == '__main__':
    net = FaceMaskNet()
    x = torch.randn(5, 3, 128, 128)
    print(net(x).shape)