fpg.py 16.4 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule

from ..builder import NECKS


class Transition(BaseModule):
    """Base class for transition.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
    """

    def __init__(self, in_channels, out_channels, init_cfg=None):
        super().__init__(init_cfg)
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(x):
        pass


class UpInterpolationConv(Transition):
    """A transition used for up-sampling.

    Up-sample the input by interpolation then refines the feature by
    a convolution layer.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        scale_factor (int): Up-sampling factor. Default: 2.
        mode (int): Interpolation mode. Default: nearest.
        align_corners (bool): Whether align corners when interpolation.
            Default: None.
        kernel_size (int): Kernel size for the conv. Default: 3.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 scale_factor=2,
                 mode='nearest',
                 align_corners=None,
                 kernel_size=3,
                 init_cfg=None,
                 **kwargs):
        super().__init__(in_channels, out_channels, init_cfg)
        self.mode = mode
        self.scale_factor = scale_factor
        self.align_corners = align_corners
        self.conv = ConvModule(
            in_channels,
            out_channels,
            kernel_size,
            padding=(kernel_size - 1) // 2,
            **kwargs)

    def forward(self, x):
        x = F.interpolate(
            x,
            scale_factor=self.scale_factor,
            mode=self.mode,
            align_corners=self.align_corners)
        x = self.conv(x)
        return x


class LastConv(Transition):
    """A transition used for refining the output of the last stage.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        num_inputs (int): Number of inputs of the FPN features.
        kernel_size (int): Kernel size for the conv. Default: 3.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 num_inputs,
                 kernel_size=3,
                 init_cfg=None,
                 **kwargs):
        super().__init__(in_channels, out_channels, init_cfg)
        self.num_inputs = num_inputs
        self.conv_out = ConvModule(
            in_channels,
            out_channels,
            kernel_size,
            padding=(kernel_size - 1) // 2,
            **kwargs)

    def forward(self, inputs):
        assert len(inputs) == self.num_inputs
        return self.conv_out(inputs[-1])


