Skip to content

Commit 1435a8e

Browse files
Add timing cache to accelerate consequent .engine export (#13386)
* fix: typos * feat: enable timing cache for engine export * Auto-format by https://ultralytics.com/actions --------- Co-authored-by: UltralyticsAssistant <[email protected]>
1 parent 3760e0e commit 1435a8e

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

export.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,9 @@ def export_coreml(model, im, file, int8, half, nms, mlmodel, prefix=colorstr("Co
593593

594594

595595
@try_export
596-
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr("TensorRT:")):
596+
def export_engine(
597+
model, im, file, half, dynamic, simplify, workspace=4, verbose=False, cache="", prefix=colorstr("TensorRT:")
598+
):
597599
"""
598600
Export a YOLOv5 model to TensorRT engine format, requiring GPU and TensorRT>=7.0.0.
599601
@@ -606,6 +608,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
606608
simplify (bool): Set to True to simplify the model during export.
607609
workspace (int): Workspace size in GB (default is 4).
608610
verbose (bool): Set to True for verbose logging output.
611+
cache (str): Path to save the TensorRT timing cache.
609612
prefix (str): Log message prefix.
610613
611614
Returns:
@@ -660,6 +663,11 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
660663
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)
661664
else: # TensorRT versions 7, 8
662665
config.max_workspace_size = workspace * 1 << 30
666+
if cache: # enable timing cache
667+
Path(cache).parent.mkdir(parents=True, exist_ok=True)
668+
buf = Path(cache).read_bytes() if Path(cache).exists() else b""
669+
timing_cache = config.create_timing_cache(buf)
670+
config.set_timing_cache(timing_cache, ignore_mismatch=True)
663671
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
664672
network = builder.create_network(flag)
665673
parser = trt.OnnxParser(network, logger)
@@ -688,6 +696,9 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
688696
build = builder.build_serialized_network if is_trt10 else builder.build_engine
689697
with build(network, config) as engine, open(f, "wb") as t:
690698
t.write(engine if is_trt10 else engine.serialize())
699+
if cache: # save timing cache
700+
with open(cache, "wb") as c:
701+
c.write(config.get_timing_cache().serialize())
691702
return f, None
692703

693704

@@ -1277,6 +1288,7 @@ def run(
12771288
int8=False, # CoreML/TF INT8 quantization
12781289
per_tensor=False, # TF per tensor quantization
12791290
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
1291+
cache="", # TensorRT: timing cache path
12801292
simplify=False, # ONNX: simplify model
12811293
mlmodel=False, # CoreML: Export in *.mlmodel format
12821294
opset=12, # ONNX: opset version
@@ -1306,6 +1318,7 @@ def run(
13061318
int8 (bool): Apply INT8 quantization for CoreML or TensorFlow models. Default is False.
13071319
per_tensor (bool): Apply per tensor quantization for TensorFlow models. Default is False.
13081320
dynamic (bool): Enable dynamic axes for ONNX, TensorFlow, or TensorRT exports. Default is False.
1321+
cache (str): TensorRT timing cache path. Default is an empty string.
13091322
simplify (bool): Simplify the ONNX model during export. Default is False.
13101323
opset (int): ONNX opset version. Default is 12.
13111324
verbose (bool): Enable verbose logging for TensorRT export. Default is False.
@@ -1341,6 +1354,7 @@ def run(
13411354
int8=False,
13421355
per_tensor=False,
13431356
dynamic=False,
1357+
cache="",
13441358
simplify=False,
13451359
opset=12,
13461360
verbose=False,
@@ -1378,7 +1392,8 @@ def run(
13781392
# Input
13791393
gs = int(max(model.stride)) # grid size (max stride)
13801394
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
1381-
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
1395+
ch = next(model.parameters()).size(1) # require input image channels
1396+
im = torch.zeros(batch_size, ch, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
13821397

13831398
# Update model
13841399
model.eval()
@@ -1402,7 +1417,7 @@ def run(
14021417
if jit: # TorchScript
14031418
f[0], _ = export_torchscript(model, im, file, optimize)
14041419
if engine: # TensorRT required before ONNX
1405-
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
1420+
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose, cache)
14061421
if onnx or xml: # OpenVINO requires ONNX
14071422
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
14081423
if xml: # OpenVINO
@@ -1497,6 +1512,7 @@ def parse_opt(known=False):
14971512
parser.add_argument("--int8", action="store_true", help="CoreML/TF/OpenVINO INT8 quantization")
14981513
parser.add_argument("--per-tensor", action="store_true", help="TF per-tensor quantization")
14991514
parser.add_argument("--dynamic", action="store_true", help="ONNX/TF/TensorRT: dynamic axes")
1515+
parser.add_argument("--cache", type=str, default="", help="TensorRT: timing cache file path")
15001516
parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model")
15011517
parser.add_argument("--mlmodel", action="store_true", help="CoreML: Export in *.mlmodel format")
15021518
parser.add_argument("--opset", type=int, default=17, help="ONNX: opset version")

train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,10 @@ def main(opt, callbacks=Callbacks()):
717717
"perspective": (True, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
718718
"flipud": (True, 0.0, 1.0), # image flip up-down (probability)
719719
"fliplr": (True, 0.0, 1.0), # image flip left-right (probability)
720-
"mosaic": (True, 0.0, 1.0), # image mixup (probability)
720+
"mosaic": (True, 0.0, 1.0), # image mosaic (probability)
721721
"mixup": (True, 0.0, 1.0), # image mixup (probability)
722-
"copy_paste": (True, 0.0, 1.0),
723-
} # segment copy-paste (probability)
722+
"copy_paste": (True, 0.0, 1.0), # segment copy-paste (probability)
723+
}
724724

725725
# GA configs
726726
pop_size = 50

0 commit comments

Comments
 (0)