collate_fn.py 404 Bytes
import torch
from utils.registery import COLLATE_FN_REGISTRY


@COLLATE_FN_REGISTRY.register()
def base_collate_fn(batch):
    images, labels = list(), list()
    for image, label in batch:
        images.append(image.unsqueeze(0))
        labels.append(label.unsqueeze(0))
    
    images = torch.cat(images, dim=0)
    labels = torch.cat(labels, dim=0)

    return {'image': images, 'label': labels}