recon_metric.py 800 Bytes
import numpy as np
import torchmetrics


from utils.registery import METRIC_REGISTRY

@METRIC_REGISTRY.register()
class Recon(object):
    def __init__(self):
        self.psnr = torchmetrics.PeakSignalNoiseRatio()
        self.ssim = torchmetrics.StructuralSimilarityIndexMeasure()

    def __call__(self, pred, label):

        assert pred.shape[0] == label.shape[0]

        psnr_list = list()
        ssim_list = list()

        for i in range(len(pred)):
            psnr_list.append(self.psnr(pred[i].unsqueeze(0), label[i].unsqueeze(0)))
            ssim_list.append(self.ssim(pred[i].unsqueeze(0), label[i].unsqueeze(0)))

        psnr_result = sum(psnr_list) / len(psnr_list)
        ssim_result = sum(ssim_list) / len(ssim_list)

        return {'psnr': psnr_result, 'ssim': ssim_result}