import os import json import torch from torch.utils.data import DataLoader, Dataset import pandas as pd from utils.registery import DATASET_REGISTRY @DATASET_REGISTRY.register() class CoordinatesData(Dataset): def __init__(self, data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset', anno_file: str = 'train.csv', phase: str = 'train'): self.data_root = data_root self.df = pd.read_csv(anno_file) self.phase = phase def __len__(self): return len(self.df) def __getitem__(self, idx): series = self.df.iloc[idx] name = series['name'] with open(os.path.join(self.data_root, self.phase, name), 'r') as fp: input_coordinates_list, label_list = json.load(fp) input_coordinates = torch.tensor(input_coordinates_list) label = torch.tensor(label_list).float() return input_coordinates, label