edge_loader.py
1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from torch.utils.data import DataLoader, Dataset
import os
import pandas as pd
import cv2
syn_root = '/home/mly/data/datasets/text_recognition/CDLA/CDLA_DATASET_SYN'
syn_img_root = os.path.join(syn_root, 'img')
syn_edge_root = os.path.join(syn_root, 'edge')
sy_mask_root = os.path.join(syn_root, 'mask')
def img_aug():
pass
class EdgeData(Dataset):
def __init__(self, csv, transform):
phase = csv.split('/')[-1].split('.')[0]
df = pd.read_csv(csv, index_col=0)
self.img_list = df.path.tolist()
self.img_root = os.path.join(syn_img_root, phase)
self.edge_root = os.path.join(syn_edge_root, phase)
self.transform = transform
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img = cv2.imread(os.path.join(self.img_root, self.img_list[idx]))
edge = cv2.imread(os.path.join(self.edge_root, self.img_list[idx]))
img, edge = self.transform(img, edge)
return img, edge
def get_loader():
edge_train_data = EdgeData(csv=os.path.join(syn_root, 'train.csv'), transform=img_aug)
edge_val_data = EdgeData(csv=os.path.join(syn_root, 'val.csv'), transform=img_aug)
train_loader = DataLoader(edge_train_data, batch_size=8, num_workers=4, pin_memory=True, drop_last=True)
val_loader = DataLoader(edge_val_data, batch_size=8, num_workers=4, pin_memory=True, drop_last=True)
return train_loader, val_loader