@@ -593,7 +593,9 @@ def export_coreml(model, im, file, int8, half, nms, mlmodel, prefix=colorstr("Co
593
593
594
594
595
595
@try_export
596
- def export_engine (model , im , file , half , dynamic , simplify , workspace = 4 , verbose = False , prefix = colorstr ("TensorRT:" )):
596
+ def export_engine (
597
+ model , im , file , half , dynamic , simplify , workspace = 4 , verbose = False , cache = "" , prefix = colorstr ("TensorRT:" )
598
+ ):
597
599
"""
598
600
Export a YOLOv5 model to TensorRT engine format, requiring GPU and TensorRT>=7.0.0.
599
601
@@ -606,6 +608,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
606
608
simplify (bool): Set to True to simplify the model during export.
607
609
workspace (int): Workspace size in GB (default is 4).
608
610
verbose (bool): Set to True for verbose logging output.
611
+ cache (str): Path to save the TensorRT timing cache.
609
612
prefix (str): Log message prefix.
610
613
611
614
Returns:
@@ -660,6 +663,11 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
660
663
config .set_memory_pool_limit (trt .MemoryPoolType .WORKSPACE , workspace << 30 )
661
664
else : # TensorRT versions 7, 8
662
665
config .max_workspace_size = workspace * 1 << 30
666
+ if cache : # enable timing cache
667
+ Path (cache ).parent .mkdir (parents = True , exist_ok = True )
668
+ buf = Path (cache ).read_bytes () if Path (cache ).exists () else b""
669
+ timing_cache = config .create_timing_cache (buf )
670
+ config .set_timing_cache (timing_cache , ignore_mismatch = True )
663
671
flag = 1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
664
672
network = builder .create_network (flag )
665
673
parser = trt .OnnxParser (network , logger )
@@ -688,6 +696,9 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
688
696
build = builder .build_serialized_network if is_trt10 else builder .build_engine
689
697
with build (network , config ) as engine , open (f , "wb" ) as t :
690
698
t .write (engine if is_trt10 else engine .serialize ())
699
+ if cache : # save timing cache
700
+ with open (cache , "wb" ) as c :
701
+ c .write (config .get_timing_cache ().serialize ())
691
702
return f , None
692
703
693
704
@@ -1277,6 +1288,7 @@ def run(
1277
1288
int8 = False , # CoreML/TF INT8 quantization
1278
1289
per_tensor = False , # TF per tensor quantization
1279
1290
dynamic = False , # ONNX/TF/TensorRT: dynamic axes
1291
+ cache = "" , # TensorRT: timing cache path
1280
1292
simplify = False , # ONNX: simplify model
1281
1293
mlmodel = False , # CoreML: Export in *.mlmodel format
1282
1294
opset = 12 , # ONNX: opset version
@@ -1306,6 +1318,7 @@ def run(
1306
1318
int8 (bool): Apply INT8 quantization for CoreML or TensorFlow models. Default is False.
1307
1319
per_tensor (bool): Apply per tensor quantization for TensorFlow models. Default is False.
1308
1320
dynamic (bool): Enable dynamic axes for ONNX, TensorFlow, or TensorRT exports. Default is False.
1321
+ cache (str): TensorRT timing cache path. Default is an empty string.
1309
1322
simplify (bool): Simplify the ONNX model during export. Default is False.
1310
1323
opset (int): ONNX opset version. Default is 12.
1311
1324
verbose (bool): Enable verbose logging for TensorRT export. Default is False.
@@ -1341,6 +1354,7 @@ def run(
1341
1354
int8=False,
1342
1355
per_tensor=False,
1343
1356
dynamic=False,
1357
+ cache="",
1344
1358
simplify=False,
1345
1359
opset=12,
1346
1360
verbose=False,
@@ -1378,7 +1392,8 @@ def run(
1378
1392
# Input
1379
1393
gs = int (max (model .stride )) # grid size (max stride)
1380
1394
imgsz = [check_img_size (x , gs ) for x in imgsz ] # verify img_size are gs-multiples
1381
- im = torch .zeros (batch_size , 3 , * imgsz ).to (device ) # image size(1,3,320,192) BCHW iDetection
1395
+ ch = next (model .parameters ()).size (1 ) # require input image channels
1396
+ im = torch .zeros (batch_size , ch , * imgsz ).to (device ) # image size(1,3,320,192) BCHW iDetection
1382
1397
1383
1398
# Update model
1384
1399
model .eval ()
@@ -1402,7 +1417,7 @@ def run(
1402
1417
if jit : # TorchScript
1403
1418
f [0 ], _ = export_torchscript (model , im , file , optimize )
1404
1419
if engine : # TensorRT required before ONNX
1405
- f [1 ], _ = export_engine (model , im , file , half , dynamic , simplify , workspace , verbose )
1420
+ f [1 ], _ = export_engine (model , im , file , half , dynamic , simplify , workspace , verbose , cache )
1406
1421
if onnx or xml : # OpenVINO requires ONNX
1407
1422
f [2 ], _ = export_onnx (model , im , file , opset , dynamic , simplify )
1408
1423
if xml : # OpenVINO
@@ -1497,6 +1512,7 @@ def parse_opt(known=False):
1497
1512
parser .add_argument ("--int8" , action = "store_true" , help = "CoreML/TF/OpenVINO INT8 quantization" )
1498
1513
parser .add_argument ("--per-tensor" , action = "store_true" , help = "TF per-tensor quantization" )
1499
1514
parser .add_argument ("--dynamic" , action = "store_true" , help = "ONNX/TF/TensorRT: dynamic axes" )
1515
+ parser .add_argument ("--cache" , type = str , default = "" , help = "TensorRT: timing cache file path" )
1500
1516
parser .add_argument ("--simplify" , action = "store_true" , help = "ONNX: simplify model" )
1501
1517
parser .add_argument ("--mlmodel" , action = "store_true" , help = "CoreML: Export in *.mlmodel format" )
1502
1518
parser .add_argument ("--opset" , type = int , default = 17 , help = "ONNX: opset version" )
0 commit comments