yolo_bbox_coder.py
3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from ..builder import BBOX_CODERS
from .base_bbox_coder import BaseBBoxCoder
@BBOX_CODERS.register_module()
class YOLOBBoxCoder(BaseBBoxCoder):
"""YOLO BBox coder.
Following `YOLO <https://arxiv.org/abs/1506.02640>`_, this coder divide
image into grids, and encode bbox (x1, y1, x2, y2) into (cx, cy, dw, dh).
cx, cy in [0., 1.], denotes relative center position w.r.t the center of
bboxes. dw, dh are the same as :obj:`DeltaXYWHBBoxCoder`.
Args:
eps (float): Min value of cx, cy when encoding.
"""
def __init__(self, eps=1e-6):
super(BaseBBoxCoder, self).__init__()
self.eps = eps
@mmcv.jit(coderize=True)
def encode(self, bboxes, gt_bboxes, stride):
"""Get box regression transformation deltas that can be used to
transform the ``bboxes`` into the ``gt_bboxes``.
Args:
bboxes (torch.Tensor): Source boxes, e.g., anchors.
gt_bboxes (torch.Tensor): Target of the transformation, e.g.,
ground-truth boxes.
stride (torch.Tensor | int): Stride of bboxes.
Returns:
torch.Tensor: Box transformation deltas
"""
assert bboxes.size(0) == gt_bboxes.size(0)
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
x_center_gt = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) * 0.5
y_center_gt = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) * 0.5
w_gt = gt_bboxes[..., 2] - gt_bboxes[..., 0]
h_gt = gt_bboxes[..., 3] - gt_bboxes[..., 1]
x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5
y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5
w = bboxes[..., 2] - bboxes[..., 0]
h = bboxes[..., 3] - bboxes[..., 1]
w_target = torch.log((w_gt / w).clamp(min=self.eps))
h_target = torch.log((h_gt / h).clamp(min=self.eps))
x_center_target = ((x_center_gt - x_center) / stride + 0.5).clamp(
self.eps, 1 - self.eps)
y_center_target = ((y_center_gt - y_center) / stride + 0.5).clamp(
self.eps, 1 - self.eps)
encoded_bboxes = torch.stack(
[x_center_target, y_center_target, w_target, h_target], dim=-1)
return encoded_bboxes
@mmcv.jit(coderize=True)
def decode(self, bboxes, pred_bboxes, stride):
"""Apply transformation `pred_bboxes` to `boxes`.
Args:
boxes (torch.Tensor): Basic boxes, e.g. anchors.
pred_bboxes (torch.Tensor): Encoded boxes with shape
stride (torch.Tensor | int): Strides of bboxes.
Returns:
torch.Tensor: Decoded boxes.
"""
assert pred_bboxes.size(-1) == bboxes.size(-1) == 4
xy_centers = (bboxes[..., :2] + bboxes[..., 2:]) * 0.5 + (
pred_bboxes[..., :2] - 0.5) * stride
whs = (bboxes[..., 2:] -
bboxes[..., :2]) * 0.5 * pred_bboxes[..., 2:].exp()
decoded_bboxes = torch.stack(
(xy_centers[..., 0] - whs[..., 0], xy_centers[..., 1] -
whs[..., 1], xy_centers[..., 0] + whs[..., 0],
xy_centers[..., 1] + whs[..., 1]),
dim=-1)
return decoded_bboxes