match_cost.py
12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmdet.core.bbox.iou_calculators import bbox_overlaps
from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
from .builder import MATCH_COST
@MATCH_COST.register_module()
class BBoxL1Cost:
"""BBoxL1Cost.
Args:
weight (int | float, optional): loss_weight
box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN
Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost
>>> import torch
>>> self = BBoxL1Cost()
>>> bbox_pred = torch.rand(1, 4)
>>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(bbox_pred, gt_bboxes, factor)
tensor([[1.6172, 1.6422]])
"""
def __init__(self, weight=1., box_format='xyxy'):
self.weight = weight
assert box_format in ['xyxy', 'xywh']
self.box_format = box_format
def __call__(self, bbox_pred, gt_bboxes):
"""
Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(cx, cy, w, h), which are all in range [0, 1]. Shape
(num_query, 4).
gt_bboxes (Tensor): Ground truth boxes with normalized
coordinates (x1, y1, x2, y2). Shape (num_gt, 4).
Returns:
torch.Tensor: bbox_cost value with weight
"""
if self.box_format == 'xywh':
gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
elif self.box_format == 'xyxy':
bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
return bbox_cost * self.weight
@MATCH_COST.register_module()
class FocalLossCost:
"""FocalLossCost.
Args:
weight (int | float, optional): loss_weight
alpha (int | float, optional): focal_loss alpha
gamma (int | float, optional): focal_loss gamma
eps (float, optional): default 1e-12
binary_input (bool, optional): Whether the input is binary,
default False.
Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost
>>> import torch
>>> self = FocalLossCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3236, -0.3364, -0.2699],
[-0.3439, -0.3209, -0.4807],
[-0.4099, -0.3795, -0.2929],
[-0.1950, -0.1207, -0.2626]])
"""
def __init__(self,
weight=1.,
alpha=0.25,
gamma=2,
eps=1e-12,
binary_input=False):
self.weight = weight
self.alpha = alpha
self.gamma = gamma
self.eps = eps
self.binary_input = binary_input
def _focal_loss_cost(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
(num_query, num_class).
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value with weight
"""
cls_pred = cls_pred.sigmoid()
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels]
return cls_cost * self.weight
def _mask_focal_loss_cost(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classfication logits
in shape (num_query, d1, ..., dn), dtype=torch.float32.
gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn),
dtype=torch.long. Labels should be binary.
Returns:
Tensor: Focal cost matrix with weight in shape\
(num_query, num_gt).
"""
cls_pred = cls_pred.flatten(1)
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
cls_pred = cls_pred.sigmoid()
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \
torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels))
return cls_cost / n * self.weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classfication logits.
gt_labels (Tensor)): Labels.
Returns:
Tensor: Focal cost matrix with weight in shape\
(num_query, num_gt).
"""
if self.binary_input:
return self._mask_focal_loss_cost(cls_pred, gt_labels)
else:
return self._focal_loss_cost(cls_pred, gt_labels)
@MATCH_COST.register_module()
class ClassificationCost:
"""ClsSoftmaxCost.
Args:
weight (int | float, optional): loss_weight
Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import \
... ClassificationCost
>>> import torch
>>> self = ClassificationCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3430, -0.3525, -0.3045],
[-0.3077, -0.2931, -0.3992],
[-0.3664, -0.3455, -0.2881],
[-0.3343, -0.2701, -0.3956]])
"""
def __init__(self, weight=1.):
self.weight = weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
(num_query, num_class).
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value with weight
"""
# Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching,
# so it can be omitted.
cls_score = cls_pred.softmax(-1)
cls_cost = -cls_score[:, gt_labels]
return cls_cost * self.weight
@MATCH_COST.register_module()
class IoUCost:
"""IoUCost.
Args:
iou_mode (str, optional): iou mode such as 'iou' | 'giou'
weight (int | float, optional): loss weight
Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import IoUCost
>>> import torch
>>> self = IoUCost()
>>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]])
>>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
>>> self(bboxes, gt_bboxes)
tensor([[-0.1250, 0.1667],
[ 0.1667, -0.5000]])
"""
def __init__(self, iou_mode='giou', weight=1.):
self.weight = weight
self.iou_mode = iou_mode
def __call__(self, bboxes, gt_bboxes):
"""
Args:
bboxes (Tensor): Predicted boxes with unnormalized coordinates
(x1, y1, x2, y2). Shape (num_query, 4).
gt_bboxes (Tensor): Ground truth boxes with unnormalized
coordinates (x1, y1, x2, y2). Shape (num_gt, 4).
Returns:
torch.Tensor: iou_cost value with weight
"""
# overlaps: [num_bboxes, num_gt]
overlaps = bbox_overlaps(
bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False)
# The 1 is a constant that doesn't change the matching, so omitted.
iou_cost = -overlaps
return iou_cost * self.weight
@MATCH_COST.register_module()
class DiceCost:
"""Cost of mask assignments based on dice losses.
Args:
weight (int | float, optional): loss_weight. Defaults to 1.
pred_act (bool, optional): Whether to apply sigmoid to mask_pred.
Defaults to False.
eps (float, optional): default 1e-12.
"""
def __init__(self, weight=1., pred_act=False, eps=1e-3):
self.weight = weight
self.pred_act = pred_act
self.eps = eps
def binary_mask_dice_loss(self, mask_preds, gt_masks):
"""
Args:
mask_preds (Tensor): Mask prediction in shape (num_query, *).
gt_masks (Tensor): Ground truth in shape (num_gt, *)
store 0 or 1, 0 for negative class and 1 for
positive class.
Returns:
Tensor: Dice cost matrix in shape (num_query, num_gt).
"""
mask_preds = mask_preds.flatten(1)
gt_masks = gt_masks.flatten(1).float()
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :]
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
return loss
def __call__(self, mask_preds, gt_masks):
"""
Args:
mask_preds (Tensor): Mask prediction logits in shape (num_query, *)
gt_masks (Tensor): Ground truth in shape (num_gt, *)
Returns:
Tensor: Dice cost matrix with weight in shape (num_query, num_gt).
"""
if self.pred_act:
mask_preds = mask_preds.sigmoid()
dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks)
return dice_cost * self.weight
@MATCH_COST.register_module()
class CrossEntropyLossCost:
"""CrossEntropyLossCost.
Args:
weight (int | float, optional): loss weight. Defaults to 1.
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to True.
Examples:
>>> from mmdet.core.bbox.match_costs import CrossEntropyLossCost
>>> import torch
>>> bce = CrossEntropyLossCost(use_sigmoid=True)
>>> cls_pred = torch.tensor([[7.6, 1.2], [-1.3, 10]])
>>> gt_labels = torch.tensor([[1, 1], [1, 0]])
>>> print(bce(cls_pred, gt_labels))
"""
def __init__(self, weight=1., use_sigmoid=True):
assert use_sigmoid, 'use_sigmoid = False is not supported yet.'
self.weight = weight
self.use_sigmoid = use_sigmoid
def _binary_cross_entropy(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): The prediction with shape (num_query, 1, *) or
(num_query, *).
gt_labels (Tensor): The learning label of prediction with
shape (num_gt, *).
Returns:
Tensor: Cross entropy cost matrix in shape (num_query, num_gt).
"""
cls_pred = cls_pred.flatten(1).float()
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
pos = F.binary_cross_entropy_with_logits(
cls_pred, torch.ones_like(cls_pred), reduction='none')
neg = F.binary_cross_entropy_with_logits(
cls_pred, torch.zeros_like(cls_pred), reduction='none')
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
cls_cost = cls_cost / n
return cls_cost
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits.
gt_labels (Tensor): Labels.
Returns:
Tensor: Cross entropy cost matrix with weight in
shape (num_query, num_gt).
"""
if self.use_sigmoid:
cls_cost = self._binary_cross_entropy(cls_pred, gt_labels)
else:
raise NotImplementedError
return cls_cost * self.weight