Skip to content

Commit 77305aa

Browse files
authored
feat: Add FLUX-1.dev model to the model zoo (#3382)
1 parent 5a4dd33 commit 77305aa

File tree

6 files changed

+170
-7
lines changed

6 files changed

+170
-7
lines changed

Diff for: docsrc/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ Model Zoo
141141
* :ref:`torch_export_gpt2`
142142
* :ref:`torch_export_llama2`
143143
* :ref:`torch_export_sam2`
144+
* :ref:`torch_export_flux_dev`
144145
* :ref:`notebooks`
145146

146147
.. toctree::
@@ -157,6 +158,7 @@ Model Zoo
157158
tutorials/_rendered_examples/dynamo/torch_export_gpt2
158159
tutorials/_rendered_examples/dynamo/torch_export_llama2
159160
tutorials/_rendered_examples/dynamo/torch_export_sam2
161+
tutorials/_rendered_examples/dynamo/torch_export_flux_dev
160162
tutorials/notebooks
161163

162164
Python API Documentation

Diff for: docsrc/tutorials/_rendered_examples/dog_code.png

969 KB
Loading

Diff for: examples/dynamo/README.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ Model Zoo
2020
* :ref:`_torch_compile_gpt2`: Compiling a GPT2 model using ``torch.compile``
2121
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
2222
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
23-
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
23+
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
24+
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)

