varifocal_loss.py
5.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
@mmcv.jit(derivate=True, coderize=True)
def varifocal_loss(pred,
target,
weight=None,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
reduction='mean',
avg_factor=None):
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning target of the iou-aware
classification score with shape (N, C), C is the number of classes.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
alpha (float, optional): A balance factor for the negative part of
Varifocal Loss, which is different from the alpha of Focal Loss.
Defaults to 0.75.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
iou_weighted (bool, optional): Whether to weight the loss of the
positive example with the iou target. Defaults to True.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# pred and target should be of the same size
assert pred.size() == target.size()
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
if iou_weighted:
focal_weight = target * (target > 0.0).float() + \
alpha * (pred_sigmoid - target).abs().pow(gamma) * \
(target <= 0.0).float()
else:
focal_weight = (target > 0.0).float() + \
alpha * (pred_sigmoid - target).abs().pow(gamma) * \
(target <= 0.0).float()
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module()
class VarifocalLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
reduction='mean',
loss_weight=1.0):
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
Args:
use_sigmoid (bool, optional): Whether the prediction is
used for sigmoid or softmax. Defaults to True.
alpha (float, optional): A balance factor for the negative part of
Varifocal Loss, which is different from the alpha of Focal
Loss. Defaults to 0.75.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
iou_weighted (bool, optional): Whether to weight the loss of the
positive examples with the iou target. Defaults to True.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
"""
super(VarifocalLoss, self).__init__()
assert use_sigmoid is True, \
'Only sigmoid varifocal loss supported now.'
assert alpha >= 0.0
self.use_sigmoid = use_sigmoid
self.alpha = alpha
self.gamma = gamma
self.iou_weighted = iou_weighted
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
loss_cls = self.loss_weight * varifocal_loss(
pred,
target,
weight,
alpha=self.alpha,
gamma=self.gamma,
iou_weighted=self.iou_weighted,
reduction=reduction,
avg_factor=avg_factor)
else:
raise NotImplementedError
return loss_cls