Skip to content

Commit 1b7560c

Browse files
hhaAndroidZwwWayne
authored andcommitted
Add benchmark training scripts
1 parent 78bab5e commit 1b7560c

File tree

3 files changed

+150
-0
lines changed

3 files changed

+150
-0
lines changed

Diff for: .dev_scripts/benchmark_train.py

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import logging
3+
import os
4+
import os.path as osp
5+
from argparse import ArgumentParser
6+
7+
from mmengine.config import Config, DictAction
8+
from mmengine.logging import MMLogger, print_log
9+
from mmengine.registry import RUNNERS
10+
from mmengine.runner import Runner
11+
12+
from mmdet.testing import FastStopTrainingHook # noqa: F401,F403
13+
from mmdet.utils import register_all_modules, replace_cfg_vals
14+
15+
16+
def parse_args():
17+
parser = ArgumentParser()
18+
parser.add_argument('config', help='test config file path')
19+
parser.add_argument('--work-dir', help='the dir to save logs and models')
20+
parser.add_argument(
21+
'--amp',
22+
action='store_true',
23+
default=False,
24+
help='enable automatic-mixed-precision training')
25+
parser.add_argument(
26+
'--cfg-options',
27+
nargs='+',
28+
action=DictAction,
29+
help='override some settings in the used config, the key-value pair '
30+
'in xxx=yyy format will be merged into config file. If the value to '
31+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
32+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
33+
'Note that the quotation marks are necessary and that no white space '
34+
'is allowed.')
35+
parser.add_argument(
36+
'--launcher',
37+
choices=['none', 'pytorch', 'slurm', 'mpi'],
38+
default='none',
39+
help='job launcher')
40+
parser.add_argument('--local_rank', type=int, default=0)
41+
args = parser.parse_args()
42+
if 'LOCAL_RANK' not in os.environ:
43+
os.environ['LOCAL_RANK'] = str(args.local_rank)
44+
args = parser.parse_args()
45+
return args
46+
47+
48+
# TODO: Need to refactor train.py so that it can be reused.
49+
def fast_train_model(config_name, args, logger=None):
50+
cfg = Config.fromfile(config_name)
51+
cfg = replace_cfg_vals(cfg)
52+
cfg.launcher = args.launcher
53+
if args.cfg_options is not None:
54+
cfg.merge_from_dict(args.cfg_options)
55+
56+
# work_dir is determined in this priority: CLI > segment in file > filename
57+
if args.work_dir is not None:
58+
# update configs according to CLI args if args.work_dir is not None
59+
cfg.work_dir = args.work_dir
60+
elif cfg.get('work_dir', None) is None:
61+
# use config filename as default work_dir if cfg.work_dir is None
62+
cfg.work_dir = osp.join('./work_dirs',
63+
osp.splitext(osp.basename(args.config))[0])
64+
65+
if 'custom_hooks' in cfg:
66+
cfg.custom_hooks.append(dict(type='FastStopTrainingHook'))
67+
else:
68+
custom_hooks = [dict(type='FastStopTrainingHook')]
69+
cfg.custom_hooks = custom_hooks
70+
71+
# TODO: temporary plan
72+
if 'visualizer' in cfg:
73+
if 'name' in cfg.visualizer:
74+
del cfg.visualizer.name
75+
76+
# enable automatic-mixed-precision training
77+
if args.amp is True:
78+
optim_wrapper = cfg.optim_wrapper.type
79+
if optim_wrapper == 'AmpOptimWrapper':
80+
print_log(
81+
'AMP training is already enabled in your config.',
82+
logger='current',
83+
level=logging.WARNING)
84+
else:
85+
assert optim_wrapper == 'OptimWrapper', (
86+
'`--amp` is only supported when the optimizer wrapper type is '
87+
f'`OptimWrapper` but got {optim_wrapper}.')
88+
cfg.optim_wrapper.type = 'AmpOptimWrapper'
89+
cfg.optim_wrapper.loss_scale = 'dynamic'
90+
91+
# build the runner from config
92+
if 'runner_type' not in cfg:
93+
# build the default runner
94+
runner = Runner.from_cfg(cfg)
95+
else:
96+
# build customized runner from the registry
97+
# if 'runner_type' is set in the cfg
98+
runner = RUNNERS.build(cfg)
99+
100+
runner.train()
101+
102+
103+
# Sample test whether the train code is correct
104+
def main(args):
105+
# register all modules in mmdet into the registries
106+
register_all_modules(init_default_scope=False)
107+
108+
config = Config.fromfile(args.config)
109+
110+
# test all model
111+
logger = MMLogger.get_instance(
112+
name='MMLogger',
113+
log_file='benchmark_train.log',
114+
log_level=logging.ERROR)
115+
116+
for model_key in config:
117+
model_infos = config[model_key]
118+
if not isinstance(model_infos, list):
119+
model_infos = [model_infos]
120+
for model_info in model_infos:
121+
print('processing: ', model_info['config'], flush=True)
122+
config_name = model_info['config'].strip()
123+
try:
124+
fast_train_model(config_name, args, logger)
125+
except RuntimeError as e:
126+
# quick exit is the normal exit message
127+
if 'quick exit' not in repr(e):
128+
logger.error(f'{config_name} " : {repr(e)}')
129+
except Exception as e:
130+
logger.error(f'{config_name} " : {repr(e)}')
131+
132+
133+
if __name__ == '__main__':
134+
args = parse_args()
135+
main(args)

Diff for: mmdet/testing/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from ._fast_stop_training_hook import FastStopTrainingHook # noqa: F401,F403
23
from ._utils import (demo_mm_inputs, demo_mm_proposals,
34
demo_mm_sampling_results, get_detector_cfg,
45
get_roi_head_cfg)

Diff for: mmdet/testing/_fast_stop_training_hook.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmengine.hooks import Hook
3+
4+
from mmdet.registry import HOOKS
5+
6+
7+
@HOOKS.register_module()
8+
class FastStopTrainingHook(Hook):
9+
"""Set runner's epoch information to the model."""
10+
11+
def after_train_iter(self, runner, batch_idx: int, data_batch: None,
12+
outputs: None) -> None:
13+
if batch_idx >= 5:
14+
raise RuntimeError('quick exit')

0 commit comments

Comments
 (0)