CDLA_loader.py
2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from torch.utils.data import Dataset, DataLoader
import albumentations as A
import pandas as pd
import os
import cv2
import torch
import numpy as np
from PIL import Image
train_img_root = '/home/mly/data/datasets/text_recognition/CDLA/syn/train/img/'
val_img_root = '/home/mly/data/datasets/text_recognition/CDLA/syn/val/img'
train_anno_root = '/home/mly/data/datasets/text_recognition/CDLA/syn/train/mask/'
val_anno_root = '/home/mly/data/datasets/text_recognition/CDLA/syn/val/mask/'
train_csv_path = '/home/mly/data/datasets/text_recognition/CDLA/syn/train.csv'
val_csv_path = '/home/mly/data/datasets/text_recognition/CDLA/syn/val.csv'
def img_aug(img, mask):
transform = A.Compose([
A.RandomResizedCrop(height=256, width=256),
A.GaussNoise(p=0.3),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=20, p=0.3),
A.RandomRotate90(p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.Affine(rotate=(-90, 90), shear=(-45, 45), p=0.5),
A.RandomShadow(p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
augmented = transform(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask']
return img, mask
def val_tran(img, mask):
transform = A.Compose([
A.Resize(256, 256),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
augmented = transform(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask']
return img, mask
class CDLA(Dataset):
def __init__(self, img_root, anno_root, csv, is_training=True):
self.df = pd.read_csv(csv, index_col=0)
self.img_list = self.df.path.tolist()
self.img_root = img_root
self.anno_root = anno_root
self.is_training = is_training
def __getitem__(self, idx):
img_path = os.path.join(self.img_root, self.img_list[idx])
mask_path = os.path.join(self.anno_root, self.img_list[idx])
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, 0)
if self.is_training:
img, mask = img_aug(img, mask)
else:
img, mask = val_tran(img, mask)
img = np.transpose(img, axes=(2, 0, 1))
mask[mask > 0] = 1
img = torch.tensor(img.copy(), dtype=torch.float32)
mask = torch.tensor(mask.copy(), dtype=torch.float32)
return img, mask
def __len__(self):
return len(self.img_list)
def get_loader():
CDLA_train_data = CDLA(img_root=train_img_root, anno_root=train_anno_root, csv=train_csv_path, is_training=True)
CDLA_val_data = CDLA(img_root=val_img_root, anno_root=val_anno_root, csv=val_csv_path, is_training=False)
train_loader = DataLoader(CDLA_train_data, batch_size=128, num_workers=8, shuffle=True, pin_memory=True, drop_last=True)
val_loader = DataLoader(CDLA_val_data, batch_size=128, num_workers=8, shuffle=False, pin_memory=True, drop_last=True)
return train_loader, val_loader