ReconData.py
2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torch
import pandas as pd
import os
import random
import cv2
import albumentations as A
import albumentations.pytorch
import numpy as np
from PIL import Image
from utils.registery import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class ReconData(Dataset):
def __init__(self,
data_root: str = '/data1/lxl/data/ocr/generate1108',
anno_file: str = 'train.csv',
fixed_size: int = 448,
phase: str = 'train'):
self.data_root = data_root
self.df = pd.read_csv(anno_file)
self.img_root = os.path.join(data_root, 'img')
self.gt_root = os.path.join(data_root, 'text_img')
self.fixed_size = fixed_size
self.phase = phase
transform_fn = self.__get_transform()
self.transform = transform_fn[phase]
def __get_transform(self):
train_transform = A.Compose([
A.Resize(height=self.fixed_size, width=self.fixed_size),
# A.RandomBrightness(limit=(-0.5, 0), p=0.5),
A.RandomBrightnessContrast(brightness_limit=(-0.5, 0), contrast_limit=0, p=0.5),
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
A.pytorch.transforms.ToTensorV2()
], additional_targets={'label': 'image'})
val_transform = A.Compose([
A.Resize(height=self.fixed_size, width=self.fixed_size),
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
A.pytorch.transforms.ToTensorV2()
], additional_targets={'label': 'image'})
transform_fn = {'train': train_transform, 'val': val_transform}
return transform_fn
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
series = self.df.iloc[idx]
name = series['name']
# img = Image.open(os.path.join(self.img_root, name))
# gt = Image.open(os.path.join(self.gt_root, name))
img = cv2.imread(os.path.join(self.img_root, name))
gt = cv2.imread(os.path.join(self.gt_root, name))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
transformed = self.transform(image=img, label=gt)
img = transformed['image']
label = transformed['label']
return img, label