CDLA_loader.py 2.98 KB
from torch.utils.data import Dataset, DataLoader
import albumentations as A
import pandas as pd
import os
import cv2
import torch
import numpy as np
from PIL import Image

train_img_root = '/home/mly/data/datasets/text_recognition/CDLA/syn/train/img/'
val_img_root = '/home/mly/data/datasets/text_recognition/CDLA/syn/val/img'
train_anno_root = '/home/mly/data/datasets/text_recognition/CDLA/syn/train/mask/'
val_anno_root = '/home/mly/data/datasets/text_recognition/CDLA/syn/val/mask/'
train_csv_path = '/home/mly/data/datasets/text_recognition/CDLA/syn/train.csv'
val_csv_path = '/home/mly/data/datasets/text_recognition/CDLA/syn/val.csv'


def img_aug(img, mask):
    transform = A.Compose([
        A.RandomResizedCrop(height=256, width=256),
        A.GaussNoise(p=0.3),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=20, p=0.3),
        A.RandomRotate90(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.Affine(rotate=(-90, 90), shear=(-45, 45), p=0.5),
        A.RandomShadow(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    augmented = transform(image=img, mask=mask)
    img = augmented['image']
    mask = augmented['mask']

    return img, mask

def val_tran(img, mask):
    transform = A.Compose([
        A.Resize(256, 256),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    augmented = transform(image=img, mask=mask)
    img = augmented['image']
    mask = augmented['mask']

    return img, mask



class CDLA(Dataset):
    def __init__(self, img_root, anno_root, csv, is_training=True):
        self.df = pd.read_csv(csv, index_col=0)
        self.img_list = self.df.path.tolist()
        self.img_root = img_root
        self.anno_root = anno_root
        self.is_training = is_training

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_root, self.img_list[idx])
        mask_path = os.path.join(self.anno_root, self.img_list[idx])
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, 0)

        if self.is_training:
            img, mask = img_aug(img, mask)
        else:
            img, mask = val_tran(img, mask)


        img = np.transpose(img, axes=(2, 0, 1))

        mask[mask > 0] = 1

        img = torch.tensor(img.copy(), dtype=torch.float32)
        mask = torch.tensor(mask.copy(), dtype=torch.float32)

        return img, mask

    def __len__(self):
        return len(self.img_list)


def get_loader():
    CDLA_train_data = CDLA(img_root=train_img_root, anno_root=train_anno_root, csv=train_csv_path, is_training=True)
    CDLA_val_data = CDLA(img_root=val_img_root, anno_root=val_anno_root, csv=val_csv_path, is_training=False)

    train_loader = DataLoader(CDLA_train_data, batch_size=128, num_workers=8, shuffle=True, pin_memory=True, drop_last=True)
    val_loader = DataLoader(CDLA_val_data, batch_size=128, num_workers=8, shuffle=False, pin_memory=True, drop_last=True)

    return train_loader, val_loader