Skip to content

Commit 4cd089f

Browse files
committed
Finalize the refit revision
1 parent b9768d5 commit 4cd089f

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

examples/apps/flux-demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
"use_fp32_acc": True,
4242
"use_explicit_typing": True,
4343
"debug": False,
44-
"use_python_runtime": False,
44+
"use_python_runtime": True,
4545
"immutable_weights": False,
4646
# "cache_built_engines": True,
4747
# "reuse_cached_engines": True,

examples/dynamo/torch_export_flux_dev.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
min_block_size=1,
113113
use_fp32_acc=True,
114114
use_explicit_typing=True,
115+
use_python_runtime=True,
115116
)
116117

117118
# %%
@@ -126,7 +127,7 @@
126127
torch.cuda.empty_cache()
127128
pipe.transformer = trt_gm
128129
pipe.transformer.config = config
129-
130+
trt_gm.device = torch.device("cuda")
130131
# %%
131132
# Image generation using prompt
132133
# ---------------------------

py/torch_tensorrt/dynamo/_refit.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -507,23 +507,22 @@ def refit_module_weights(
507507
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
508508
serialized_engine = engine.serialize_with_config(serialization_config)
509509

510-
del engine
511-
gc.collect()
512-
torch.cuda.empty_cache()
513-
514-
if isinstance(
515-
compiled_submodule, (PythonTorchTensorRTModule, TorchTensorRTModule)
516-
):
510+
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
511+
compiled_submodule.serialized_engine = bytes(serialized_engine)
512+
elif isinstance(compiled_submodule, TorchTensorRTModule):
517513
compiled_submodule.engine = None # Clear the engine for TorchTensorRTModule, otherwise it won't be updated
518514
compiled_submodule.serialized_engine = bytes(serialized_engine)
519515
compiled_submodule.setup_engine()
520-
521516
elif inline_module:
522517
new_engine_info = list(engine_info)
523518
new_engine_info[ENGINE_IDX] = bytes(serialized_engine)
524519
refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info))
525520
setattr(compiled_module, f"{name}_engine", refitted_engine)
526521

522+
del engine
523+
gc.collect()
524+
torch.cuda.empty_cache()
525+
527526
# TODO: Memory control prototyping. Under discussion
528527
if settings.offload_module_to_cpu:
529528
del new_partitioned_module

tests/py/dynamo/models/test_model_refit.py

+1
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,7 @@ def forward(self, x):
763763
debug=True,
764764
min_block_size=1,
765765
immutable_weights=False,
766+
offload_module_to_cpu=False,
766767
)
767768

768769
num_pyt_segments = len(

0 commit comments

Comments
 (0)