dataset.py 1.17 KB
import os
import random

import cv2
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from utils import *
from torchvision import transforms

classes_names = ['normal', 'mask']


class FaceMaskDataset(Dataset):
    def __init__(self, root_path):
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])
        self.dataset = []
        class_names = os.listdir(root_path)
        for cls in class_names:
            image_names = os.listdir(os.path.join(root_path, cls))
            for image in image_names:
                self.dataset.append([os.path.join(root_path, cls, image), classes_names.index(cls)])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        lights=[0.6,0.8,1,1.2,1.4,1.6]
        data = self.dataset[index]
        image_path = data[0]
        image_data = keep_resize_image(image_path)
        image_data=cv2.convertScaleAbs(image_data,alpha=lights[random.randint(0,4)])
        image_label = data[1]
        return self.transform(image_data), image_label


if __name__ == '__main__':
    import tqdm
    d = FaceMaskDataset('image')
    for i in d:
        i