reppoints_head.py
34.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops import DeformConv2d
from mmdet.core import (build_assigner, build_sampler, images_to_levels,
multi_apply, unmap)
from mmdet.core.anchor.point_generator import MlvlPointGenerator
from mmdet.core.utils import filter_scores_and_topk
from ..builder import HEADS, build_loss
from .anchor_free_head import AnchorFreeHead
@HEADS.register_module()
class RepPointsHead(AnchorFreeHead):
"""RepPoint head.
Args:
point_feat_channels (int): Number of channels of points features.
gradient_mul (float): The multiplier to gradients from
points refinement and recognition.
point_strides (Iterable): points strides.
point_base_scale (int): bbox scale for assigning labels.
loss_cls (dict): Config of classification loss.
loss_bbox_init (dict): Config of initial points loss.
loss_bbox_refine (dict): Config of points loss in refinement.
use_grid_points (bool): If we use bounding box representation, the
reppoints is represented as grid points on the bounding box.
center_init (bool): Whether to use center point assignment.
transform_method (str): The methods to transform RepPoints to bbox.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605
def __init__(self,
num_classes,
in_channels,
point_feat_channels=256,
num_points=9,
gradient_mul=0.1,
point_strides=[8, 16, 32, 64, 128],
point_base_scale=4,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_init=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
loss_bbox_refine=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
use_grid_points=False,
center_init=True,
transform_method='moment',
moment_mul=0.01,
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal',
name='reppoints_cls_out',
std=0.01,
bias_prob=0.01)),
**kwargs):
self.num_points = num_points
self.point_feat_channels = point_feat_channels
self.use_grid_points = use_grid_points
self.center_init = center_init
# we use deform conv to extract points features
self.dcn_kernel = int(np.sqrt(num_points))
self.dcn_pad = int((self.dcn_kernel - 1) / 2)
assert self.dcn_kernel * self.dcn_kernel == num_points, \
'The points number should be a square number.'
assert self.dcn_kernel % 2 == 1, \
'The points number should be an odd square number.'
dcn_base = np.arange(-self.dcn_pad,
self.dcn_pad + 1).astype(np.float64)
dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
(-1))
self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
super().__init__(
num_classes,
in_channels,
loss_cls=loss_cls,
init_cfg=init_cfg,
**kwargs)
self.gradient_mul = gradient_mul
self.point_base_scale = point_base_scale
self.point_strides = point_strides
self.prior_generator = MlvlPointGenerator(
self.point_strides, offset=0.)
self.sampling = loss_cls['type'] not in ['FocalLoss']
if self.train_cfg:
self.init_assigner = build_assigner(self.train_cfg.init.assigner)
self.refine_assigner = build_assigner(
self.train_cfg.refine.assigner)
# use PseudoSampler when sampling is False
if self.sampling and hasattr(self.train_cfg, 'sampler'):
sampler_cfg = self.train_cfg.sampler
else:
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.transform_method = transform_method
if self.transform_method == 'moment':
self.moment_transfer = nn.Parameter(
data=torch.zeros(2), requires_grad=True)
self.moment_mul = moment_mul
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
if self.use_sigmoid_cls:
self.cls_out_channels = self.num_classes
else:
self.cls_out_channels = self.num_classes + 1
self.loss_bbox_init = build_loss(loss_bbox_init)
self.loss_bbox_refine = build_loss(loss_bbox_refine)
def _init_layers(self):
"""Initialize layers of the head."""
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points
self.reppoints_cls_conv = DeformConv2d(self.feat_channels,
self.point_feat_channels,
self.dcn_kernel, 1,
self.dcn_pad)
self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
self.cls_out_channels, 1, 1, 0)
self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
self.point_feat_channels, 3,
1, 1)
self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
pts_out_dim, 1, 1, 0)
self.reppoints_pts_refine_conv = DeformConv2d(self.feat_channels,
self.point_feat_channels,
self.dcn_kernel, 1,
self.dcn_pad)
self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
pts_out_dim, 1, 1, 0)
def points2bbox(self, pts, y_first=True):
"""Converting the points set into bounding box.
:param pts: the input points sets (fields), each points
set (fields) is represented as 2n scalar.
:param y_first: if y_first=True, the point set is represented as
[y1, x1, y2, x2 ... yn, xn], otherwise the point set is
represented as [x1, y1, x2, y2 ... xn, yn].
:return: each points set is converting to a bbox [x1, y1, x2, y2].
"""
pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1,
...]
pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0,
...]
if self.transform_method == 'minmax':
bbox_left = pts_x.min(dim=1, keepdim=True)[0]
bbox_right = pts_x.max(dim=1, keepdim=True)[0]
bbox_up = pts_y.min(dim=1, keepdim=True)[0]
bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
dim=1)
elif self.transform_method == 'partial_minmax':
pts_y = pts_y[:, :4, ...]
pts_x = pts_x[:, :4, ...]
bbox_left = pts_x.min(dim=1, keepdim=True)[0]
bbox_right = pts_x.max(dim=1, keepdim=True)[0]
bbox_up = pts_y.min(dim=1, keepdim=True)[0]
bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
dim=1)
elif self.transform_method == 'moment':
pts_y_mean = pts_y.mean(dim=1, keepdim=True)
pts_x_mean = pts_x.mean(dim=1, keepdim=True)
pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True)
pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True)
moment_transfer = (self.moment_transfer * self.moment_mul) + (
self.moment_transfer.detach() * (1 - self.moment_mul))
moment_width_transfer = moment_transfer[0]
moment_height_transfer = moment_transfer[1]
half_width = pts_x_std * torch.exp(moment_width_transfer)
half_height = pts_y_std * torch.exp(moment_height_transfer)
bbox = torch.cat([
pts_x_mean - half_width, pts_y_mean - half_height,
pts_x_mean + half_width, pts_y_mean + half_height
],
dim=1)
else:
raise NotImplementedError
return bbox
def gen_grid_from_reg(self, reg, previous_boxes):
"""Base on the previous bboxes and regression values, we compute the
regressed bboxes and generate the grids on the bboxes.
:param reg: the regression value to previous bboxes.
:param previous_boxes: previous bboxes.
:return: generate grids on the regressed bboxes.
"""
b, _, h, w = reg.shape
bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2.
bwh = (previous_boxes[:, 2:, ...] -
previous_boxes[:, :2, ...]).clamp(min=1e-6)
grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp(
reg[:, 2:, ...])
grid_wh = bwh * torch.exp(reg[:, 2:, ...])
grid_left = grid_topleft[:, [0], ...]
grid_top = grid_topleft[:, [1], ...]
grid_width = grid_wh[:, [0], ...]
grid_height = grid_wh[:, [1], ...]
intervel = torch.linspace(0., 1., self.dcn_kernel).view(
1, self.dcn_kernel, 1, 1).type_as(reg)
grid_x = grid_left + grid_width * intervel
grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1)
grid_x = grid_x.view(b, -1, h, w)
grid_y = grid_top + grid_height * intervel
grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1)
grid_y = grid_y.view(b, -1, h, w)
grid_yx = torch.stack([grid_y, grid_x], dim=2)
grid_yx = grid_yx.view(b, -1, h, w)
regressed_bbox = torch.cat([
grid_left, grid_top, grid_left + grid_width, grid_top + grid_height
], 1)
return grid_yx, regressed_bbox
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def forward_single(self, x):
"""Forward feature map of a single FPN level."""
dcn_base_offset = self.dcn_base_offset.type_as(x)
# If we use center_init, the initial reppoints is from center points.
# If we use bounding bbox representation, the initial reppoints is
# from regular grid placed on a pre-defined bbox.
if self.use_grid_points or not self.center_init:
scale = self.point_base_scale / 2
points_init = dcn_base_offset / dcn_base_offset.max() * scale
bbox_init = x.new_tensor([-scale, -scale, scale,
scale]).view(1, 4, 1, 1)
else:
points_init = 0
cls_feat = x
pts_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
pts_feat = reg_conv(pts_feat)
# initialize reppoints
pts_out_init = self.reppoints_pts_init_out(
self.relu(self.reppoints_pts_init_conv(pts_feat)))
if self.use_grid_points:
pts_out_init, bbox_out_init = self.gen_grid_from_reg(
pts_out_init, bbox_init.detach())
else:
pts_out_init = pts_out_init + points_init
# refine and classify reppoints
pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach(
) + self.gradient_mul * pts_out_init
dcn_offset = pts_out_init_grad_mul - dcn_base_offset
cls_out = self.reppoints_cls_out(
self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
pts_out_refine = self.reppoints_pts_refine_out(
self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))
if self.use_grid_points:
pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(
pts_out_refine, bbox_out_init.detach())
else:
pts_out_refine = pts_out_refine + pts_out_init.detach()
if self.training:
return cls_out, pts_out_init, pts_out_refine
else:
return cls_out, self.points2bbox(pts_out_refine)
def get_points(self, featmap_sizes, img_metas, device):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: points of each image, valid flags of each image
"""
num_imgs = len(img_metas)
# since feature map sizes of all images are the same, we only compute
# points center for one time
multi_level_points = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
points_list = [[point.clone() for point in multi_level_points]
for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level grids
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = self.prior_generator.valid_flags(
featmap_sizes, img_meta['pad_shape'])
valid_flag_list.append(multi_level_flags)
return points_list, valid_flag_list
def centers_to_bboxes(self, point_list):
"""Get bboxes according to center points.
Only used in :class:`MaxIoUAssigner`.
"""
bbox_list = []
for i_img, point in enumerate(point_list):
bbox = []
for i_lvl in range(len(self.point_strides)):
scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5
bbox_shift = torch.Tensor([-scale, -scale, scale,
scale]).view(1, 4).type_as(point[0])
bbox_center = torch.cat(
[point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
bbox.append(bbox_center + bbox_shift)
bbox_list.append(bbox)
return bbox_list
def offset_to_pts(self, center_list, pred_list):
"""Change from point offset to point coordinate."""
pts_list = []
for i_lvl in range(len(self.point_strides)):
pts_lvl = []
for i_img in range(len(center_list)):
pts_center = center_list[i_img][i_lvl][:, :2].repeat(
1, self.num_points)
pts_shift = pred_list[i_lvl][i_img]
yx_pts_shift = pts_shift.permute(1, 2, 0).view(
-1, 2 * self.num_points)
y_pts_shift = yx_pts_shift[..., 0::2]
x_pts_shift = yx_pts_shift[..., 1::2]
xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1)
xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
pts_lvl.append(pts)
pts_lvl = torch.stack(pts_lvl, 0)
pts_list.append(pts_lvl)
return pts_list
def _point_target_single(self,
flat_proposals,
valid_flags,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
stage='init',
unmap_outputs=True):
inside_flags = valid_flags
if not inside_flags.any():
return (None, ) * 7
# assign gt and sample proposals
proposals = flat_proposals[inside_flags, :]
if stage == 'init':
assigner = self.init_assigner
pos_weight = self.train_cfg.init.pos_weight
else:
assigner = self.refine_assigner
pos_weight = self.train_cfg.refine.pos_weight
assign_result = assigner.assign(proposals, gt_bboxes, gt_bboxes_ignore,
None if self.sampling else gt_labels)
sampling_result = self.sampler.sample(assign_result, proposals,
gt_bboxes)
num_valid_proposals = proposals.shape[0]
bbox_gt = proposals.new_zeros([num_valid_proposals, 4])
pos_proposals = torch.zeros_like(proposals)
proposals_weights = proposals.new_zeros([num_valid_proposals, 4])
labels = proposals.new_full((num_valid_proposals, ),
self.num_classes,
dtype=torch.long)
label_weights = proposals.new_zeros(
num_valid_proposals, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
pos_gt_bboxes = sampling_result.pos_gt_bboxes
bbox_gt[pos_inds, :] = pos_gt_bboxes
pos_proposals[pos_inds, :] = proposals[pos_inds, :]
proposals_weights[pos_inds, :] = 1.0
if gt_labels is None:
# Only rpn gives gt_labels as None
# Foreground is the first class
labels[pos_inds] = 0
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
if pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of proposals
if unmap_outputs:
num_total_proposals = flat_proposals.size(0)
labels = unmap(labels, num_total_proposals, inside_flags)
label_weights = unmap(label_weights, num_total_proposals,
inside_flags)
bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
pos_proposals = unmap(pos_proposals, num_total_proposals,
inside_flags)
proposals_weights = unmap(proposals_weights, num_total_proposals,
inside_flags)
return (labels, label_weights, bbox_gt, pos_proposals,
proposals_weights, pos_inds, neg_inds)
def get_targets(self,
proposals_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
stage='init',
label_channels=1,
unmap_outputs=True):
"""Compute corresponding GT box and classification targets for
proposals.
Args:
proposals_list (list[list]): Multi level points/bboxes of each
image.
valid_flag_list (list[list]): Multi level valid flags of each
image.
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
ignored.
gt_bboxes_list (list[Tensor]): Ground truth labels of each box.
stage (str): `init` or `refine`. Generate target for init stage or
refine stage
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple:
- labels_list (list[Tensor]): Labels of each level.
- label_weights_list (list[Tensor]): Label weights of each level. # noqa: E501
- bbox_gt_list (list[Tensor]): Ground truth bbox of each level.
- proposal_list (list[Tensor]): Proposals(points/bboxes) of each level. # noqa: E501
- proposal_weights_list (list[Tensor]): Proposal weights of each level. # noqa: E501
- num_total_pos (int): Number of positive samples in all images. # noqa: E501
- num_total_neg (int): Number of negative samples in all images. # noqa: E501
"""
assert stage in ['init', 'refine']
num_imgs = len(img_metas)
assert len(proposals_list) == len(valid_flag_list) == num_imgs
# points number of multi levels
num_level_proposals = [points.size(0) for points in proposals_list[0]]
# concat all level points and flags to a single tensor
for i in range(num_imgs):
assert len(proposals_list[i]) == len(valid_flag_list[i])
proposals_list[i] = torch.cat(proposals_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_labels, all_label_weights, all_bbox_gt, all_proposals,
all_proposal_weights, pos_inds_list, neg_inds_list) = multi_apply(
self._point_target_single,
proposals_list,
valid_flag_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
stage=stage,
unmap_outputs=unmap_outputs)
# no valid points
if any([labels is None for labels in all_labels]):
return None
# sampled points of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
labels_list = images_to_levels(all_labels, num_level_proposals)
label_weights_list = images_to_levels(all_label_weights,
num_level_proposals)
bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
proposals_list = images_to_levels(all_proposals, num_level_proposals)
proposal_weights_list = images_to_levels(all_proposal_weights,
num_level_proposals)
return (labels_list, label_weights_list, bbox_gt_list, proposals_list,
proposal_weights_list, num_total_pos, num_total_neg)
def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels,
label_weights, bbox_gt_init, bbox_weights_init,
bbox_gt_refine, bbox_weights_refine, stride,
num_total_samples_init, num_total_samples_refine):
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
cls_score = cls_score.contiguous()
loss_cls = self.loss_cls(
cls_score,
labels,
label_weights,
avg_factor=num_total_samples_refine)
# points loss
bbox_gt_init = bbox_gt_init.reshape(-1, 4)
bbox_weights_init = bbox_weights_init.reshape(-1, 4)
bbox_pred_init = self.points2bbox(
pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False)
bbox_gt_refine = bbox_gt_refine.reshape(-1, 4)
bbox_weights_refine = bbox_weights_refine.reshape(-1, 4)
bbox_pred_refine = self.points2bbox(
pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False)
normalize_term = self.point_base_scale * stride
loss_pts_init = self.loss_bbox_init(
bbox_pred_init / normalize_term,
bbox_gt_init / normalize_term,
bbox_weights_init,
avg_factor=num_total_samples_init)
loss_pts_refine = self.loss_bbox_refine(
bbox_pred_refine / normalize_term,
bbox_gt_refine / normalize_term,
bbox_weights_refine,
avg_factor=num_total_samples_refine)
return loss_cls, loss_pts_init, loss_pts_refine
def loss(self,
cls_scores,
pts_preds_init,
pts_preds_refine,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
device = cls_scores[0].device
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
# target for initial stage
center_list, valid_flag_list = self.get_points(featmap_sizes,
img_metas, device)
pts_coordinate_preds_init = self.offset_to_pts(center_list,
pts_preds_init)
if self.train_cfg.init.assigner['type'] == 'PointAssigner':
# Assign target for center list
candidate_list = center_list
else:
# transform center list to bbox list and
# assign target for bbox list
bbox_list = self.centers_to_bboxes(center_list)
candidate_list = bbox_list
cls_reg_targets_init = self.get_targets(
candidate_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
stage='init',
label_channels=label_channels)
(*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init,
num_total_pos_init, num_total_neg_init) = cls_reg_targets_init
num_total_samples_init = (
num_total_pos_init +
num_total_neg_init if self.sampling else num_total_pos_init)
# target for refinement stage
center_list, valid_flag_list = self.get_points(featmap_sizes,
img_metas, device)
pts_coordinate_preds_refine = self.offset_to_pts(
center_list, pts_preds_refine)
bbox_list = []
for i_img, center in enumerate(center_list):
bbox = []
for i_lvl in range(len(pts_preds_refine)):
bbox_preds_init = self.points2bbox(
pts_preds_init[i_lvl].detach())
bbox_shift = bbox_preds_init * self.point_strides[i_lvl]
bbox_center = torch.cat(
[center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1)
bbox.append(bbox_center +
bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4))
bbox_list.append(bbox)
cls_reg_targets_refine = self.get_targets(
bbox_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
stage='refine',
label_channels=label_channels)
(labels_list, label_weights_list, bbox_gt_list_refine,
candidate_list_refine, bbox_weights_list_refine, num_total_pos_refine,
num_total_neg_refine) = cls_reg_targets_refine
num_total_samples_refine = (
num_total_pos_refine +
num_total_neg_refine if self.sampling else num_total_pos_refine)
# compute loss
losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
self.loss_single,
cls_scores,
pts_coordinate_preds_init,
pts_coordinate_preds_refine,
labels_list,
label_weights_list,
bbox_gt_list_init,
bbox_weights_list_init,
bbox_gt_list_refine,
bbox_weights_list_refine,
self.point_strides,
num_total_samples_init=num_total_samples_init,
num_total_samples_refine=num_total_samples_refine)
loss_dict_all = {
'loss_cls': losses_cls,
'loss_pts_init': losses_pts_init,
'loss_pts_refine': losses_pts_refine
}
return loss_dict_all
# Same as base_dense_head/_get_bboxes_single except self._bbox_decode
def _get_bboxes_single(self,
cls_score_list,
bbox_pred_list,
score_factor_list,
mlvl_priors,
img_meta,
cfg,
rescale=False,
with_nms=True,
**kwargs):
"""Transform outputs of a single image into bbox predictions.
Args:
cls_score_list (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_priors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas from
all scale levels of a single image, each item has shape
(num_priors * 4, H, W).
score_factor_list (list[Tensor]): Score factor from all scale
levels of a single image. RepPoints head does not need
this value.
mlvl_priors (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid, has shape
(num_priors, 2).
img_meta (dict): Image meta info.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
tuple[Tensor]: Results of detected bboxes and labels. If with_nms
is False and mlvl_score_factor is None, return mlvl_bboxes and
mlvl_scores, else return mlvl_bboxes, mlvl_scores and
mlvl_score_factor. Usually with_nms is False is used for aug
test. If with_nms is True, then return the following format
- det_bboxes (Tensor): Predicted bboxes with shape \
[num_bboxes, 5], where the first 4 columns are bounding \
box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
column are scores between 0 and 1.
- det_labels (Tensor): Predicted labels of the corresponding \
box with shape [num_bboxes].
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_score_list) == len(bbox_pred_list)
img_shape = img_meta['img_shape']
nms_pre = cfg.get('nms_pre', -1)
mlvl_bboxes = []
mlvl_scores = []
mlvl_labels = []
for level_idx, (cls_score, bbox_pred, priors) in enumerate(
zip(cls_score_list, bbox_pred_list, mlvl_priors)):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)[:, :-1]
# After https://github.com/open-mmlab/mmdetection/pull/6268/,
# this operation keeps fewer bboxes under the same `nms_pre`.
# There is no difference in performance for most models. If you
# find a slight drop in performance, you can set a larger
# `nms_pre` than before.
results = filter_scores_and_topk(
scores, cfg.score_thr, nms_pre,
dict(bbox_pred=bbox_pred, priors=priors))
scores, labels, _, filtered_results = results
bbox_pred = filtered_results['bbox_pred']
priors = filtered_results['priors']
bboxes = self._bbox_decode(priors, bbox_pred,
self.point_strides[level_idx],
img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_labels.append(labels)
return self._bbox_post_process(
mlvl_scores,
mlvl_labels,
mlvl_bboxes,
img_meta['scale_factor'],
cfg,
rescale=rescale,
with_nms=with_nms)
def _bbox_decode(self, points, bbox_pred, stride, max_shape):
bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1)
bboxes = bbox_pred * stride + bbox_pos_center
x1 = bboxes[:, 0].clamp(min=0, max=max_shape[1])
y1 = bboxes[:, 1].clamp(min=0, max=max_shape[0])
x2 = bboxes[:, 2].clamp(min=0, max=max_shape[1])
y2 = bboxes[:, 3].clamp(min=0, max=max_shape[0])
decoded_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
return decoded_bboxes