yolox_mode_switch_hook.py
2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.parallel import is_module_wrapper
from mmcv.runner.hooks import HOOKS, Hook
@HOOKS.register_module()
class YOLOXModeSwitchHook(Hook):
"""Switch the mode of YOLOX during training.
This hook turns off the mosaic and mixup data augmentation and switches
to use L1 loss in bbox_head.
Args:
num_last_epochs (int): The number of latter epochs in the end of the
training to close the data augmentation and switch to L1 loss.
Default: 15.
skip_type_keys (list[str], optional): Sequence of type string to be
skip pipeline. Default: ('Mosaic', 'RandomAffine', 'MixUp')
"""
def __init__(self,
num_last_epochs=15,
skip_type_keys=('Mosaic', 'RandomAffine', 'MixUp')):
self.num_last_epochs = num_last_epochs
self.skip_type_keys = skip_type_keys
self._restart_dataloader = False
def before_train_epoch(self, runner):
"""Close mosaic and mixup augmentation and switches to use L1 loss."""
epoch = runner.epoch
train_loader = runner.data_loader
model = runner.model
if is_module_wrapper(model):
model = model.module
if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
runner.logger.info('No mosaic and mixup aug now!')
# The dataset pipeline cannot be updated when persistent_workers
# is True, so we need to force the dataloader's multi-process
# restart. This is a very hacky approach.
train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
if hasattr(train_loader, 'persistent_workers'
) and train_loader.persistent_workers is True:
train_loader._DataLoader__initialized = False
train_loader._iterator = None
self._restart_dataloader = True
runner.logger.info('Add additional L1 loss now!')
model.bbox_head.use_l1 = True
else:
# Once the restart is complete, we need to restore
# the initialization flag.
if self._restart_dataloader:
train_loader._DataLoader__initialized = True