File tree 4 files changed +11
-10
lines changed
4 files changed +11
-10
lines changed Original file line number Diff line number Diff line change 41
41
"use_fp32_acc" : True ,
42
42
"use_explicit_typing" : True ,
43
43
"debug" : False ,
44
- "use_python_runtime" : False ,
44
+ "use_python_runtime" : True ,
45
45
"immutable_weights" : False ,
46
46
# "cache_built_engines": True,
47
47
# "reuse_cached_engines": True,
Original file line number Diff line number Diff line change 112
112
min_block_size = 1 ,
113
113
use_fp32_acc = True ,
114
114
use_explicit_typing = True ,
115
+ use_python_runtime = True ,
115
116
)
116
117
117
118
# %%
126
127
torch .cuda .empty_cache ()
127
128
pipe .transformer = trt_gm
128
129
pipe .transformer .config = config
129
-
130
+ trt_gm . device = torch . device ( "cuda" )
130
131
# %%
131
132
# Image generation using prompt
132
133
# ---------------------------
Original file line number Diff line number Diff line change @@ -507,23 +507,22 @@ def refit_module_weights(
507
507
serialization_config .clear_flag (trt .SerializationFlag .EXCLUDE_WEIGHTS )
508
508
serialized_engine = engine .serialize_with_config (serialization_config )
509
509
510
- del engine
511
- gc .collect ()
512
- torch .cuda .empty_cache ()
513
-
514
- if isinstance (
515
- compiled_submodule , (PythonTorchTensorRTModule , TorchTensorRTModule )
516
- ):
510
+ if isinstance (compiled_submodule , PythonTorchTensorRTModule ):
511
+ compiled_submodule .serialized_engine = bytes (serialized_engine )
512
+ elif isinstance (compiled_submodule , TorchTensorRTModule ):
517
513
compiled_submodule .engine = None # Clear the engine for TorchTensorRTModule, otherwise it won't be updated
518
514
compiled_submodule .serialized_engine = bytes (serialized_engine )
519
515
compiled_submodule .setup_engine ()
520
-
521
516
elif inline_module :
522
517
new_engine_info = list (engine_info )
523
518
new_engine_info [ENGINE_IDX ] = bytes (serialized_engine )
524
519
refitted_engine = torch .classes .tensorrt .Engine (tuple (new_engine_info ))
525
520
setattr (compiled_module , f"{ name } _engine" , refitted_engine )
526
521
522
+ del engine
523
+ gc .collect ()
524
+ torch .cuda .empty_cache ()
525
+
527
526
# TODO: Memory control prototyping. Under discussion
528
527
if settings .offload_module_to_cpu :
529
528
del new_partitioned_module
Original file line number Diff line number Diff line change @@ -763,6 +763,7 @@ def forward(self, x):
763
763
debug = True ,
764
764
min_block_size = 1 ,
765
765
immutable_weights = False ,
766
+ offload_module_to_cpu = False ,
766
767
)
767
768
768
769
num_pyt_segments = len (
You can’t perform that action at this time.
0 commit comments