Diff for: examples/dynamo/torch_export_flux_dev.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""
2+
.. _torch_export_flux_dev:
3+
4+
Compiling FLUX.1-dev model using the Torch-TensorRT dynamo backend
5+
===================================================================
6+
7+
This example illustrates the state of the art model `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ optimized using
8+
Torch-TensorRT.
9+
10+
**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications.
11+
12+
Install the following dependencies before compilation
13+
14+
.. code-block:: python
15+
16+
pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2"
17+
18+
There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example,
19+
we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
20+
"""
21+
22+
# %%
23+
# Import the following libraries
24+
# -----------------------------
25+
import torch
26+
import torch_tensorrt
27+
from diffusers import FluxPipeline
28+
from torch.export._trace import _export
29+
30+
# %%
31+
# Define the FLUX-1.dev model
32+
# -----------------------------
33+
# Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class.
34+
# ``FluxPipeline`` includes different components such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler`` necessary
35+
# to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument
36+
DEVICE = "cuda:0"
37+
pipe = FluxPipeline.from_pretrained(
38+
"black-forest-labs/FLUX.1-dev",
39+
torch_dtype=torch.float16,
40+
)
41+
pipe.to(DEVICE).to(torch.float16)
42+
# Store the config and transformer backbone
43+
config = pipe.transformer.config
44+
backbone = pipe.transformer
45+
46+
47+
# %%
48+
# Export the backbone using torch.export
49+
# --------------------------------------------------
50+
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2``
51+
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
52+
batch_size = 2
53+
BATCH = torch.export.Dim("batch", min=1, max=2)
54+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
55+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
56+
# To see this recommendation, you can try exporting using min=1, max=4096
57+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
58+
dynamic_shapes = {
59+
"hidden_states": {0: BATCH},
60+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
61+
"pooled_projections": {0: BATCH},
62+
"timestep": {0: BATCH},
63+
"txt_ids": {0: SEQ_LEN},
64+
"img_ids": {0: IMG_ID},
65+
"guidance": {0: BATCH},
66+
}
67+
# The guidance factor is of type torch.float32
68+
dummy_inputs = {
69+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
70+
DEVICE
71+
),
72+
"encoder_hidden_states": torch.randn(
73+
(batch_size, 512, 4096), dtype=torch.float16
74+
).to(DEVICE),
75+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
76+
DEVICE
77+
),
78+
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE),
79+
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
80+
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
81+
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
82+
}
83+
# This will create an exported program which is going to be compiled with Torch-TensorRT
84+
ep = _export(
85+
backbone,
86+
args=(),
87+
kwargs=dummy_inputs,
88+
dynamic_shapes=dynamic_shapes,
89+
strict=False,
90+
allow_complex_guards_as_runtime_asserts=True,
91+
)
92+
93+
# %%
94+
# Torch-TensorRT compilation
95+
# ---------------------------
96+
# .. note::
97+
# The compilation requires a GPU with high memory (> 80GB) since TensorRT is storing the weights in FP32 precision. This is a known issue and will be resolved in the future.
98+
#
99+
#
100+
# We enable ``FP32`` matmul accumulation using ``use_fp32_acc=True`` to ensure accuracy is preserved by introducing cast to ``FP32`` nodes.
101+
# We also enable explicit typing to ensure TensorRT respects the datatypes set by the user which is a requirement for FP32 matmul accumulation.
102+
# Since this is a 12 billion parameter model, it takes around 20-30 min to compile on H100 GPU. The model is completely convertible and results in
103+
# a single TensorRT engine.
104+
trt_gm = torch_tensorrt.dynamo.compile(
105+
ep,
106+
inputs=dummy_inputs,
107+
enabled_precisions={torch.float32},
108+
truncate_double=True,
109+
min_block_size=1,
110+
use_fp32_acc=True,
111+
use_explicit_typing=True,
112+
)
113+
114+
# %%
115+
# Post Processing
116+
# ---------------------------
117+
# Release the GPU memory occupied by the exported program and the pipe.transformer
118+
# Set the transformer in the Flux pipeline to the Torch-TRT compiled model
119+
backbone.to("cpu")
120+
del ep
121+
pipe.transformer = trt_gm
122+
pipe.transformer.config = config
123+
124+
# %%
125+
# Image generation using prompt
126+
# ---------------------------
127+
# Provide a prompt and the file name of the image to be generated. Here we use the
128+
# prompt ``A golden retriever holding a sign to code``.
129+
130+
131+
# Function which generates images from the flux pipeline
132+
def generate_image(pipe, prompt, image_name):
133+
seed = 42
134+
image = pipe(
135+
prompt,
136+
output_type="pil",
137+
num_inference_steps=20,
138+
generator=torch.Generator("cuda").manual_seed(seed),
139+
).images[0]
140+
image.save(f"{image_name}.png")
141+
print(f"Image generated using {image_name} model saved as {image_name}.png")
142+
143+
144+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
145+
146+
# %%
147+
# The generated image is as shown below
148+
#
149+
# .. image:: dog_code.png
150+
#

Diff for: py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ def remove_assert_scalar(
1515
"""Remove assert_scalar ops in the graph"""
1616
count = 0
1717
for node in gm.graph.nodes:
18-
if node.target == torch.ops.aten._assert_scalar.default:
18+
if (
19+
node.target == torch.ops.aten._assert_scalar.default
20+
or node == torch.ops.aten._assert_tensor_metadata.default
21+
):
1922
gm.graph.erase_node(node)
2023
count += 1
2124

Diff for: py/torch_tensorrt/dynamo/utils.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,16 @@ def prepare_inputs(
243243
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
244244
disable_memory_format_check: bool = False,
245245
) -> Any:
246-
if isinstance(inputs, Input):
246+
if inputs is None:
247+
return None
248+
249+
elif isinstance(inputs, Input):
247250
return inputs
248251

249-
elif isinstance(inputs, torch.Tensor):
252+
elif isinstance(inputs, (torch.Tensor, int, float, bool)):
250253
return Input.from_tensor(
251-
inputs, disable_memory_format_check=disable_memory_format_check
254+
torch.tensor(inputs),
255+
disable_memory_format_check=disable_memory_format_check,
252256
)
253257

254258
elif isinstance(inputs, (list, tuple)):
@@ -395,10 +399,13 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
395399
"""
396400
Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64
397401
"""
398-
if isinstance(tensor, (torch.Tensor, FakeTensor)):
399-
return tensor.dtype
402+
if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)):
403+
return torch.tensor(tensor).dtype
400404
elif isinstance(tensor, torch.SymInt):
401405
return torch.int64
406+
elif tensor is None:
407+
# Case where we explicitly pass one of the inputs to be None (eg: FLUX.1-dev)
408+
return None
402409
else:
403410
raise ValueError(f"Found invalid tensor type {type(tensor)}")
404411

0 commit comments

Comments
 (0)