Skip to content

Commit 0aeea36

Browse files
committed
Changed the way to enable CudaGraph for MTTM
1 parent c858f34 commit 0aeea36

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

-8
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
to_torch_device,
1818
to_torch_tensorrt_device,
1919
)
20-
from torch_tensorrt.runtime._cudagraphs import get_cuda_graph_module
2120

2221
logger = logging.getLogger(__name__)
2322

@@ -64,7 +63,6 @@ def __init__(
6463
*,
6564
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
6665
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
67-
enable_cuda_graph: bool = False,
6866
immutable_weights: bool = False,
6967
strict: bool = True,
7068
allow_complex_guards_as_runtime_asserts: bool = False,
@@ -160,7 +158,6 @@ def __init__(
160158
logger.warning(
161159
"Weight stremaing budget is not set. Using auto weight streaming budget"
162160
)
163-
self.enable_cuda_graph = enable_cuda_graph
164161

165162
cls = self.__class__
166163
self.__class__ = type(
@@ -347,8 +344,6 @@ def compile(self) -> None:
347344
)
348345
self.original_model.to("cpu")
349346
torch.cuda.empty_cache()
350-
if self.enable_cuda_graph:
351-
self._enable_cuda_graph()
352347
if self.enable_weight_streaming:
353348
self.set_weight_streaming_ctx(self.weight_streaming_budget)
354349

@@ -365,9 +360,6 @@ def set_weight_streaming_ctx(self, requested_budget: Optional[int] = None) -> No
365360
)
366361
self.weight_streaming_ctx.device_budget = requested_budget
367362

368-
def _enable_cuda_graph(self) -> None:
369-
self.gm = get_cuda_graph_module(self.gm)
370-
371363
def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
372364

373365
if not self.arg_inputs and not self.kwarg_inputs:

py/torch_tensorrt/runtime/_cudagraphs.py

+9
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,22 @@ def __init__(self, compiled_module: torch.nn.Module) -> None:
6868
global _PY_RT_CUDAGRAPHS
6969
self.old_mode = _PY_RT_CUDAGRAPHS
7070
self.compiled_module = compiled_module
71+
self.old_module = None
7172

7273
def __enter__(self) -> torch.nn.Module | torch.fx.GraphModule:
74+
75+
if isinstance(self.compiled_module, torch_tensorrt.MutableTorchTensorRTModule):
76+
self.old_module = self.compiled_module.gm
77+
self.compiled_module.gm = get_cuda_graph_module(self.compiled_module.gm)
78+
return self.compiled_module
79+
7380
return get_cuda_graph_module(self.compiled_module)
7481

7582
def __exit__(self, *args: Any) -> None:
7683
# Set cudagraphs back to old mode
7784
set_cudagraphs_mode(self.old_mode)
85+
if self.old_module: # MutableTorchTRTModule
86+
self.compiled_module.gm = self.old_module
7887

7988

8089
def get_cuda_graph_module(

0 commit comments

Comments
 (0)