mask_hungarian_assigner.py
5.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.core.bbox.builder import BBOX_ASSIGNERS
from mmdet.core.bbox.match_costs.builder import build_match_cost
from .assign_result import AssignResult
from .base_assigner import BaseAssigner
try:
from scipy.optimize import linear_sum_assignment
except ImportError:
linear_sum_assignment = None
@BBOX_ASSIGNERS.register_module()
class MaskHungarianAssigner(BaseAssigner):
"""Computes one-to-one matching between predictions and ground truth for
mask.
This class computes an assignment between the targets and the predictions
based on the costs. The costs are weighted sum of three components:
classification cost, mask focal cost and mask dice cost. The
targets don't include the no_object, so generally there are more
predictions than targets. After the one-to-one matching, the un-matched
are treated as backgrounds. Thus each query prediction will be assigned
with `0` or a positive integer indicating the ground truth index:
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
cls_cost (:obj:`mmcv.ConfigDict` | dict): Classification cost config.
mask_cost (:obj:`mmcv.ConfigDict` | dict): Mask cost config.
dice_cost (:obj:`mmcv.ConfigDict` | dict): Dice cost config.
"""
def __init__(self,
cls_cost=dict(type='ClassificationCost', weight=1.0),
mask_cost=dict(
type='FocalLossCost', weight=1.0, binary_input=True),
dice_cost=dict(type='DiceCost', weight=1.0)):
self.cls_cost = build_match_cost(cls_cost)
self.mask_cost = build_match_cost(mask_cost)
self.dice_cost = build_match_cost(dice_cost)
def assign(self,
cls_pred,
mask_pred,
gt_labels,
gt_mask,
img_meta,
gt_bboxes_ignore=None,
eps=1e-7):
"""Computes one-to-one matching based on the weighted costs.
Args:
cls_pred (Tensor): Class prediction in shape
(num_query, cls_out_channels).
mask_pred (Tensor): Mask prediction in shape (num_query, H, W).
gt_labels (Tensor): Label of 'gt_mask'in shape = (num_gt, ).
gt_mask (Tensor): Ground truth mask in shape = (num_gt, H, W).
img_meta (dict): Meta information for current image.
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`. Default None.
eps (int | float, optional): A value added to the denominator for
numerical stability. Default 1e-7.
Returns:
:obj:`AssignResult`: The assigned result.
"""
assert gt_bboxes_ignore is None, \
'Only case when gt_bboxes_ignore is None is supported.'
num_gt, num_query = gt_labels.shape[0], cls_pred.shape[0]
# 1. assign -1 by default
assigned_gt_inds = cls_pred.new_full((num_query, ),
-1,
dtype=torch.long)
assigned_labels = cls_pred.new_full((num_query, ),
-1,
dtype=torch.long)
if num_gt == 0 or num_query == 0:
# No ground truth or boxes, return empty assignment
if num_gt == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
return AssignResult(
num_gt, assigned_gt_inds, None, labels=assigned_labels)
# 2. compute the weighted costs
# classification and maskcost.
if self.cls_cost.weight != 0 and cls_pred is not None:
cls_cost = self.cls_cost(cls_pred, gt_labels)
else:
cls_cost = 0
if self.mask_cost.weight != 0:
# mask_pred shape = [num_query, h, w]
# gt_mask shape = [num_gt, h, w]
# mask_cost shape = [num_query, num_gt]
mask_cost = self.mask_cost(mask_pred, gt_mask)
else:
mask_cost = 0
if self.dice_cost.weight != 0:
dice_cost = self.dice_cost(mask_pred, gt_mask)
else:
dice_cost = 0
cost = cls_cost + mask_cost + dice_cost
# 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" '
'to install scipy first.')
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
matched_row_inds = torch.from_numpy(matched_row_inds).to(
cls_pred.device)
matched_col_inds = torch.from_numpy(matched_col_inds).to(
cls_pred.device)
# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds[:] = 0
# assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
return AssignResult(
num_gt, assigned_gt_inds, None, labels=assigned_labels)