builder.py 2.4 KB
import copy
import torch
import inspect
from utils.registery import LOSS_REGISTRY
from utils import sequence_mask
from torchvision.ops import sigmoid_focal_loss

class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):

    def __init__(self, 
                 weight= None, 
                 size_average=None,
                 reduce=None, 
                 reduction: str = 'mean',
                 alpha: float = 0.25, 
                 gamma: float = 2):
        super().__init__(weight, size_average, reduce, reduction)
        self.alpha = alpha 
        self.gamma = gamma 
        self.reduction = reduction 

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) 

class MaskedSigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):

    def __init__(self, 
                 weight= None, 
                 size_average=None,
                 reduce=None, 
                 reduction: str = 'mean',
                 alpha: float = 0.25, 
                 gamma: float = 2):
        super().__init__(weight, size_average, reduce, reduction)
        self.alpha = alpha 
        self.gamma = gamma 
        self.reduction = reduction 

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor, valid_lens) -> torch.Tensor:
        weights = torch.ones_like(targets)
        weights = sequence_mask(weights, valid_lens)
        unweighted_loss = sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, reduction='none') 
        weighted_loss = (unweighted_loss * weights).mean(dim=-1)
        return weighted_loss


def register_sigmoid_focal_loss():
    LOSS_REGISTRY.register()(SigmoidFocalLoss)
    LOSS_REGISTRY.register()(MaskedSigmoidFocalLoss)


def register_torch_loss():
    for module_name in dir(torch.nn):
        if module_name.startswith('__') or 'Loss' not in module_name:
            continue
        _loss = getattr(torch.nn, module_name)
        if inspect.isclass(_loss) and issubclass(_loss, torch.nn.Module):
            LOSS_REGISTRY.register()(_loss)

def build_loss(cfg):
    register_sigmoid_focal_loss()
    register_torch_loss()
    loss_cfg = copy.deepcopy(cfg)
    try:
        loss_cfg = cfg['solver']['loss']
    except Exception:
        raise 'should contain {solver.loss}!'
    
    # return sigmoid_focal_loss
    return LOSS_REGISTRY.get(loss_cfg['name'])(**loss_cfg['args'])