Skip to content

Commit 297adef

Browse files
authored
fix: Destory cuda graphs before setting weight streaming (#3461)
1 parent 282ca74 commit 297adef

9 files changed

+38
-14
lines changed

core/runtime/TRTEngine.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,10 @@ std::vector<std::string> TRTEngine::serialize() {
453453
return serialized_info;
454454
}
455455

456+
void TRTEngine::reset_captured_graph() {
457+
cudagraph.reset();
458+
}
459+
456460
} // namespace runtime
457461
} // namespace core
458462
} // namespace torch_tensorrt

core/runtime/TRTEngine.h

+1
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ struct TRTEngine : torch::CustomClassHolder {
185185
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
186186

187187
void set_profiling_paths();
188+
void reset_captured_graph();
188189
#ifndef NDEBUG
189190
bool profile_execution = true;
190191
#else

core/runtime/register_jit_hooks.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8888
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
8989
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
9090
.def("infer_outputs", &TRTEngine::infer_outputs)
91+
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
9192
.def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs)
9293
.def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs)
9394
.def_property(

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
103103

104104
return False
105105

106-
def __del__(self) -> None:
106+
def _reset_captured_graph(self) -> None:
107107
if self.cudagraph:
108108
self.cudagraph.reset()
109+
self.cudagraph = None
110+
111+
def __del__(self) -> None:
112+
self._reset_captured_graph()
109113

110114
def set_use_output_allocator(self, enable: bool) -> None:
111115
self.use_output_allocator_outputs = enable
@@ -119,8 +123,7 @@ def forward(
119123
shape_changed = self.validate_input_shapes(inputs)
120124
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
121125
if need_cudagraphs_record:
122-
if self.cudagraph:
123-
self.cudagraph.reset()
126+
self._reset_captured_graph()
124127
self._input_buffers = [None] * len(inputs)
125128

126129
self.is_weight_streaming_set = False
@@ -196,7 +199,5 @@ def forward(
196199
return outputs[0]
197200
return outputs
198201
else:
199-
if self.cudagraph:
200-
self.cudagraph.reset()
201-
self.cudagraph = None
202+
self._reset_captured_graph()
202203
return self.compiled_module(*args, **kwargs)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,13 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule:
333333
result.__setstate__(self.__getstate__())
334334
return result
335335

336-
def __del__(self) -> None:
336+
def _reset_captured_graph(self) -> None:
337337
if self.cudagraph:
338338
self.cudagraph.reset()
339+
self.cudagraph = None
340+
341+
def __del__(self) -> None:
342+
self._reset_captured_graph()
339343

340344
def setup_input_tensors(
341345
self,
@@ -426,9 +430,8 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
426430
self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
427431
)
428432

429-
if need_cudagraphs_reset and self.cudagraph:
430-
self.cudagraph.reset()
431-
self.cudagraph = None
433+
if need_cudagraphs_reset:
434+
self._reset_captured_graph()
432435

433436
if need_cudagraphs_record:
434437
self._input_buffers = [None] * len(self.input_names)

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

+3
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
209209

210210
return budget_bytes
211211

212+
def _reset_captured_graph(self) -> None:
213+
self.engine.reset_captured_graph()
214+
212215
def setup_engine(self) -> None:
213216
"""
214217
Setup engine for a module which has deferred engine setup.

py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py

+3
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def automatic_device_memory_budget_getter(self) -> Any:
142142
def infer_outputs(self, input_shapes: List[Any]) -> Any:
143143
pass
144144

145+
def reset_captured_graph(self) -> Any:
146+
pass
147+
145148
def __setstate__(self, serialized_state: List[str]) -> Any:
146149
pass
147150

py/torch_tensorrt/runtime/_cudagraphs.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Union
2+
from typing import Any, Optional, Union
33

44
import torch
55
import torch_tensorrt
@@ -68,6 +68,7 @@ def __init__(self, compiled_module: torch.nn.Module) -> None:
6868
global _PY_RT_CUDAGRAPHS
6969
self.old_mode = _PY_RT_CUDAGRAPHS
7070
self.compiled_module = compiled_module
71+
self.cudagraphs_module: Optional[CudaGraphsTorchTensorRTModule] = None
7172

7273
def __enter__(self) -> torch.nn.Module:
7374
global _PY_RT_CUDAGRAPHS
@@ -98,7 +99,8 @@ def __enter__(self) -> torch.nn.Module:
9899
logger.debug(
99100
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
100101
)
101-
return CudaGraphsTorchTensorRTModule(self.compiled_module)
102+
self.cudagraphs_module = CudaGraphsTorchTensorRTModule(self.compiled_module)
103+
return self.cudagraphs_module
102104
else:
103105
if num_trt_module > 0:
104106
logger.debug("No graph breaks detected, using runtime cudagraphs mode")
@@ -113,6 +115,9 @@ def __enter__(self) -> torch.nn.Module:
113115
def __exit__(self, *args: Any) -> None:
114116
# Set cudagraphs back to old mode
115117
set_cudagraphs_mode(self.old_mode)
118+
# __del__ is not entirely predictable, so we reset cudagraph here
119+
if self.cudagraphs_module:
120+
self.cudagraphs_module._reset_captured_graph()
116121

117122

118123
def enable_cudagraphs(

py/torch_tensorrt/runtime/_weight_streaming.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,15 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int:
7676
int(streamable_bytes / total_bytes * requested_budget)
7777
for streamable_bytes in self.streamable_budget
7878
]
79+
if self.cuda_graphs_module:
80+
self.cuda_graphs_module.is_weight_streaming_set = True
81+
self.cuda_graphs_module._reset_captured_graph()
82+
7983
for i, (name, rt_mod) in enumerate(self.rt_mods):
84+
rt_mod._reset_captured_graph()
8085
ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i])
8186
logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}")
8287

83-
if self.cuda_graphs_module:
84-
self.cuda_graphs_module.is_weight_streaming_set = True
8588
return ws_budget_bytes
8689

8790
def __setattr__(self, name: str, value: Any) -> None:

0 commit comments

Comments
 (0)