|
| 1 | +""" |
| 2 | +.. _torch_export_flux_dev: |
| 3 | +
|
| 4 | +Compiling FLUX.1-dev model using the Torch-TensorRT dynamo backend |
| 5 | +=================================================================== |
| 6 | +
|
| 7 | +This example illustrates the state of the art model `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ optimized using |
| 8 | +Torch-TensorRT. |
| 9 | +
|
| 10 | +**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications. |
| 11 | +
|
| 12 | +Install the following dependencies before compilation |
| 13 | +
|
| 14 | +.. code-block:: python |
| 15 | +
|
| 16 | + pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" |
| 17 | +
|
| 18 | +There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example, |
| 19 | +we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency) |
| 20 | +""" |
| 21 | + |
| 22 | +# %% |
| 23 | +# Import the following libraries |
| 24 | +# ----------------------------- |
| 25 | +import torch |
| 26 | +import torch_tensorrt |
| 27 | +from diffusers import FluxPipeline |
| 28 | +from torch.export._trace import _export |
| 29 | + |
| 30 | +# %% |
| 31 | +# Define the FLUX-1.dev model |
| 32 | +# ----------------------------- |
| 33 | +# Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class. |
| 34 | +# ``FluxPipeline`` includes different components such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler`` necessary |
| 35 | +# to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument |
| 36 | +DEVICE = "cuda:0" |
| 37 | +pipe = FluxPipeline.from_pretrained( |
| 38 | + "black-forest-labs/FLUX.1-dev", |
| 39 | + torch_dtype=torch.float16, |
| 40 | +) |
| 41 | +pipe.to(DEVICE).to(torch.float16) |
| 42 | +# Store the config and transformer backbone |
| 43 | +config = pipe.transformer.config |
| 44 | +backbone = pipe.transformer |
| 45 | + |
| 46 | + |
| 47 | +# %% |
| 48 | +# Export the backbone using torch.export |
| 49 | +# -------------------------------------------------- |
| 50 | +# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2`` |
| 51 | +# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_ |
| 52 | +batch_size = 2 |
| 53 | +BATCH = torch.export.Dim("batch", min=1, max=2) |
| 54 | +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) |
| 55 | +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. |
| 56 | +# To see this recommendation, you can try exporting using min=1, max=4096 |
| 57 | +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) |
| 58 | +dynamic_shapes = { |
| 59 | + "hidden_states": {0: BATCH}, |
| 60 | + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, |
| 61 | + "pooled_projections": {0: BATCH}, |
| 62 | + "timestep": {0: BATCH}, |
| 63 | + "txt_ids": {0: SEQ_LEN}, |
| 64 | + "img_ids": {0: IMG_ID}, |
| 65 | + "guidance": {0: BATCH}, |
| 66 | +} |
| 67 | +# The guidance factor is of type torch.float32 |
| 68 | +dummy_inputs = { |
| 69 | + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to( |
| 70 | + DEVICE |
| 71 | + ), |
| 72 | + "encoder_hidden_states": torch.randn( |
| 73 | + (batch_size, 512, 4096), dtype=torch.float16 |
| 74 | + ).to(DEVICE), |
| 75 | + "pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to( |
| 76 | + DEVICE |
| 77 | + ), |
| 78 | + "timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE), |
| 79 | + "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE), |
| 80 | + "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), |
| 81 | + "guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE), |
| 82 | +} |
| 83 | +# This will create an exported program which is going to be compiled with Torch-TensorRT |
| 84 | +ep = _export( |
| 85 | + backbone, |
| 86 | + args=(), |
| 87 | + kwargs=dummy_inputs, |
| 88 | + dynamic_shapes=dynamic_shapes, |
| 89 | + strict=False, |
| 90 | + allow_complex_guards_as_runtime_asserts=True, |
| 91 | +) |
| 92 | + |
| 93 | +# %% |
| 94 | +# Torch-TensorRT compilation |
| 95 | +# --------------------------- |
| 96 | +# .. note:: |
| 97 | +# The compilation requires a GPU with high memory (> 80GB) since TensorRT is storing the weights in FP32 precision. This is a known issue and will be resolved in the future. |
| 98 | +# |
| 99 | +# |
| 100 | +# We enable ``FP32`` matmul accumulation using ``use_fp32_acc=True`` to ensure accuracy is preserved by introducing cast to ``FP32`` nodes. |
| 101 | +# We also enable explicit typing to ensure TensorRT respects the datatypes set by the user which is a requirement for FP32 matmul accumulation. |
| 102 | +# Since this is a 12 billion parameter model, it takes around 20-30 min to compile on H100 GPU. The model is completely convertible and results in |
| 103 | +# a single TensorRT engine. |
| 104 | +trt_gm = torch_tensorrt.dynamo.compile( |
| 105 | + ep, |
| 106 | + inputs=dummy_inputs, |
| 107 | + enabled_precisions={torch.float32}, |
| 108 | + truncate_double=True, |
| 109 | + min_block_size=1, |
| 110 | + use_fp32_acc=True, |
| 111 | + use_explicit_typing=True, |
| 112 | +) |
| 113 | + |
| 114 | +# %% |
| 115 | +# Post Processing |
| 116 | +# --------------------------- |
| 117 | +# Release the GPU memory occupied by the exported program and the pipe.transformer |
| 118 | +# Set the transformer in the Flux pipeline to the Torch-TRT compiled model |
| 119 | +backbone.to("cpu") |
| 120 | +del ep |
| 121 | +pipe.transformer = trt_gm |
| 122 | +pipe.transformer.config = config |
| 123 | + |
| 124 | +# %% |
| 125 | +# Image generation using prompt |
| 126 | +# --------------------------- |
| 127 | +# Provide a prompt and the file name of the image to be generated. Here we use the |
| 128 | +# prompt ``A golden retriever holding a sign to code``. |
| 129 | + |
| 130 | + |
| 131 | +# Function which generates images from the flux pipeline |
| 132 | +def generate_image(pipe, prompt, image_name): |
| 133 | + seed = 42 |
| 134 | + image = pipe( |
| 135 | + prompt, |
| 136 | + output_type="pil", |
| 137 | + num_inference_steps=20, |
| 138 | + generator=torch.Generator("cuda").manual_seed(seed), |
| 139 | + ).images[0] |
| 140 | + image.save(f"{image_name}.png") |
| 141 | + print(f"Image generated using {image_name} model saved as {image_name}.png") |
| 142 | + |
| 143 | + |
| 144 | +generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code") |
| 145 | + |
| 146 | +# %% |
| 147 | +# The generated image is as shown below |
| 148 | +# |
| 149 | +# .. image:: dog_code.png |
| 150 | +# |
0 commit comments