collate_fn.py
404 Bytes
import torch
from utils.registery import COLLATE_FN_REGISTRY
@COLLATE_FN_REGISTRY.register()
def base_collate_fn(batch):
images, labels = list(), list()
for image, label in batch:
images.append(image.unsqueeze(0))
labels.append(label.unsqueeze(0))
images = torch.cat(images, dim=0)
labels = torch.cat(labels, dim=0)
return {'image': images, 'label': labels}