loss_fun.py 625 Bytes
'''
 _*_coding:utf-8 _*_
 @Time     :2022/1/28   19:05
 @Author   : qiaofengsheng
 @File     :loss_fun.py
 @Software :PyCharm
 '''

from torch import nn


class Loss:
    def __init__(self, loss_type='mse'):
        self.loss_fun = nn.MSELoss()
        if loss_type == 'mse':
            self.loss_fun = nn.MSELoss()
        elif loss_type == 'l1':
            self.loss_fun = nn.L1Loss()
        elif loss_type == 'smooth_l1':
            self.loss_fun = nn.SmoothL1Loss()
        elif loss_type == 'cross_entropy':
            self.loss_fun = nn.CrossEntropyLoss()

    def get_loss_fun(self):
        return self.loss_fun