finetune_loader.py
2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pandas
import os
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import transforms as Ap
import cv2
import pandas as pd
root = '/home/mly/data/datasets/text_recognition/finetune/src'
img_root = os.path.join(root, 'img')
label_root = os.path.join(root, 'mask')
train_csv_path = os.path.join(root, 'train.csv')
val_csv_path = os.path.join(root, 'test.csv')
SIZE = 512
def train_aug(img, mask):
transform = A.Compose([
# A.RandomResizedCrop(height=224, width=224),
# A.Resize(SIZE, SIZE),
A.Resize(768, 768),
A.RandomCrop(SIZE, SIZE),
A.GaussNoise(p=0.3),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=20, p=0.3),
A.RandomRotate90(p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.Affine(rotate=(-90, 90), shear=(-45, 45), p=0.5),
A.RandomShadow(p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
Ap.ToTensorV2(),
])
augmented = transform(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask']
return img, mask
def val_aug(img, mask):
transform = A.Compose([
A.Resize(SIZE, 512),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
Ap.ToTensorV2(),
])
augmented = transform(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask']
return img, mask
class FineTuneData(Dataset):
def __init__(self, img_root, anno_root, csv, is_training=True):
self.df = pd.read_csv(csv, index_col=0)
self.img_list = self.df.path.tolist()
self.img_root = img_root
self.anno_root = anno_root
self.is_training = is_training
def __getitem__(self, idx):
img_path = os.path.join(self.img_root, self.img_list[idx])
mask_path = os.path.join(self.anno_root, self.img_list[idx].split('.')[0] + '.png')
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, 0)
if self.is_training:
img, mask = train_aug(img, mask)
else:
img, mask = val_aug(img, mask)
mask[mask > 0] = 1
img = img.to(torch.float32)
mask = mask.to(torch.float32)
return img, mask
def __len__(self):
return len(self.img_list)
def get_loader():
bs = 16
finetune_train_data = FineTuneData(img_root=img_root, anno_root=label_root, csv=train_csv_path, is_training=True)
finetune_val_data = FineTuneData(img_root=img_root, anno_root=label_root, csv=val_csv_path, is_training=False)
train_loader = DataLoader(finetune_train_data, batch_size=bs, num_workers=4, shuffle=True, pin_memory=True, drop_last=True)
val_loader = DataLoader(finetune_val_data, batch_size=bs, num_workers=4, shuffle=False, pin_memory=True, drop_last=True)
return train_loader, val_loader