@NECKS.register_module()
class FPG(BaseModule):
    """FPG.

    Implementation of `Feature Pyramid Grids (FPG)
    <https://arxiv.org/abs/2004.03580>`_.
    This implementation only gives the basic structure stated in the paper.
    But users can implement different type of transitions to fully explore the
    the potential power of the structure of FPG.

    Args:
        in_channels (int): Number of input channels (feature maps of all levels
            should have the same channels).
        out_channels (int): Number of output channels (used at each scale)
        num_outs (int): Number of output scales.
        stack_times (int): The number of times the pyramid architecture will
            be stacked.
        paths (list[str]): Specify the path order of each stack level.
            Each element in the list should be either 'bu' (bottom-up) or
            'td' (top-down).
        inter_channels (int): Number of inter channels.
        same_up_trans (dict): Transition that goes down at the same stage.
        same_down_trans (dict): Transition that goes up at the same stage.
        across_lateral_trans (dict): Across-pathway same-stage
        across_down_trans (dict): Across-pathway bottom-up connection.
        across_up_trans (dict): Across-pathway top-down connection.
        across_skip_trans (dict): Across-pathway skip connection.
        output_trans (dict): Transition that trans the output of the
            last stage.
        start_level (int): Index of the start input backbone level used to
            build the feature pyramid. Default: 0.
        end_level (int): Index of the end input backbone level (exclusive) to
            build the feature pyramid. Default: -1, which means the last level.
        add_extra_convs (bool): It decides whether to add conv
            layers on top of the original feature maps. Default to False.
            If True, its actual mode is specified by `extra_convs_on_inputs`.
        norm_cfg (dict): Config dict for normalization layer. Default: None.
        init_cfg (dict or list[dict], optional): Initialization config dict.
    """

    transition_types = {
        'conv': ConvModule,
        'interpolation_conv': UpInterpolationConv,
        'last_conv': LastConv,
    }

    def __init__(self,
                 in_channels,
                 out_channels,
                 num_outs,
                 stack_times,
                 paths,
                 inter_channels=None,
                 same_down_trans=None,
                 same_up_trans=dict(
                     type='conv', kernel_size=3, stride=2, padding=1),
                 across_lateral_trans=dict(type='conv', kernel_size=1),
                 across_down_trans=dict(type='conv', kernel_size=3),
                 across_up_trans=None,
                 across_skip_trans=dict(type='identity'),
                 output_trans=dict(type='last_conv', kernel_size=3),
                 start_level=0,
                 end_level=-1,
                 add_extra_convs=False,
                 norm_cfg=None,
                 skip_inds=None,
                 init_cfg=[
                     dict(type='Caffe2Xavier', layer='Conv2d'),
                     dict(
                         type='Constant',
                         layer=[
                             '_BatchNorm', '_InstanceNorm', 'GroupNorm',
                             'LayerNorm'
                         ],
                         val=1.0)
                 ]):
        super(FPG, self).__init__(init_cfg)
        assert isinstance(in_channels, list)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_ins = len(in_channels)
        self.num_outs = num_outs
        if inter_channels is None:
            self.inter_channels = [out_channels for _ in range(num_outs)]
        elif isinstance(inter_channels, int):
            self.inter_channels = [inter_channels for _ in range(num_outs)]
        else:
            assert isinstance(inter_channels, list)
            assert len(inter_channels) == num_outs
            self.inter_channels = inter_channels
        self.stack_times = stack_times
        self.paths = paths
        assert isinstance(paths, list) and len(paths) == stack_times
        for d in paths:
            assert d in ('bu', 'td')

        self.same_down_trans = same_down_trans
        self.same_up_trans = same_up_trans
        self.across_lateral_trans = across_lateral_trans
        self.across_down_trans = across_down_trans
        self.across_up_trans = across_up_trans
        self.output_trans = output_trans
        self.across_skip_trans = across_skip_trans

        self.with_bias = norm_cfg is None
        # skip inds must be specified if across skip trans is not None
        if self.across_skip_trans is not None:
            skip_inds is not None
        self.skip_inds = skip_inds
        assert len(self.skip_inds[0]) <= self.stack_times

        if end_level == -1:
            self.backbone_end_level = self.num_ins
            assert num_outs >= self.num_ins - start_level
        else:
            # if end_level < inputs, no extra level is allowed
            self.backbone_end_level = end_level
            assert end_level <= len(in_channels)
            assert num_outs == end_level - start_level
        self.start_level = start_level
        self.end_level = end_level
        self.add_extra_convs = add_extra_convs

        # build lateral 1x1 convs to reduce channels
        self.lateral_convs = nn.ModuleList()
        for i in range(self.start_level, self.backbone_end_level):
            l_conv = nn.Conv2d(self.in_channels[i],
                               self.inter_channels[i - self.start_level], 1)
            self.lateral_convs.append(l_conv)

        extra_levels = num_outs - self.backbone_end_level + self.start_level
        self.extra_downsamples = nn.ModuleList()
        for i in range(extra_levels):
            if self.add_extra_convs:
                fpn_idx = self.backbone_end_level - self.start_level + i
                extra_conv = nn.Conv2d(
                    self.inter_channels[fpn_idx - 1],
                    self.inter_channels[fpn_idx],
                    3,
                    stride=2,
                    padding=1)
                self.extra_downsamples.append(extra_conv)
            else:
                self.extra_downsamples.append(nn.MaxPool2d(1, stride=2))

        self.fpn_transitions = nn.ModuleList()  # stack times
        for s in range(self.stack_times):
            stage_trans = nn.ModuleList()  # num of feature levels
            for i in range(self.num_outs):
                # same, across_lateral, across_down, across_up
                trans = nn.ModuleDict()
                if s in self.skip_inds[i]:
                    stage_trans.append(trans)
                    continue
                # build same-stage down trans (used in bottom-up paths)
                if i == 0 or self.same_up_trans is None:
                    same_up_trans = None
                else:
                    same_up_trans = self.build_trans(
                        self.same_up_trans, self.inter_channels[i - 1],
                        self.inter_channels[i])
                trans['same_up'] = same_up_trans
                # build same-stage up trans (used in top-down paths)
                if i == self.num_outs - 1 or self.same_down_trans is None:
                    same_down_trans = None
                else:
                    same_down_trans = self.build_trans(
                        self.same_down_trans, self.inter_channels[i + 1],
                        self.inter_channels[i])
                trans['same_down'] = same_down_trans
                # build across lateral trans
                across_lateral_trans = self.build_trans(
                    self.across_lateral_trans, self.inter_channels[i],
                    self.inter_channels[i])
                trans['across_lateral'] = across_lateral_trans
                # build across down trans
                if i == self.num_outs - 1 or self.across_down_trans is None:
                    across_down_trans = None
                else:
                    across_down_trans = self.build_trans(
                        self.across_down_trans, self.inter_channels[i + 1],
                        self.inter_channels[i])
                trans['across_down'] = across_down_trans
                # build across up trans
                if i == 0 or self.across_up_trans is None:
                    across_up_trans = None
                else:
                    across_up_trans = self.build_trans(
                        self.across_up_trans, self.inter_channels[i - 1],
                        self.inter_channels[i])
                trans['across_up'] = across_up_trans
                if self.across_skip_trans is None:
                    across_skip_trans = None
                else:
                    across_skip_trans = self.build_trans(
                        self.across_skip_trans, self.inter_channels[i - 1],
                        self.inter_channels[i])
                trans['across_skip'] = across_skip_trans
                # build across_skip trans
                stage_trans.append(trans)
            self.fpn_transitions.append(stage_trans)

        self.output_transition = nn.ModuleList()  # output levels
        for i in range(self.num_outs):
            trans = self.build_trans(
                self.output_trans,
                self.inter_channels[i],
                self.out_channels,
                num_inputs=self.stack_times + 1)
            self.output_transition.append(trans)

        self.relu = nn.ReLU(inplace=True)

    def build_trans(self, cfg, in_channels, out_channels, **extra_args):
        cfg_ = cfg.copy()
        trans_type = cfg_.pop('type')
        trans_cls = self.transition_types[trans_type]
        return trans_cls(in_channels, out_channels, **cfg_, **extra_args)

    def fuse(self, fuse_dict):
        out = None
        for item in fuse_dict.values():
            if item is not None:
                if out is None:
                    out = item
                else:
                    out = out + item
        return out

    def forward(self, inputs):
        assert len(inputs) == len(self.in_channels)

        # build all levels from original feature maps
        feats = [
            lateral_conv(inputs[i + self.start_level])
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]
        for downsample in self.extra_downsamples:
            feats.append(downsample(feats[-1]))

        outs = [feats]

        for i in range(self.stack_times):
            current_outs = outs[-1]
            next_outs = []
            direction = self.paths[i]
            for j in range(self.num_outs):
                if i in self.skip_inds[j]:
                    next_outs.append(outs[-1][j])
                    continue
                # feature level
                if direction == 'td':
                    lvl = self.num_outs - j - 1
                else:
                    lvl = j
                # get transitions
                if direction == 'td':
                    same_trans = self.fpn_transitions[i][lvl]['same_down']
                else:
                    same_trans = self.fpn_transitions[i][lvl]['same_up']
                across_lateral_trans = self.fpn_transitions[i][lvl][
                    'across_lateral']
                across_down_trans = self.fpn_transitions[i][lvl]['across_down']
                across_up_trans = self.fpn_transitions[i][lvl]['across_up']
                across_skip_trans = self.fpn_transitions[i][lvl]['across_skip']
                # init output
                to_fuse = dict(
                    same=None, lateral=None, across_up=None, across_down=None)
                # same downsample/upsample
                if same_trans is not None:
                    to_fuse['same'] = same_trans(next_outs[-1])
                # across lateral
                if across_lateral_trans is not None:
                    to_fuse['lateral'] = across_lateral_trans(
                        current_outs[lvl])
                # across downsample
                if lvl > 0 and across_up_trans is not None:
                    to_fuse['across_up'] = across_up_trans(current_outs[lvl -
                                                                        1])
                # across upsample
                if (lvl < self.num_outs - 1 and across_down_trans is not None):
                    to_fuse['across_down'] = across_down_trans(
                        current_outs[lvl + 1])
                if across_skip_trans is not None:
                    to_fuse['across_skip'] = across_skip_trans(outs[0][lvl])
                x = self.fuse(to_fuse)
                next_outs.append(x)

            if direction == 'td':
                outs.append(next_outs[::-1])
            else:
                outs.append(next_outs)

        # output trans
        final_outs = []
        for i in range(self.num_outs):
            lvl_out_list = []
            for s in range(len(outs)):
                lvl_out_list.append(outs[s][i])
            lvl_out = self.output_transition[i](lvl_out_list)
            final_outs.append(lvl_out)

        return final_outs