import copy from torch.utils.data import DataLoader from utils.registery import DATASET_REGISTRY from .CoordinatesData import CoordinatesData def build_dataset(cfg): dataset_cfg = copy.deepcopy(cfg) try: dataset_cfg = dataset_cfg['dataset'] except Exception: raise 'should contain {dataset}!' train_cfg = copy.deepcopy(dataset_cfg) val_cfg = copy.deepcopy(dataset_cfg) train_cfg['args']['anno_file'] = train_cfg['args'].pop('train_anno_file') train_cfg['args'].pop('val_anno_file', None) train_cfg['args']['phase'] = 'train' val_cfg['args']['anno_file'] = val_cfg['args'].pop('val_anno_file') val_cfg['args'].pop('train_anno_file', None) val_cfg['args']['phase'] = 'valid' train_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**train_cfg['args']) val_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**val_cfg['args']) return train_data, val_data def build_dataloader(cfg): dataloader_cfg = copy.deepcopy(cfg) try: dataloader_cfg = cfg['dataloader'] except Exception: raise 'should contain {dataloader}!' train_ds, val_ds = build_dataset(cfg) train_loader = DataLoader(train_ds, **dataloader_cfg) val_loader = DataLoader(val_ds, **dataloader_cfg) return train_loader, val_loader