17
17
to_torch_device ,
18
18
to_torch_tensorrt_device ,
19
19
)
20
- from torch_tensorrt .runtime ._cudagraphs import get_cuda_graph_module
21
20
22
21
logger = logging .getLogger (__name__ )
23
22
@@ -64,7 +63,6 @@ def __init__(
64
63
* ,
65
64
device : Optional [Union [Device , torch .device , str ]] = _defaults .DEVICE ,
66
65
use_python_runtime : bool = _defaults .USE_PYTHON_RUNTIME ,
67
- enable_cuda_graph : bool = False ,
68
66
immutable_weights : bool = False ,
69
67
strict : bool = True ,
70
68
allow_complex_guards_as_runtime_asserts : bool = False ,
@@ -160,7 +158,6 @@ def __init__(
160
158
logger .warning (
161
159
"Weight stremaing budget is not set. Using auto weight streaming budget"
162
160
)
163
- self .enable_cuda_graph = enable_cuda_graph
164
161
165
162
cls = self .__class__
166
163
self .__class__ = type (
@@ -347,8 +344,6 @@ def compile(self) -> None:
347
344
)
348
345
self .original_model .to ("cpu" )
349
346
torch .cuda .empty_cache ()
350
- if self .enable_cuda_graph :
351
- self ._enable_cuda_graph ()
352
347
if self .enable_weight_streaming :
353
348
self .set_weight_streaming_ctx (self .weight_streaming_budget )
354
349
@@ -365,9 +360,6 @@ def set_weight_streaming_ctx(self, requested_budget: Optional[int] = None) -> No
365
360
)
366
361
self .weight_streaming_ctx .device_budget = requested_budget
367
362
368
- def _enable_cuda_graph (self ) -> None :
369
- self .gm = get_cuda_graph_module (self .gm )
370
-
371
363
def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
372
364
373
365
if not self .arg_inputs and not self .kwarg_inputs :
0 commit comments