grid_roi_head.py
6.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmdet.core import bbox2result, bbox2roi
from ..builder import HEADS, build_head, build_roi_extractor
from .standard_roi_head import StandardRoIHead
@HEADS.register_module()
class GridRoIHead(StandardRoIHead):
"""Grid roi head for Grid R-CNN.
https://arxiv.org/abs/1811.12030
"""
def __init__(self, grid_roi_extractor, grid_head, **kwargs):
assert grid_head is not None
super(GridRoIHead, self).__init__(**kwargs)
if grid_roi_extractor is not None:
self.grid_roi_extractor = build_roi_extractor(grid_roi_extractor)
self.share_roi_extractor = False
else:
self.share_roi_extractor = True
self.grid_roi_extractor = self.bbox_roi_extractor
self.grid_head = build_head(grid_head)
def _random_jitter(self, sampling_results, img_metas, amplitude=0.15):
"""Ramdom jitter positive proposals for training."""
for sampling_result, img_meta in zip(sampling_results, img_metas):
bboxes = sampling_result.pos_bboxes
random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_(
-amplitude, amplitude)
# before jittering
cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2
wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs()
# after jittering
new_cxcy = cxcy + wh * random_offsets[:, :2]
new_wh = wh * (1 + random_offsets[:, 2:])
# xywh to xyxy
new_x1y1 = (new_cxcy - new_wh / 2)
new_x2y2 = (new_cxcy + new_wh / 2)
new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1)
# clip bboxes
max_shape = img_meta['img_shape']
if max_shape is not None:
new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1)
new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1)
sampling_result.pos_bboxes = new_bboxes
return sampling_results
def forward_dummy(self, x, proposals):
"""Dummy forward function."""
# bbox head
outs = ()
rois = bbox2roi([proposals])
if self.with_bbox:
bbox_results = self._bbox_forward(x, rois)
outs = outs + (bbox_results['cls_score'],
bbox_results['bbox_pred'])
# grid head
grid_rois = rois[:100]
grid_feats = self.grid_roi_extractor(
x[:self.grid_roi_extractor.num_inputs], grid_rois)
if self.with_shared_head:
grid_feats = self.shared_head(grid_feats)
grid_pred = self.grid_head(grid_feats)
outs = outs + (grid_pred, )
# mask head
if self.with_mask:
mask_rois = rois[:100]
mask_results = self._mask_forward(x, mask_rois)
outs = outs + (mask_results['mask_pred'], )
return outs
def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
img_metas):
"""Run forward function and calculate loss for box head in training."""
bbox_results = super(GridRoIHead,
self)._bbox_forward_train(x, sampling_results,
gt_bboxes, gt_labels,
img_metas)
# Grid head forward and loss
sampling_results = self._random_jitter(sampling_results, img_metas)
pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
# GN in head does not support zero shape input
if pos_rois.shape[0] == 0:
return bbox_results
grid_feats = self.grid_roi_extractor(
x[:self.grid_roi_extractor.num_inputs], pos_rois)
if self.with_shared_head:
grid_feats = self.shared_head(grid_feats)
# Accelerate training
max_sample_num_grid = self.train_cfg.get('max_num_grid', 192)
sample_idx = torch.randperm(
grid_feats.shape[0])[:min(grid_feats.shape[0], max_sample_num_grid
)]
grid_feats = grid_feats[sample_idx]
grid_pred = self.grid_head(grid_feats)
grid_targets = self.grid_head.get_targets(sampling_results,
self.train_cfg)
grid_targets = grid_targets[sample_idx]
loss_grid = self.grid_head.loss(grid_pred, grid_targets)
bbox_results['loss_bbox'].update(loss_grid)
return bbox_results
def simple_test(self,
x,
proposal_list,
img_metas,
proposals=None,
rescale=False):
"""Test without augmentation."""
assert self.with_bbox, 'Bbox head must be implemented.'
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_metas, proposal_list, self.test_cfg, rescale=False)
# pack rois into bboxes
grid_rois = bbox2roi([det_bbox[:, :4] for det_bbox in det_bboxes])
if grid_rois.shape[0] != 0:
grid_feats = self.grid_roi_extractor(
x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois)
self.grid_head.test_mode = True
grid_pred = self.grid_head(grid_feats)
# split batch grid head prediction back to each image
num_roi_per_img = tuple(len(det_bbox) for det_bbox in det_bboxes)
grid_pred = {
k: v.split(num_roi_per_img, 0)
for k, v in grid_pred.items()
}
# apply bbox post-processing to each image individually
bbox_results = []
num_imgs = len(det_bboxes)
for i in range(num_imgs):
if det_bboxes[i].shape[0] == 0:
bbox_results.append([
np.zeros((0, 5), dtype=np.float32)
for _ in range(self.bbox_head.num_classes)
])
else:
det_bbox = self.grid_head.get_bboxes(
det_bboxes[i], grid_pred['fused'][i], [img_metas[i]])
if rescale:
det_bbox[:, :4] /= img_metas[i]['scale_factor']
bbox_results.append(
bbox2result(det_bbox, det_labels[i],
self.bbox_head.num_classes))
else:
bbox_results = [[
np.zeros((0, 5), dtype=np.float32)
for _ in range(self.bbox_head.num_classes)
] for _ in range(len(det_bboxes))]
if not self.with_mask:
return bbox_results
else:
segm_results = self.simple_test_mask(
x, img_metas, det_bboxes, det_labels, rescale=rescale)
return list(zip(bbox_results, segm_results))