edge_loader.py 1.41 KB
from torch.utils.data import DataLoader, Dataset
import os
import pandas as pd
import cv2

syn_root = '/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET_SYN'
syn_img_root = os.path.join(syn_root, 'img')
syn_edge_root = os.path.join(syn_root, 'edge')
sy_mask_root = os.path.join(syn_root, 'mask')


def img_aug():
    pass


class EdgeData(Dataset):
    def __init__(self, csv, transform):
        phase = csv.split('/')[-1].split('.')[0]
        df = pd.read_csv(csv, index_col=0)
        self.img_list = df.path.tolist()
        self.img_root = os.path.join(syn_img_root, phase)
        self.edge_root = os.path.join(syn_edge_root, phase)
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img = cv2.imread(os.path.join(self.img_root, self.img_list[idx]))
        edge = cv2.imread(os.path.join(self.edge_root, self.img_list[idx]))
        img, edge = self.transform(img, edge)

        return img, edge


def get_loader():
    edge_train_data = EdgeData(csv=os.path.join(syn_root, 'train.csv'), transform=img_aug)
    edge_val_data = EdgeData(csv=os.path.join(syn_root, 'val.csv'), transform=img_aug)

    train_loader = DataLoader(edge_train_data, batch_size=8, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = DataLoader(edge_val_data, batch_size=8, num_workers=4, pin_memory=True, drop_last=True)

    return train_loader, val_loader