CoordinatesData.py
969 Bytes
import os
import json
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from utils.registery import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class CoordinatesData(Dataset):
def __init__(self,
data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset',
anno_file: str = 'train.csv',
phase: str = 'train'):
self.data_root = data_root
self.df = pd.read_csv(anno_file)
self.phase = phase
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
series = self.df.iloc[idx]
name = series['name']
with open(os.path.join(self.data_root, self.phase, name), 'r') as fp:
input_coordinates_list, label_list = json.load(fp)
input_coordinates = torch.tensor(input_coordinates_list)
label = torch.tensor(label_list).float()
return input_coordinates, label