CoordinatesData.py 969 Bytes
import os
import json
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from utils.registery import DATASET_REGISTRY


@DATASET_REGISTRY.register()
class CoordinatesData(Dataset):

    def __init__(self,
                 data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset',
                 anno_file: str = 'train.csv',
                 phase: str = 'train'):
        self.data_root = data_root
        self.df = pd.read_csv(anno_file)
        self.phase = phase


    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        series = self.df.iloc[idx]
        name = series['name']

        with open(os.path.join(self.data_root, self.phase, name), 'r') as fp:
            input_coordinates_list, label_list = json.load(fp)

        input_coordinates = torch.tensor(input_coordinates_list)
        label = torch.tensor(label_list).float() 
        
        return input_coordinates, label