Skip to content

Commit c698e60

Browse files
committed
Added flux demo
1 parent f4219f7 commit c698e60

File tree

3 files changed

+165
-1
lines changed

3 files changed

+165
-1
lines changed

demo/flux_demo.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from typing import Any
2+
3+
import gradio as gr
4+
import torch
5+
import torch_tensorrt
6+
from diffusers import FluxPipeline
7+
from torch.export._trace import _export
8+
9+
# %%
10+
# Define the FLUX-1.dev model
11+
# -----------------------------
12+
# Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class.
13+
# ``FluxPipeline`` includes different components such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler`` necessary
14+
# to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument
15+
DEVICE = "cuda:0"
16+
pipe = FluxPipeline.from_pretrained(
17+
"black-forest-labs/FLUX.1-dev",
18+
torch_dtype=torch.float16,
19+
)
20+
pipe.to(DEVICE).to(torch.float16)
21+
# Store the config and transformer backbone
22+
config = pipe.transformer.config
23+
backbone = pipe.transformer
24+
25+
26+
# %%
27+
# Export the backbone using torch.export
28+
# --------------------------------------------------
29+
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2``
30+
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
31+
batch_size = 2
32+
BATCH = torch.export.Dim("batch", min=1, max=2)
33+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
34+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
35+
# To see this recommendation, you can try exporting using min=1, max=4096
36+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
37+
dynamic_shapes = {
38+
"hidden_states": {0: BATCH},
39+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
40+
"pooled_projections": {0: BATCH},
41+
"timestep": {0: BATCH},
42+
"txt_ids": {0: SEQ_LEN},
43+
"img_ids": {0: IMG_ID},
44+
"guidance": {0: BATCH},
45+
"joint_attention_kwargs": {},
46+
"return_dict": None,
47+
}
48+
49+
dummy_inputs = {
50+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
51+
DEVICE
52+
),
53+
"encoder_hidden_states": torch.randn(
54+
(batch_size, 512, 4096), dtype=torch.float16
55+
).to(DEVICE),
56+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
57+
DEVICE
58+
),
59+
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE),
60+
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
61+
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
62+
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
63+
"joint_attention_kwargs": {},
64+
"return_dict": False,
65+
}
66+
# This will create an exported program which is going to be compiled with Torch-TensorRT
67+
ep = _export(
68+
backbone,
69+
args=(),
70+
kwargs=dummy_inputs,
71+
dynamic_shapes=dynamic_shapes,
72+
strict=False,
73+
allow_complex_guards_as_runtime_asserts=True,
74+
)
75+
76+
trt_gm = torch_tensorrt.dynamo.compile(
77+
ep,
78+
inputs=dummy_inputs,
79+
enabled_precisions={torch.float32},
80+
truncate_double=True,
81+
min_block_size=1,
82+
use_fp32_acc=True,
83+
use_explicit_typing=True,
84+
debug=False,
85+
use_python_runtime=True,
86+
)
87+
backbone.to("cpu")
88+
del ep
89+
pipe.transformer = trt_gm
90+
pipe.transformer.config = config
91+
torch.cuda.empty_cache()
92+
93+
94+
def generate_image(prompt: str, inference_step: int) -> Any:
95+
"""Generate image from text prompt using Stable Diffusion."""
96+
image = pipe(
97+
prompt,
98+
output_type="pil",
99+
num_inference_steps=inference_step,
100+
generator=torch.Generator("cuda"),
101+
).images[0]
102+
return image
103+
104+
105+
def model_change(model: str) -> None:
106+
if model == "Torch Model":
107+
pipe.transformer = backbone
108+
backbone.to(DEVICE)
109+
else:
110+
backbone.to("cpu")
111+
pipe.transformer = trt_gm
112+
torch.cuda.empty_cache()
113+
114+
115+
# Create Gradio interface
116+
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
117+
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")
118+
119+
with gr.Row():
120+
with gr.Column():
121+
# Input components
122+
prompt_input = gr.Textbox(
123+
label="Prompt", placeholder="Enter your prompt here...", lines=3
124+
)
125+
model_dropdown = gr.Dropdown(
126+
choices=["Torch Model", "Torch-TensorRT Accelerated Model"],
127+
value="Torch-TensorRT Accelerated Model",
128+
label="Model Variant",
129+
)
130+
131+
lora_upload = gr.File(
132+
label="Upload LoRA weights (.safetensors)", file_types=[".safetensors"]
133+
)
134+
num_steps = gr.Slider(
135+
minimum=20, maximum=100, value=20, step=1, label="Inference Steps"
136+
)
137+
138+
generate_btn = gr.Button("Generate Image")
139+
140+
with gr.Column():
141+
# Output component
142+
output_image = gr.Image(label="Generated Image")
143+
144+
# Connect the button to the generation function
145+
model_dropdown.change(model_change, inputs=[model_dropdown])
146+
generate_btn.click(
147+
fn=generate_image,
148+
inputs=[
149+
prompt_input,
150+
num_steps,
151+
],
152+
outputs=output_image,
153+
)
154+
155+
# Launch the interface
156+
if __name__ == "__main__":
157+
demo.launch()

py/torch_tensorrt/dynamo/_compiler.py

+3
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,8 @@ def compile(
672672
)
673673

674674
gm = exported_program.module()
675+
# TODO: Memory control prototyping. Under discussion
676+
exported_program.module().to("cpu")
675677
logger.debug("Input graph: " + str(gm.graph))
676678

677679
# Apply lowering on the graph module
@@ -808,6 +810,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
808810
trt_modules = {}
809811
# Iterate over all components that can be accelerated
810812
# Generate the corresponding TRT Module for those
813+
811814
for name, _ in partitioned_module.named_children():
812815
submodule = getattr(partitioned_module, name)
813816
# filter on the GraphModule

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,11 @@ def run(
710710
self._create_timing_cache(
711711
builder_config, self.compilation_settings.timing_cache_path
712712
)
713-
713+
# TODO: Memory control prototyping. Under discussion
714+
self.module.to("cpu")
715+
torch.cuda.empty_cache()
716+
del self.module
717+
gc.collect()
714718
serialized_engine = self.builder.build_serialized_network(
715719
self.ctx.net, builder_config
716720
)

0 commit comments

Comments
 (0)