1
1
import logging
2
- from typing import Any , Union
2
+ from typing import Any , Optional , Union
3
3
4
4
import torch
5
5
import torch_tensorrt
@@ -68,6 +68,7 @@ def __init__(self, compiled_module: torch.nn.Module) -> None:
68
68
global _PY_RT_CUDAGRAPHS
69
69
self .old_mode = _PY_RT_CUDAGRAPHS
70
70
self .compiled_module = compiled_module
71
+ self .cudagraphs_module : Optional [CudaGraphsTorchTensorRTModule ] = None
71
72
72
73
def __enter__ (self ) -> torch .nn .Module :
73
74
global _PY_RT_CUDAGRAPHS
@@ -98,7 +99,8 @@ def __enter__(self) -> torch.nn.Module:
98
99
logger .debug (
99
100
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
100
101
)
101
- return CudaGraphsTorchTensorRTModule (self .compiled_module )
102
+ self .cudagraphs_module = CudaGraphsTorchTensorRTModule (self .compiled_module )
103
+ return self .cudagraphs_module
102
104
else :
103
105
if num_trt_module > 0 :
104
106
logger .debug ("No graph breaks detected, using runtime cudagraphs mode" )
@@ -113,6 +115,9 @@ def __enter__(self) -> torch.nn.Module:
113
115
def __exit__ (self , * args : Any ) -> None :
114
116
# Set cudagraphs back to old mode
115
117
set_cudagraphs_mode (self .old_mode )
118
+ # __del__ is not entirely predictable, so we reset cudagraph here
119
+ if self .cudagraphs_module :
120
+ self .cudagraphs_module ._reset_captured_graph ()
116
121
117
122
118
123
def enable_cudagraphs (
0 commit comments