口罩二分类
0 parents
Showing
17 changed files
with
351 additions
and
0 deletions
.idea/.gitignore
0 → 100644
.idea/face_mask_classifier.iml
0 → 100644
1 | <?xml version="1.0" encoding="UTF-8"?> | ||
2 | <module type="PYTHON_MODULE" version="4"> | ||
3 | <component name="NewModuleRootManager"> | ||
4 | <content url="file://$MODULE_DIR$" /> | ||
5 | <orderEntry type="inheritedJdk" /> | ||
6 | <orderEntry type="sourceFolder" forTests="false" /> | ||
7 | </component> | ||
8 | </module> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
.idea/inspectionProfiles/Project_Default.xml
0 → 100644
1 | <component name="InspectionProjectProfileManager"> | ||
2 | <profile version="1.0"> | ||
3 | <option name="myName" value="Project Default" /> | ||
4 | <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true"> | ||
5 | <option name="ignoredPackages"> | ||
6 | <value> | ||
7 | <list size="25"> | ||
8 | <item index="0" class="java.lang.String" itemvalue="tqdm" /> | ||
9 | <item index="1" class="java.lang.String" itemvalue="easydict" /> | ||
10 | <item index="2" class="java.lang.String" itemvalue="scikit_image" /> | ||
11 | <item index="3" class="java.lang.String" itemvalue="matplotlib" /> | ||
12 | <item index="4" class="java.lang.String" itemvalue="tensorboardX" /> | ||
13 | <item index="5" class="java.lang.String" itemvalue="torch" /> | ||
14 | <item index="6" class="java.lang.String" itemvalue="numpy" /> | ||
15 | <item index="7" class="java.lang.String" itemvalue="pycocotools" /> | ||
16 | <item index="8" class="java.lang.String" itemvalue="skimage" /> | ||
17 | <item index="9" class="java.lang.String" itemvalue="Pillow" /> | ||
18 | <item index="10" class="java.lang.String" itemvalue="scipy" /> | ||
19 | <item index="11" class="java.lang.String" itemvalue="torchvision" /> | ||
20 | <item index="12" class="java.lang.String" itemvalue="opencv_python" /> | ||
21 | <item index="13" class="java.lang.String" itemvalue="onnxruntime" /> | ||
22 | <item index="14" class="java.lang.String" itemvalue="onnx-simplifier" /> | ||
23 | <item index="15" class="java.lang.String" itemvalue="onnx" /> | ||
24 | <item index="16" class="java.lang.String" itemvalue="opencv-contrib-python" /> | ||
25 | <item index="17" class="java.lang.String" itemvalue="numba" /> | ||
26 | <item index="18" class="java.lang.String" itemvalue="opencv-python" /> | ||
27 | <item index="19" class="java.lang.String" itemvalue="librosa" /> | ||
28 | <item index="20" class="java.lang.String" itemvalue="tensorboard" /> | ||
29 | <item index="21" class="java.lang.String" itemvalue="dill" /> | ||
30 | <item index="22" class="java.lang.String" itemvalue="pandas" /> | ||
31 | <item index="23" class="java.lang.String" itemvalue="scikit_learn" /> | ||
32 | <item index="24" class="java.lang.String" itemvalue="pytorch-gradual-warmup-lr" /> | ||
33 | </list> | ||
34 | </value> | ||
35 | </option> | ||
36 | </inspection_tool> | ||
37 | </profile> | ||
38 | </component> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
.idea/misc.xml
0 → 100644
.idea/modules.xml
0 → 100644
1 | <?xml version="1.0" encoding="UTF-8"?> | ||
2 | <project version="4"> | ||
3 | <component name="ProjectModuleManager"> | ||
4 | <modules> | ||
5 | <module fileurl="file://$PROJECT_DIR$/.idea/face_mask_classifier.iml" filepath="$PROJECT_DIR$/.idea/face_mask_classifier.iml" /> | ||
6 | </modules> | ||
7 | </component> | ||
8 | </project> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
.idea/vcs.xml
0 → 100644
__pycache__/net.cpython-38.pyc
0 → 100644
No preview for this file type
dataset.py
0 → 100644
1 | import os | ||
2 | import random | ||
3 | |||
4 | import cv2 | ||
5 | import torch | ||
6 | from PIL import Image | ||
7 | from torch.utils.data import Dataset, DataLoader | ||
8 | from utils import * | ||
9 | from torchvision import transforms | ||
10 | |||
11 | classes_names = ['normal', 'mask'] | ||
12 | |||
13 | |||
14 | class FaceMaskDataset(Dataset): | ||
15 | def __init__(self, root_path): | ||
16 | self.transform = transforms.Compose([ | ||
17 | transforms.ToTensor() | ||
18 | ]) | ||
19 | self.dataset = [] | ||
20 | class_names = os.listdir(root_path) | ||
21 | for cls in class_names: | ||
22 | image_names = os.listdir(os.path.join(root_path, cls)) | ||
23 | for image in image_names: | ||
24 | self.dataset.append([os.path.join(root_path, cls, image), classes_names.index(cls)]) | ||
25 | |||
26 | def __len__(self): | ||
27 | return len(self.dataset) | ||
28 | |||
29 | def __getitem__(self, index): | ||
30 | lights=[0.6,0.8,1,1.2,1.4,1.6] | ||
31 | data = self.dataset[index] | ||
32 | image_path = data[0] | ||
33 | image_data = keep_resize_image(image_path) | ||
34 | image_data=cv2.convertScaleAbs(image_data,alpha=lights[random.randint(0,4)]) | ||
35 | image_label = data[1] | ||
36 | return self.transform(image_data), image_label | ||
37 | |||
38 | |||
39 | if __name__ == '__main__': | ||
40 | import tqdm | ||
41 | d = FaceMaskDataset('image') | ||
42 | for i in d: | ||
43 | i | ||
44 |
image/img.png
0 → 100644

