ReconData.py 2.38 KB
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torch
import pandas as pd
import os
import random
import cv2
import albumentations as A
import albumentations.pytorch
import numpy as np
from PIL import Image

from utils.registery import DATASET_REGISTRY


@DATASET_REGISTRY.register()
class ReconData(Dataset):
    def __init__(self,
                 data_root: str = '/data1/lxl/data/ocr/generate1108',
                 anno_file: str = 'train.csv',
                 fixed_size: int = 448,
                 phase: str = 'train'):
        self.data_root = data_root
        self.df = pd.read_csv(anno_file)
        self.img_root = os.path.join(data_root, 'img')
        self.gt_root = os.path.join(data_root, 'text_img')
        self.fixed_size = fixed_size

        self.phase = phase
        transform_fn = self.__get_transform()
        self.transform = transform_fn[phase]

    def __get_transform(self):

        train_transform = A.Compose([
            A.Resize(height=self.fixed_size, width=self.fixed_size),
            # A.RandomBrightness(limit=(-0.5, 0), p=0.5),
            A.RandomBrightnessContrast(brightness_limit=(-0.5, 0), contrast_limit=0, p=0.5),
            A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
            A.pytorch.transforms.ToTensorV2()
        ], additional_targets={'label': 'image'})

        val_transform = A.Compose([
            A.Resize(height=self.fixed_size, width=self.fixed_size),
            A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
            A.pytorch.transforms.ToTensorV2()
        ], additional_targets={'label': 'image'})

        transform_fn = {'train': train_transform, 'val': val_transform}

        return transform_fn

    def __len__(self):
        return len(self.df)

    
    def __getitem__(self, idx):
        series = self.df.iloc[idx]
        name = series['name']

        # img = Image.open(os.path.join(self.img_root, name))
        # gt = Image.open(os.path.join(self.gt_root, name))
        img = cv2.imread(os.path.join(self.img_root, name))
        gt = cv2.imread(os.path.join(self.gt_root, name))

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)

        transformed = self.transform(image=img, label=gt)
        img = transformed['image']
        label = transformed['label']
        
        return img, label