口罩二分类
0 parents
Showing
17 changed files
with
351 additions
and
0 deletions
.idea/.gitignore
0 → 100644
.idea/face_mask_classifier.iml
0 → 100644
| 1 | <?xml version="1.0" encoding="UTF-8"?> | ||
| 2 | <module type="PYTHON_MODULE" version="4"> | ||
| 3 | <component name="NewModuleRootManager"> | ||
| 4 | <content url="file://$MODULE_DIR$" /> | ||
| 5 | <orderEntry type="inheritedJdk" /> | ||
| 6 | <orderEntry type="sourceFolder" forTests="false" /> | ||
| 7 | </component> | ||
| 8 | </module> | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
.idea/inspectionProfiles/Project_Default.xml
0 → 100644
| 1 | <component name="InspectionProjectProfileManager"> | ||
| 2 | <profile version="1.0"> | ||
| 3 | <option name="myName" value="Project Default" /> | ||
| 4 | <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true"> | ||
| 5 | <option name="ignoredPackages"> | ||
| 6 | <value> | ||
| 7 | <list size="25"> | ||
| 8 | <item index="0" class="java.lang.String" itemvalue="tqdm" /> | ||
| 9 | <item index="1" class="java.lang.String" itemvalue="easydict" /> | ||
| 10 | <item index="2" class="java.lang.String" itemvalue="scikit_image" /> | ||
| 11 | <item index="3" class="java.lang.String" itemvalue="matplotlib" /> | ||
| 12 | <item index="4" class="java.lang.String" itemvalue="tensorboardX" /> | ||
| 13 | <item index="5" class="java.lang.String" itemvalue="torch" /> | ||
| 14 | <item index="6" class="java.lang.String" itemvalue="numpy" /> | ||
| 15 | <item index="7" class="java.lang.String" itemvalue="pycocotools" /> | ||
| 16 | <item index="8" class="java.lang.String" itemvalue="skimage" /> | ||
| 17 | <item index="9" class="java.lang.String" itemvalue="Pillow" /> | ||
| 18 | <item index="10" class="java.lang.String" itemvalue="scipy" /> | ||
| 19 | <item index="11" class="java.lang.String" itemvalue="torchvision" /> | ||
| 20 | <item index="12" class="java.lang.String" itemvalue="opencv_python" /> | ||
| 21 | <item index="13" class="java.lang.String" itemvalue="onnxruntime" /> | ||
| 22 | <item index="14" class="java.lang.String" itemvalue="onnx-simplifier" /> | ||
| 23 | <item index="15" class="java.lang.String" itemvalue="onnx" /> | ||
| 24 | <item index="16" class="java.lang.String" itemvalue="opencv-contrib-python" /> | ||
| 25 | <item index="17" class="java.lang.String" itemvalue="numba" /> | ||
| 26 | <item index="18" class="java.lang.String" itemvalue="opencv-python" /> | ||
| 27 | <item index="19" class="java.lang.String" itemvalue="librosa" /> | ||
| 28 | <item index="20" class="java.lang.String" itemvalue="tensorboard" /> | ||
| 29 | <item index="21" class="java.lang.String" itemvalue="dill" /> | ||
| 30 | <item index="22" class="java.lang.String" itemvalue="pandas" /> | ||
| 31 | <item index="23" class="java.lang.String" itemvalue="scikit_learn" /> | ||
| 32 | <item index="24" class="java.lang.String" itemvalue="pytorch-gradual-warmup-lr" /> | ||
| 33 | </list> | ||
| 34 | </value> | ||
| 35 | </option> | ||
| 36 | </inspection_tool> | ||
| 37 | </profile> | ||
| 38 | </component> | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
.idea/misc.xml
0 → 100644
.idea/modules.xml
0 → 100644
| 1 | <?xml version="1.0" encoding="UTF-8"?> | ||
| 2 | <project version="4"> | ||
| 3 | <component name="ProjectModuleManager"> | ||
| 4 | <modules> | ||
| 5 | <module fileurl="file://$PROJECT_DIR$/.idea/face_mask_classifier.iml" filepath="$PROJECT_DIR$/.idea/face_mask_classifier.iml" /> | ||
| 6 | </modules> | ||
| 7 | </component> | ||
| 8 | </project> | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
.idea/vcs.xml
0 → 100644
__pycache__/net.cpython-38.pyc
0 → 100644
No preview for this file type
dataset.py
0 → 100644
| 1 | import os | ||
| 2 | import random | ||
| 3 | |||
| 4 | import cv2 | ||
| 5 | import torch | ||
| 6 | from PIL import Image | ||
| 7 | from torch.utils.data import Dataset, DataLoader | ||
| 8 | from utils import * | ||
| 9 | from torchvision import transforms | ||
| 10 | |||
| 11 | classes_names = ['normal', 'mask'] | ||
| 12 | |||
| 13 | |||
| 14 | class FaceMaskDataset(Dataset): | ||
| 15 | def __init__(self, root_path): | ||
| 16 | self.transform = transforms.Compose([ | ||
| 17 | transforms.ToTensor() | ||
| 18 | ]) | ||
| 19 | self.dataset = [] | ||
| 20 | class_names = os.listdir(root_path) | ||
| 21 | for cls in class_names: | ||
| 22 | image_names = os.listdir(os.path.join(root_path, cls)) | ||
| 23 | for image in image_names: | ||
| 24 | self.dataset.append([os.path.join(root_path, cls, image), classes_names.index(cls)]) | ||
| 25 | |||
| 26 | def __len__(self): | ||
| 27 | return len(self.dataset) | ||
| 28 | |||
| 29 | def __getitem__(self, index): | ||
| 30 | lights=[0.6,0.8,1,1.2,1.4,1.6] | ||
| 31 | data = self.dataset[index] | ||
| 32 | image_path = data[0] | ||
| 33 | image_data = keep_resize_image(image_path) | ||
| 34 | image_data=cv2.convertScaleAbs(image_data,alpha=lights[random.randint(0,4)]) | ||
| 35 | image_label = data[1] | ||
| 36 | return self.transform(image_data), image_label | ||
| 37 | |||
| 38 | |||
| 39 | if __name__ == '__main__': | ||
| 40 | import tqdm | ||
| 41 | d = FaceMaskDataset('image') | ||
| 42 | for i in d: | ||
| 43 | i | ||
| 44 |
image/img.png
0 → 100644
55 KB
image/img_1.png
0 → 100644
90.2 KB
infer.py
0 → 100644
| 1 | import os | ||
| 2 | |||
| 3 | import cv2 | ||
| 4 | |||
| 5 | import torch | ||
| 6 | from PIL import Image | ||
| 7 | import numpy as np | ||
| 8 | from torchvision import transforms | ||
| 9 | from net import * | ||
| 10 | |||
| 11 | |||
| 12 | |||
| 13 | def video(net,): | ||
| 14 | cap=cv2.VideoCapture(0) | ||
| 15 | while True: | ||
| 16 | _, frame = cap.read() | ||
| 17 | image = Image.fromarray(frame) | ||
| 18 | w, h = image.size | ||
| 19 | temp = max(w, h) | ||
| 20 | mask = Image.new('RGB', (temp, temp)) | ||
| 21 | if w >= h: | ||
| 22 | mask.paste(image, (0, (w - h) // 2)) | ||
| 23 | else: | ||
| 24 | mask.paste(image, ((h - w) // 2, 0)) | ||
| 25 | mask = mask.resize((128, 128)) | ||
| 26 | mask = np.array(mask) | ||
| 27 | mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR) | ||
| 28 | mask_image = torch.unsqueeze(transform(mask), dim=0) | ||
| 29 | out = net(mask_image) | ||
| 30 | print(out) | ||
| 31 | out=torch.argmax(out,dim=1) | ||
| 32 | result = classes_names[int(out.item())] | ||
| 33 | cv2.putText(frame, result, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2) | ||
| 34 | cv2.imshow('frame', frame) | ||
| 35 | |||
| 36 | if cv2.waitKey(1) & 0XFF == ord('q'): | ||
| 37 | break | ||
| 38 | cap.release() | ||
| 39 | cv2.destroyAllWindows() | ||
| 40 | |||
| 41 | def image_cls(net,path): | ||
| 42 | frame=cv2.imread(path) | ||
| 43 | image = Image.fromarray(frame) | ||
| 44 | w, h = image.size | ||
| 45 | temp = max(w, h) | ||
| 46 | mask = Image.new('RGB', (temp, temp)) | ||
| 47 | if w >= h: | ||
| 48 | mask.paste(image, (0, (w - h) // 2)) | ||
| 49 | else: | ||
| 50 | mask.paste(image, ((h - w) // 2, 0)) | ||
| 51 | mask = mask.resize((128, 128)) | ||
| 52 | mask = np.array(mask) | ||
| 53 | mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR) | ||
| 54 | mask_image = torch.unsqueeze(transform(mask), dim=0) | ||
| 55 | out = net(mask_image) | ||
| 56 | print(out) | ||
| 57 | out = torch.argmax(out, dim=1) | ||
| 58 | result = classes_names[int(out.item())] | ||
| 59 | cv2.putText(frame, result, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2) | ||
| 60 | cv2.imshow('frame', frame) | ||
| 61 | cv2.waitKey(0) | ||
| 62 | |||
| 63 | if __name__ == '__main__': | ||
| 64 | transform = transforms.Compose([ | ||
| 65 | transforms.ToTensor() | ||
| 66 | ]) | ||
| 67 | net = FaceMaskNet() | ||
| 68 | weights_path = r'params/new_face_mobilenet_v2.pth' | ||
| 69 | classes_names = ['normal', 'mask'] | ||
| 70 | |||
| 71 | if os.path.exists(weights_path): | ||
| 72 | net.load_state_dict(torch.load(weights_path, map_location='cuda:0')) | ||
| 73 | print('successfully loading weights!') | ||
| 74 | net.eval() | ||
| 75 | |||
| 76 | image_cls(net,'image/img_1.png') | ||
| 77 | |||
| 78 |
net.py
0 → 100644
| 1 | import torch | ||
| 2 | from torch import nn | ||
| 3 | from torchvision import models | ||
| 4 | |||
| 5 | |||
| 6 | class FaceMaskNet(nn.Module): | ||
| 7 | def __init__(self): | ||
| 8 | super(FaceMaskNet, self).__init__() | ||
| 9 | self.layer = nn.Sequential( | ||
| 10 | models.mobilenet_v2(pretrained=True) | ||
| 11 | ) | ||
| 12 | self.classifier = nn.Sequential( | ||
| 13 | nn.Linear(1000, 2) | ||
| 14 | ) | ||
| 15 | |||
| 16 | def forward(self, x): | ||
| 17 | return self.classifier(self.layer(x)) | ||
| 18 | |||
| 19 | |||
| 20 | |||
| 21 | if __name__ == '__main__': | ||
| 22 | net = FaceMaskNet() | ||
| 23 | x = torch.randn(5, 3, 128, 128) | ||
| 24 | print(net(x).shape) |
params/cls_face_mobilenet_v2.pth
0 → 100644
This file is too large to display.
test.py
0 → 100644
| 1 | import os | ||
| 2 | |||
| 3 | import cv2 | ||
| 4 | import torch | ||
| 5 | |||
| 6 | from net import * | ||
| 7 | from dataset import * | ||
| 8 | |||
| 9 | transform = transforms.Compose([ | ||
| 10 | transforms.ToTensor() | ||
| 11 | ]) | ||
| 12 | |||
| 13 | data=FaceMaskDataset('/data2/face_mask') | ||
| 14 | d = DataLoader(data, batch_size=1000, shuffle=True) | ||
| 15 | with torch.no_grad(): | ||
| 16 | for i,(image,label) in enumerate(d): | ||
| 17 | net=FaceMaskNet().cuda() | ||
| 18 | net.load_state_dict(torch.load('params/face_mobilenet_v2.pth')) | ||
| 19 | net.eval() | ||
| 20 | out=net(image.cuda()) | ||
| 21 | out=torch.argmax(out,dim=1) | ||
| 22 | acc=torch.mean(torch.eq(label.cuda(), out).float()).item() | ||
| 23 | print(acc) |
train.py
0 → 100644
| 1 | import os.path | ||
| 2 | |||
| 3 | from torch import nn, optim | ||
| 4 | import torch | ||
| 5 | from dataset import * | ||
| 6 | from torch.utils.data import random_split | ||
| 7 | from net import * | ||
| 8 | import tqdm | ||
| 9 | import time | ||
| 10 | |||
| 11 | if __name__ == '__main__': | ||
| 12 | train_rate=0.8 | ||
| 13 | batch_size=50 | ||
| 14 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
| 15 | print(device) | ||
| 16 | epochs=50 | ||
| 17 | |||
| 18 | datasets = FaceMaskDataset('/data2/new_face_mask') | ||
| 19 | train_datasets, test_datasets = random_split( | ||
| 20 | datasets, | ||
| 21 | [int(len(datasets) * train_rate), len(datasets) - int(len(datasets) * train_rate)], | ||
| 22 | ) | ||
| 23 | print(f'train_datasets:{len(train_datasets)} test_datasets:{len(test_datasets)}') | ||
| 24 | train_data_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True) | ||
| 25 | test_data_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True) | ||
| 26 | loss_fun = nn.CrossEntropyLoss() | ||
| 27 | |||
| 28 | net = FaceMaskNet().to(device) | ||
| 29 | if os.path.exists('params/new_face_mobilenet_v2.pth'): | ||
| 30 | net.load_state_dict(torch.load('params/new_face_mobilenet_v2.pth')) | ||
| 31 | print('successfully loading weights!') | ||
| 32 | opt = optim.Adam(net.parameters()) | ||
| 33 | |||
| 34 | |||
| 35 | for epoch in range(1, epochs): | ||
| 36 | |||
| 37 | with tqdm.tqdm(train_data_loader) as t1: | ||
| 38 | for i, (image_data, image_label) in enumerate(train_data_loader): | ||
| 39 | net.train() | ||
| 40 | image_data, image_label = image_data.to(device), image_label.to(device) | ||
| 41 | out = net(image_data) | ||
| 42 | train_loss = loss_fun(out, image_label) | ||
| 43 | opt.zero_grad() | ||
| 44 | train_loss.backward() | ||
| 45 | opt.step() | ||
| 46 | t1.set_description(f'Epoch {epoch} train') | ||
| 47 | t1.set_postfix(train_loss=train_loss.item(), | ||
| 48 | train_acc=torch.mean(torch.eq(image_label, torch.argmax(out,dim=1)).float()).item()) | ||
| 49 | time.sleep(0.1) | ||
| 50 | t1.update(1) | ||
| 51 | if (i+1) % 10 == 0: | ||
| 52 | torch.save(net.state_dict(), 'params/new_face_mobilenet_v2.pth') | ||
| 53 | print(f'epoch : {epoch} {i} successfully save weights!') | ||
| 54 | |||
| 55 | |||
| 56 | acc, temp = 0, 0 | ||
| 57 | with torch.no_grad(): | ||
| 58 | net.eval() | ||
| 59 | with tqdm.tqdm(test_data_loader) as t2: | ||
| 60 | for j, (image_data, image_label) in enumerate(test_data_loader): | ||
| 61 | image_data, image_label = image_data.to(device), image_label.to(device) | ||
| 62 | out = net(image_data) | ||
| 63 | test_loss = loss_fun(out, image_label) | ||
| 64 | |||
| 65 | t2.set_description(f'Epoch {epoch} test') | ||
| 66 | out = torch.argmax(out, dim=1) | ||
| 67 | t2.set_postfix(test_loss=test_loss.item(), | ||
| 68 | test_acc=torch.mean(torch.eq(image_label, out).float()).item()) | ||
| 69 | time.sleep(0.1) | ||
| 70 | t2.update(1) | ||
| 71 | acc += torch.mean(torch.eq(image_label, out).float()).item() | ||
| 72 | temp += 1 | ||
| 73 | print(f'epoch : {epoch} avg acc : ', acc / temp) | ||
| 74 | |||
| 75 | # acc,temp=0,0 | ||
| 76 | # with torch.no_grad(): | ||
| 77 | # net.eval() | ||
| 78 | # for i, (image_data, image_label) in enumerate(tqdm.tqdm(test_data_loader)): | ||
| 79 | # image_data, image_label = image_data.to(device), image_label.to(device) | ||
| 80 | # out = net(image_data) | ||
| 81 | # test_loss = loss_fun(out, image_label) | ||
| 82 | # | ||
| 83 | # out = torch.argmax(out, dim=1) | ||
| 84 | # | ||
| 85 | # acc += torch.mean(torch.eq(image_label, out).float()).item() | ||
| 86 | # temp+=1 | ||
| 87 | # if i % 5 == 0: | ||
| 88 | # print(f'epoch : {epoch} {i} test_loss : ', test_loss.item()) | ||
| 89 | # print(f'epoch : {epoch} {i} test acc : ',torch.mean(torch.eq(image_label, out).float()).item()) | ||
| 90 | # print(f'epoch : {epoch} avg acc : ',acc/temp) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
utils.py
0 → 100644
| 1 | import cv2 | ||
| 2 | import numpy as np | ||
| 3 | from PIL import Image | ||
| 4 | # 填充黑边,等比缩放 | ||
| 5 | def keep_resize_image(image_path,size=(128,128)): | ||
| 6 | image=Image.open(image_path) | ||
| 7 | w,h=image.size | ||
| 8 | temp=max(w,h) | ||
| 9 | mask=Image.new('RGB',(temp,temp)) | ||
| 10 | if w>=h: | ||
| 11 | mask.paste(image,(0,(w-h)//2)) | ||
| 12 | else: | ||
| 13 | mask.paste(image,((h-w)//2,0)) | ||
| 14 | mask=mask.resize(size) | ||
| 15 | mask=np.array(mask) | ||
| 16 | mask=cv2.cvtColor(mask,cv2.COLOR_RGB2BGR) | ||
| 17 | return mask | ||
| 18 | if __name__ == '__main__': | ||
| 19 | keep_resize_image('image/mask/img.png') | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or sign in to post a comment