Skip to content

Commit ffd4463

Browse files
committed
Fixed the comments
1 parent 4cd089f commit ffd4463

File tree

9 files changed

+32
-52
lines changed

9 files changed

+32
-52
lines changed

examples/apps/flux-demo.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import gradio as gr
44
import torch
55
import torch_tensorrt
6-
from diffusers import FluxPipeline, StableDiffusionPipeline
7-
from torch.export._trace import _export
6+
from diffusers import FluxPipeline
87

98
DEVICE = "cuda:0"
109
pipe = FluxPipeline.from_pretrained(
@@ -43,13 +42,7 @@
4342
"debug": False,
4443
"use_python_runtime": True,
4544
"immutable_weights": False,
46-
# "cache_built_engines": True,
47-
# "reuse_cached_engines": True,
48-
# "timing_cache_path": "/home/engine_cache/flux.bin",
49-
# "engine_cache_size": 40 * 1 << 30,
50-
# "enable_weight_streaming": True,
51-
# "weight_streaming_budget": 8 * 1 << 30
52-
# "enable_cuda_graph": True,
45+
"enable_cuda_graph": True,
5346
}
5447

5548
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
@@ -69,7 +62,6 @@ def generate_image(prompt, inference_step, batch_size=2):
6962

7063
generate_image(["Test"], 2)
7164
torch.cuda.empty_cache()
72-
# torch_tensorrt.MutableTorchTensorRTModule.save(trt_gm, "weight_streaming_Flux.pkl")
7365

7466

7567
def model_change(model):
@@ -97,8 +89,6 @@ def load_lora(path):
9789

9890

9991
generate_image(["Test"], 2)
100-
# load_lora("")
101-
# generate_image(["A golden retriever holding a sign to code"], 2)
10292

10393
# Create Gradio interface
10494
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:

examples/dynamo/refit_engine_example.py

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102

103103
# Check the output
104+
model2.to("cuda")
104105
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
105106
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
106107
assert torch.allclose(

py/torch_tensorrt/dynamo/_compiler.py

-2
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,6 @@ def compile(
422422
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
423423
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
424424
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
425-
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
426425
**kwargs: Any,
427426
) -> torch.fx.GraphModule:
428427
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -667,7 +666,6 @@ def compile(
667666
"enable_weight_streaming": enable_weight_streaming,
668667
"tiling_optimization_level": tiling_optimization_level,
669668
"l2_limit_for_tiling": l2_limit_for_tiling,
670-
"offload_module_to_cpu": offload_module_to_cpu,
671669
}
672670

673671
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
TILING_OPTIMIZATION_LEVEL = "none"
5050
L2_LIMIT_FOR_TILING = -1
5151
USE_DISTRIBUTED_MODE_TRACE = False
52-
OFFLOAD_MODULE_TO_CPU = True
5352

5453

5554
def default_device() -> Device:

py/torch_tensorrt/dynamo/_refit.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
TorchTensorRTModule,
3636
)
3737
from torch_tensorrt.dynamo.utils import (
38+
CPU_DEVICE,
3839
check_module_output,
40+
delete_module,
3941
get_model_device,
4042
get_torch_inputs,
4143
set_log_level,
@@ -312,9 +314,6 @@ def refit_module_weights(
312314
get_decompositions(settings.enable_experimental_decompositions)
313315
)
314316
new_gm = new_weight_module.module()
315-
# TODO: Memory control prototyping. Under discussion
316-
if settings.offload_module_to_cpu:
317-
new_weight_module.module().to("cpu")
318317

319318
logger.debug("Input graph: " + str(new_gm.graph))
320319
# Apply lowering on the graph module
@@ -393,7 +392,7 @@ def refit_module_weights(
393392

394393
# Iterate over all components that can be accelerated
395394
# Generate the corresponding TRT Module for those
396-
395+
new_weight_module.module().to(CPU_DEVICE)
397396
for name, new_submodule in new_partitioned_module.named_children():
398397
# Refit each submodule
399398
# Extract engine from the submodule
@@ -496,11 +495,7 @@ def refit_module_weights(
496495
settings=settings,
497496
weight_name_map=None,
498497
)
499-
# TODO: Memory control prototyping. Under discussion
500-
if settings.offload_module_to_cpu:
501-
del new_submodule
502-
gc.collect()
503-
torch.cuda.empty_cache()
498+
delete_module(new_submodule)
504499

505500
# clear EXCLUDE_WEIGHTS flag
506501
serialization_config = engine.create_serialization_config()
@@ -523,20 +518,18 @@ def refit_module_weights(
523518
gc.collect()
524519
torch.cuda.empty_cache()
525520

526-
# TODO: Memory control prototyping. Under discussion
527-
if settings.offload_module_to_cpu:
528-
del new_partitioned_module
529-
gc.collect()
530-
torch.cuda.empty_cache()
521+
delete_module(new_partitioned_module)
531522

532523
if verify_output and arg_inputs is not None:
524+
new_gm.to(torch.cuda.current_device())
533525
if check_module_output(
534526
new_module=new_gm,
535527
refitted_module=compiled_module,
536528
arg_inputs=torch_inputs,
537529
kwarg_inputs=torch_kwarg_inputs,
538530
):
539531
logger.info("Refitting Succeed!")
532+
new_gm.to(CPU_DEVICE)
540533
else:
541534
if weight_name_map:
542535
logger.warning(
@@ -552,6 +545,7 @@ def refit_module_weights(
552545
in_place=in_place,
553546
)
554547
logger.error("Refitting Failed! The outputs do not match.")
548+
new_gm.to(CPU_DEVICE)
555549
else:
556550
logger.info("Refitting Completed! Output verification skipped.")
557551

py/torch_tensorrt/dynamo/_settings.py

-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
28-
OFFLOAD_MODULE_TO_CPU,
2928
OPTIMIZATION_LEVEL,
3029
PASS_THROUGH_BUILD_FAILURES,
3130
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -141,7 +140,6 @@ class CompilationSettings:
141140
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
142141
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
143142
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144-
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
145143

146144

147145
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def forward(a, b, c=0, d=0):
209209

210210
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
211211

212-
def _get_total_dynamic_shapes(self) -> dict[str, Any] | None:
212+
def _get_total_dynamic_shapes(self) -> Union[dict[str, Any], None]:
213213
if not self.arg_dynamic_shapes and not self.kwarg_dynamic_shapes:
214214
return None
215215
total_dynamic_shape = {}
@@ -490,7 +490,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
490490
def to(self, *args: Any, **kwargs: Any) -> None:
491491
logger.warning(
492492
"Trying to move the original PyTorch model. This will cause CPU offloading failing and increase GPU memory usage."
493-
+ "If this is absolute necessary, please call module.pytorch_model.to(...)"
493+
+ "If this is absolute necessary, please call module.pytorch_model.to(...) \n"
494+
+ "The model is still on the original device."
494495
)
495496

496497
@property

py/torch_tensorrt/runtime/_cudagraphs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def __enter__(self) -> torch.nn.Module | torch.fx.GraphModule:
7676
self.old_module = self.compiled_module.gm
7777
self.compiled_module.gm = get_cuda_graph_module(self.compiled_module.gm)
7878
return self.compiled_module
79-
80-
return get_cuda_graph_module(self.compiled_module)
79+
else:
80+
return get_cuda_graph_module(self.compiled_module)
8181

8282
def __exit__(self, *args: Any) -> None:
8383
# Set cudagraphs back to old mode

0 commit comments

Comments
 (0)