finetune_loader.py 2.89 KB
import pandas
import os
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import transforms as Ap
import cv2
import pandas as pd

root = '/home/mly/data/datasets/text_recognition/finetune/src'
img_root = os.path.join(root, 'img')
label_root = os.path.join(root, 'mask')
train_csv_path = os.path.join(root, 'train.csv')
val_csv_path = os.path.join(root, 'test.csv')

SIZE = 512
def train_aug(img, mask):
    transform = A.Compose([
        # A.RandomResizedCrop(height=224, width=224),
        # A.Resize(SIZE, SIZE),
        A.Resize(768, 768),
        A.RandomCrop(SIZE, SIZE),
        A.GaussNoise(p=0.3),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=20, p=0.3),
        A.RandomRotate90(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.Affine(rotate=(-90, 90), shear=(-45, 45), p=0.5),
        A.RandomShadow(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        Ap.ToTensorV2(),
    ])

    augmented = transform(image=img, mask=mask)
    img = augmented['image']
    mask = augmented['mask']

    return img, mask

def val_aug(img, mask):
    transform = A.Compose([
        A.Resize(SIZE, 512),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        Ap.ToTensorV2(),
    ])

    augmented = transform(image=img, mask=mask)
    img = augmented['image']
    mask = augmented['mask']

    return img, mask


class FineTuneData(Dataset):
    def __init__(self, img_root, anno_root, csv, is_training=True):
        self.df = pd.read_csv(csv, index_col=0)
        self.img_list = self.df.path.tolist()
        self.img_root = img_root
        self.anno_root = anno_root
        self.is_training = is_training

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_root, self.img_list[idx])
        mask_path = os.path.join(self.anno_root, self.img_list[idx].split('.')[0] + '.png')
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, 0)

        if self.is_training:
            img, mask = train_aug(img, mask)
        else:
            img, mask = val_aug(img, mask)

        mask[mask > 0] = 1

        img = img.to(torch.float32)
        mask = mask.to(torch.float32)
        return img, mask

    def __len__(self):
        return len(self.img_list)

def get_loader():
    bs = 16
    finetune_train_data = FineTuneData(img_root=img_root, anno_root=label_root, csv=train_csv_path, is_training=True)
    finetune_val_data = FineTuneData(img_root=img_root, anno_root=label_root, csv=val_csv_path, is_training=False)

    train_loader = DataLoader(finetune_train_data, batch_size=bs, num_workers=4, shuffle=True, pin_memory=True, drop_last=True)
    val_loader = DataLoader(finetune_val_data, batch_size=bs, num_workers=4, shuffle=False, pin_memory=True, drop_last=True)

    return train_loader, val_loader