recon_metric.py
800 Bytes
import numpy as np
import torchmetrics
from utils.registery import METRIC_REGISTRY
@METRIC_REGISTRY.register()
class Recon(object):
def __init__(self):
self.psnr = torchmetrics.PeakSignalNoiseRatio()
self.ssim = torchmetrics.StructuralSimilarityIndexMeasure()
def __call__(self, pred, label):
assert pred.shape[0] == label.shape[0]
psnr_list = list()
ssim_list = list()
for i in range(len(pred)):
psnr_list.append(self.psnr(pred[i].unsqueeze(0), label[i].unsqueeze(0)))
ssim_list.append(self.ssim(pred[i].unsqueeze(0), label[i].unsqueeze(0)))
psnr_result = sum(psnr_list) / len(psnr_list)
ssim_result = sum(ssim_list) / len(ssim_list)
return {'psnr': psnr_result, 'ssim': ssim_result}