55 KB
image/img_1.png
0 → 100644

90.2 KB
infer.py
0 → 100644
1 | import os | ||
2 | |||
3 | import cv2 | ||
4 | |||
5 | import torch | ||
6 | from PIL import Image | ||
7 | import numpy as np | ||
8 | from torchvision import transforms | ||
9 | from net import * | ||
10 | |||
11 | |||
12 | |||
13 | def video(net,): | ||
14 | cap=cv2.VideoCapture(0) | ||
15 | while True: | ||
16 | _, frame = cap.read() | ||
17 | image = Image.fromarray(frame) | ||
18 | w, h = image.size | ||
19 | temp = max(w, h) | ||
20 | mask = Image.new('RGB', (temp, temp)) | ||
21 | if w >= h: | ||
22 | mask.paste(image, (0, (w - h) // 2)) | ||
23 | else: | ||
24 | mask.paste(image, ((h - w) // 2, 0)) | ||
25 | mask = mask.resize((128, 128)) | ||
26 | mask = np.array(mask) | ||
27 | mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR) | ||
28 | mask_image = torch.unsqueeze(transform(mask), dim=0) | ||
29 | out = net(mask_image) | ||
30 | print(out) | ||
31 | out=torch.argmax(out,dim=1) | ||
32 | result = classes_names[int(out.item())] | ||
33 | cv2.putText(frame, result, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2) | ||
34 | cv2.imshow('frame', frame) | ||
35 | |||
36 | if cv2.waitKey(1) & 0XFF == ord('q'): | ||
37 | break | ||
38 | cap.release() | ||
39 | cv2.destroyAllWindows() | ||
40 | |||
41 | def image_cls(net,path): | ||
42 | frame=cv2.imread(path) | ||
43 | image = Image.fromarray(frame) | ||
44 | w, h = image.size | ||
45 | temp = max(w, h) | ||
46 | mask = Image.new('RGB', (temp, temp)) | ||
47 | if w >= h: | ||
48 | mask.paste(image, (0, (w - h) // 2)) | ||
49 | else: | ||
50 | mask.paste(image, ((h - w) // 2, 0)) | ||
51 | mask = mask.resize((128, 128)) | ||
52 | mask = np.array(mask) | ||
53 | mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR) | ||
54 | mask_image = torch.unsqueeze(transform(mask), dim=0) | ||
55 | out = net(mask_image) | ||
56 | print(out) | ||
57 | out = torch.argmax(out, dim=1) | ||
58 | result = classes_names[int(out.item())] | ||
59 | cv2.putText(frame, result, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2) | ||
60 | cv2.imshow('frame', frame) | ||
61 | cv2.waitKey(0) | ||
62 | |||
63 | if __name__ == '__main__': | ||
64 | transform = transforms.Compose([ | ||
65 | transforms.ToTensor() | ||
66 | ]) | ||
67 | net = FaceMaskNet() | ||
68 | weights_path = r'params/new_face_mobilenet_v2.pth' | ||
69 | classes_names = ['normal', 'mask'] | ||
70 | |||
71 | if os.path.exists(weights_path): | ||
72 | net.load_state_dict(torch.load(weights_path, map_location='cuda:0')) | ||
73 | print('successfully loading weights!') | ||
74 | net.eval() | ||
75 | |||
76 | image_cls(net,'image/img_1.png') | ||
77 | |||
78 |
net.py
0 → 100644
1 | import torch | ||
2 | from torch import nn | ||
3 | from torchvision import models | ||
4 | |||
5 | |||
6 | class FaceMaskNet(nn.Module): | ||
7 | def __init__(self): | ||
8 | super(FaceMaskNet, self).__init__() | ||
9 | self.layer = nn.Sequential( | ||
10 | models.mobilenet_v2(pretrained=True) | ||
11 | ) | ||
12 | self.classifier = nn.Sequential( | ||
13 | nn.Linear(1000, 2) | ||
14 | ) | ||
15 | |||
16 | def forward(self, x): | ||
17 | return self.classifier(self.layer(x)) | ||
18 | |||
19 | |||
20 | |||
21 | if __name__ == '__main__': | ||
22 | net = FaceMaskNet() | ||
23 | x = torch.randn(5, 3, 128, 128) | ||
24 | print(net(x).shape) |
params/cls_face_mobilenet_v2.pth
0 → 100644
This file is too large to display.
test.py
0 → 100644
1 | import os | ||
2 | |||
3 | import cv2 | ||
4 | import torch | ||
5 | |||
6 | from net import * | ||
7 | from dataset import * | ||
8 | |||
9 | transform = transforms.Compose([ | ||
10 | transforms.ToTensor() | ||
11 | ]) | ||
12 | |||
13 | data=FaceMaskDataset('/data2/face_mask') | ||
14 | d = DataLoader(data, batch_size=1000, shuffle=True) | ||
15 | with torch.no_grad(): | ||
16 | for i,(image,label) in enumerate(d): | ||
17 | net=FaceMaskNet().cuda() | ||
18 | net.load_state_dict(torch.load('params/face_mobilenet_v2.pth')) | ||
19 | net.eval() | ||
20 | out=net(image.cuda()) | ||
21 | out=torch.argmax(out,dim=1) | ||
22 | acc=torch.mean(torch.eq(label.cuda(), out).float()).item() | ||
23 | print(acc) |
train.py
0 → 100644
1 | import os.path | ||
2 | |||
3 | from torch import nn, optim | ||
4 | import torch | ||
5 | from dataset import * | ||
6 | from torch.utils.data import random_split | ||
7 | from net import * | ||
8 | import tqdm | ||
9 | import time | ||
10 | |||
11 | if __name__ == '__main__': | ||
12 | train_rate=0.8 | ||
13 | batch_size=50 | ||
14 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
15 | print(device) | ||
16 | epochs=50 | ||
17 | |||
18 | datasets = FaceMaskDataset('/data2/new_face_mask') | ||
19 | train_datasets, test_datasets = random_split( | ||
20 | datasets, | ||
21 | [int(len(datasets) * train_rate), len(datasets) - int(len(datasets) * train_rate)], | ||
22 | ) | ||
23 | print(f'train_datasets:{len(train_datasets)} test_datasets:{len(test_datasets)}') | ||
24 | train_data_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True) | ||
25 | test_data_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True) | ||
26 | loss_fun = nn.CrossEntropyLoss() | ||
27 | |||
28 | net = FaceMaskNet().to(device) | ||
29 | if os.path.exists('params/new_face_mobilenet_v2.pth'): | ||
30 | net.load_state_dict(torch.load('params/new_face_mobilenet_v2.pth')) | ||
31 | print('successfully loading weights!') | ||
32 | opt = optim.Adam(net.parameters()) | ||
33 | |||
34 | |||
35 | for epoch in range(1, epochs): | ||
36 | |||
37 | with tqdm.tqdm(train_data_loader) as t1: | ||
38 | for i, (image_data, image_label) in enumerate(train_data_loader): | ||
39 | net.train() | ||
40 | image_data, image_label = image_data.to(device), image_label.to(device) | ||
41 | out = net(image_data) | ||
42 | train_loss = loss_fun(out, image_label) | ||
43 | opt.zero_grad() | ||
44 | train_loss.backward() | ||
45 | opt.step() | ||
46 | t1.set_description(f'Epoch {epoch} train') | ||
47 | t1.set_postfix(train_loss=train_loss.item(), | ||
48 | train_acc=torch.mean(torch.eq(image_label, torch.argmax(out,dim=1)).float()).item()) | ||
49 | time.sleep(0.1) | ||
50 | t1.update(1) | ||
51 | if (i+1) % 10 == 0: | ||
52 | torch.save(net.state_dict(), 'params/new_face_mobilenet_v2.pth') | ||
53 | print(f'epoch : {epoch} {i} successfully save weights!') | ||
54 | |||
55 | |||
56 | acc, temp = 0, 0 | ||
57 | with torch.no_grad(): | ||
58 | net.eval() | ||
59 | with tqdm.tqdm(test_data_loader) as t2: | ||
60 | for j, (image_data, image_label) in enumerate(test_data_loader): | ||
61 | image_data, image_label = image_data.to(device), image_label.to(device) | ||
62 | out = net(image_data) | ||
63 | test_loss = loss_fun(out, image_label) | ||
64 | |||
65 | t2.set_description(f'Epoch {epoch} test') | ||
66 | out = torch.argmax(out, dim=1) | ||
67 | t2.set_postfix(test_loss=test_loss.item(), | ||
68 | test_acc=torch.mean(torch.eq(image_label, out).float()).item()) | ||
69 | time.sleep(0.1) | ||
70 | t2.update(1) | ||
71 | acc += torch.mean(torch.eq(image_label, out).float()).item() | ||
72 | temp += 1 | ||
73 | print(f'epoch : {epoch} avg acc : ', acc / temp) | ||
74 | |||
75 | # acc,temp=0,0 | ||
76 | # with torch.no_grad(): | ||
77 | # net.eval() | ||
78 | # for i, (image_data, image_label) in enumerate(tqdm.tqdm(test_data_loader)): | ||
79 | # image_data, image_label = image_data.to(device), image_label.to(device) | ||
80 | # out = net(image_data) | ||
81 | # test_loss = loss_fun(out, image_label) | ||
82 | # | ||
83 | # out = torch.argmax(out, dim=1) | ||
84 | # | ||
85 | # acc += torch.mean(torch.eq(image_label, out).float()).item() | ||
86 | # temp+=1 | ||
87 | # if i % 5 == 0: | ||
88 | # print(f'epoch : {epoch} {i} test_loss : ', test_loss.item()) | ||
89 | # print(f'epoch : {epoch} {i} test acc : ',torch.mean(torch.eq(image_label, out).float()).item()) | ||
90 | # print(f'epoch : {epoch} avg acc : ',acc/temp) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
utils.py
0 → 100644
1 | import cv2 | ||
2 | import numpy as np | ||
3 | from PIL import Image | ||
4 | # 填充黑边,等比缩放 | ||
5 | def keep_resize_image(image_path,size=(128,128)): | ||
6 | image=Image.open(image_path) | ||
7 | w,h=image.size | ||
8 | temp=max(w,h) | ||
9 | mask=Image.new('RGB',(temp,temp)) | ||
10 | if w>=h: | ||
11 | mask.paste(image,(0,(w-h)//2)) | ||
12 | else: | ||
13 | mask.paste(image,((h-w)//2,0)) | ||
14 | mask=mask.resize(size) | ||
15 | mask=np.array(mask) | ||
16 | mask=cv2.cvtColor(mask,cv2.COLOR_RGB2BGR) | ||
17 | return mask | ||
18 | if __name__ == '__main__': | ||
19 | keep_resize_image('image/mask/img.png') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or sign in to post a comment