Skip to content

Commit a29573a

Browse files
authoredJan 8, 2025
add quantized model save and load for SD & LCM (#2588)
* add quantized model save and load for SD & LCM
1 parent 0491221 commit a29573a

File tree

10 files changed

+407
-373
lines changed

10 files changed

+407
-373
lines changed
 

‎models_v2/pytorch/LCM/inference/cpu/README.md

+6
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ bash download_dataset.sh
6262
| **BATCH_SIZE** (optional) | `export BATCH_SIZE=<set a value for batch size, else it will run with default batch size>` |
6363
| **TORCH_INDUCTOR** (optional) | `export TORCH_INDUCTOR=< 0 or 1> (Compile model with PyTorch Inductor backend)` |
6464
65+
* NOTE:
66+
For `compile-inductor` mode, please do calibration to get quantized model before running `INT8-BF16` or `INT8-FP32`.
67+
```
68+
bash do_calibration.sh
69+
```
70+
6571
8. Run `run_model.sh`
6672
6773
## Output

‎models_v2/pytorch/LCM/inference/cpu/diffusers.patch

+23-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py
2-
index 24abf54d..3fa7df5f 100644
2+
index 24abf54d6..3fa7df5f3 100644
33
--- a/src/diffusers/models/transformer_2d.py
44
+++ b/src/diffusers/models/transformer_2d.py
55
@@ -385,7 +385,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
@@ -21,7 +21,7 @@ index 24abf54d..3fa7df5f 100644
2121
output = hidden_states + residual
2222
elif self.is_input_vectorized:
2323
diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py
24-
index f248b243..27d4802d 100644
24+
index f248b243f..7c83d2cf5 100644
2525
--- a/src/diffusers/models/unet_2d_condition.py
2626
+++ b/src/diffusers/models/unet_2d_condition.py
2727
@@ -799,8 +799,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
@@ -34,8 +34,17 @@ index f248b243..27d4802d 100644
3434
attention_mask: Optional[torch.Tensor] = None,
3535
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
3636
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
37+
@@ -808,7 +808,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
38+
mid_block_additional_residual: Optional[torch.Tensor] = None,
39+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
40+
encoder_attention_mask: Optional[torch.Tensor] = None,
41+
- return_dict: bool = True,
42+
+ return_dict: bool = False,
43+
) -> Union[UNet2DConditionOutput, Tuple]:
44+
r"""
45+
The [`UNet2DConditionModel`] forward method.
3746
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
38-
index ff5eea2d..10ea4af1 100644
47+
index ff5eea2d5..8a9461c87 100644
3948
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
4049
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
4150
@@ -701,17 +701,33 @@ class LatentConsistencyModelPipeline(
@@ -58,16 +67,16 @@ index ff5eea2d..10ea4af1 100644
5867
+ model_pred = self.traced_unet(
5968
+ latents.to(memory_format=torch.channels_last).to(dtype=self.precision),
6069
+ t,
61-
+ encoder_hidden_states=prompt_embeds.to(dtype=self.precision),
62-
+ timestep_cond=w_embedding.to(dtype=self.precision)
63-
+ )['sample']
70+
+ prompt_embeds.to(dtype=self.precision),
71+
+ w_embedding.to(dtype=self.precision)
72+
+ )[0]
6473
+ elif hasattr(self, 'precision'):
6574
+ model_pred = self.unet(
6675
+ latents.to(memory_format=torch.channels_last).to(dtype=self.precision),
6776
+ t,
68-
+ encoder_hidden_states=prompt_embeds.to(dtype=self.precision),
69-
+ timestep_cond=w_embedding.to(dtype=self.precision)
70-
+ )['sample']
77+
+ prompt_embeds.to(dtype=self.precision),
78+
+ w_embedding.to(dtype=self.precision)
79+
+ )[0]
7180
+ else:
7281
+ model_pred = self.unet(
7382
+ latents,
@@ -91,7 +100,7 @@ index ff5eea2d..10ea4af1 100644
91100
if not output_type == "latent":
92101
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
93102
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
94-
index 9911cbe7..98c7f2ab 100644
103+
index 9911cbe75..a4e7101e3 100644
95104
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
96105
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
97106
@@ -832,19 +832,33 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
@@ -116,14 +125,14 @@ index 9911cbe7..98c7f2ab 100644
116125
+ noise_pred = self.traced_unet(
117126
+ latent_model_input.to(memory_format=torch.channels_last).to(dtype=self.precision),
118127
+ t,
119-
+ encoder_hidden_states=prompt_embeds.to(dtype=self.precision)
120-
+ )['sample']
128+
+ prompt_embeds.to(dtype=self.precision)
129+
+ )[0]
121130
+ elif hasattr(self, 'precision'):
122131
+ noise_pred = self.unet(
123132
+ latent_model_input.to(memory_format=torch.channels_last).to(dtype=self.precision),
124133
+ t,
125-
+ encoder_hidden_states=prompt_embeds.to(dtype=self.precision)
126-
+ )['sample']
134+
+ prompt_embeds.to(dtype=self.precision)
135+
+ )[0]
127136
+ else:
128137
+ noise_pred = self.unet(
129138
+ latent_model_input,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/usr/bin/env bash
2+
#
3+
# Copyright (c) 2024 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
MODEL_DIR=${MODEL_DIR-$PWD}
19+
20+
if [ ! -e "${MODEL_DIR}/inference.py" ]; then
21+
echo "Could not find the script of inference.py. Please set environment variable '\${MODEL_DIR}'."
22+
echo "From which the inference.py exist at the: \${MODEL_DIR}/inference.py"
23+
exit 1
24+
fi
25+
26+
if [ ! -d "${DATASET_DIR}" ]; then
27+
echo "The DATASET_DIR \${DATASET_DIR} does not exist"
28+
exit 1
29+
fi
30+
31+
if [ -z "${OUTPUT_DIR}" ]; then
32+
echo "The required environment variable OUTPUT_DIR has not been set"
33+
exit 1
34+
fi
35+
36+
INT8_MODEL=${INT8_MODEL:-"quantized_model.pt2"}
37+
38+
mkdir -p ${OUTPUT_DIR}
39+
40+
export DNNL_PRIMITIVE_CACHE_CAPACITY=1024
41+
export KMP_BLOCKTIME=200
42+
export KMP_AFFINITY=granularity=fine,compact,1,0
43+
44+
export TORCHINDUCTOR_FREEZING=1
45+
export TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC=0
46+
export TORCHINDUCTOR_ENABLE_LINEAR_BINARY_FOLDING=1
47+
48+
python -m torch.backends.xeon.run_cpu --disable-numactl \
49+
--log_path ${OUTPUT_DIR} \
50+
${MODEL_DIR}/inference.py \
51+
--model_name_or_path="SimianLuo/LCM_Dreamshaper_v7" \
52+
--dataset_path=${DATASET_DIR} \
53+
--quantized_model_path=${INT8_MODEL} \
54+
--compile_inductor \
55+
--precision=int8-bf16 \
56+
--calibration

‎models_v2/pytorch/LCM/inference/cpu/inference.py

+98-167
Large diffs are not rendered by default.

‎models_v2/pytorch/LCM/inference/cpu/run_model.sh

+28-2
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,36 @@ elif [ "${PRECISION}" == "fp16" ]; then
6464
ARGS="$ARGS --precision=fp16"
6565
echo "### running fp16 datatype"
6666
elif [ "${PRECISION}" == "int8-bf16" ]; then
67-
ARGS="$ARGS --precision=int8-bf16 --configure-dir=conv_and_linear131.json"
67+
ARGS="$ARGS --precision=int8-bf16"
68+
if [ "${RUN_MODE}" == "ipex-jit" ]; then
69+
ARGS="$ARGS --configure-dir=conv_and_linear131.json"
70+
elif [ "${RUN_MODE}" == "compile-inductor" ]; then
71+
if [ ! -f "${INT8_MODEL}" ]; then
72+
echo "The required file INT8_MODEL does not exist"
73+
exit 1
74+
fi
75+
ARGS="$ARGS --quantized_model_path=${INT8_MODEL}"
76+
else
77+
echo "For int8-bf16 datatype, the specified mode '${RUN_MODE}' is unsupported."
78+
echo "Supported mode are: ipex-jit, compile-inductor"
79+
exit 1
80+
fi
6881
echo "### running int8-bf16 datatype"
6982
elif [ "${PRECISION}" == "int8-fp32" ]; then
70-
ARGS="$ARGS --precision=int8-fp32 --configure-dir=conv_and_linear131.json"
83+
ARGS="$ARGS --precision=int8-fp32"
84+
if [ "${RUN_MODE}" == "ipex-jit" ]; then
85+
ARGS="$ARGS --configure-dir=conv_and_linear131.json"
86+
elif [ "${RUN_MODE}" == "compile-inductor" ]; then
87+
if [ ! -f "${INT8_MODEL}" ]; then
88+
echo "The required file INT8_MODEL does not exist"
89+
exit 1
90+
fi
91+
ARGS="$ARGS --quantized_model_path=${INT8_MODEL}"
92+
else
93+
echo "For int8-fp32 datatype, the specified mode '${RUN_MODE}' is unsupported."
94+
echo "Supported mode are: ipex-jit, compile-inductor"
95+
exit 1
96+
fi
7197
echo "### running int8-fp32 datatype"
7298
elif [ "${PRECISION}" == "bf32" ]; then
7399
ARGS="$ARGS --precision=bf32"

‎models_v2/pytorch/stable_diffusion/inference/cpu/README.md

+8-5
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ export DATASET_DIR=<directory where the dataset will be saved>
4343
bash download_dataset.sh
4444
```
4545

46-
### **NOTE**:Int8 model
47-
48-
Please get a quant_model.pt before run INT8-BF16 model or INT8-FP32 model. Please refer the [link](https://github.com/intel/intel-extension-for-transformers/blob/v1.5/examples/huggingface/pytorch/text-to-image/quantization/qat/README.md).
49-
5046
# Inference
5147
1. `git clone https://github.com/IntelAI/models.git`
5248
2. `cd models/models_v2/pytorch/stable_diffusion/inference/cpu`
@@ -61,7 +57,6 @@ Please get a quant_model.pt before run INT8-BF16 model or INT8-FP32 model. Pleas
6157
```
6258
5. Install the latest CPU versions of [torch, torchvision and intel_extension_for_pytorch](https://intel.github.io/intel-extension-for-pytorch/index.html#installation)
6359
64-
6560
6. Setup required environment paramaters
6661
6762
| **Parameter** | **export command** |
@@ -79,6 +74,14 @@ Please get a quant_model.pt before run INT8-BF16 model or INT8-FP32 model. Pleas
7974
| **LOCAL_BATCH_SIZE** (optional for DISTRIBUTED) | `export LOCAL_BATCH_SIZE=64` |
8075
7. Run `run_model.sh`
8176
77+
* NOTE:
78+
Please get quantized model before running `INT8-BF16` or `INT8-FP32`.
79+
For `ipex-jit` mode, please refer the [link](https://github.com/intel/intel-extension-for-transformers/blob/v1.5/examples/huggingface/pytorch/text-to-image/quantization/qat/README.md).
80+
For `compile-inductor` mode, please do calibration first:
81+
```
82+
bash do_calibration.sh
83+
```
84+
8285
## Output
8386
8487
Single-tile output will typically looks like:

‎models_v2/pytorch/stable_diffusion/inference/cpu/diffusers.patch

+23-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py
2-
index 24abf54d..3fa7df5f 100644
2+
index 24abf54d6..3fa7df5f3 100644
33
--- a/src/diffusers/models/transformer_2d.py
44
+++ b/src/diffusers/models/transformer_2d.py
55
@@ -385,7 +385,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
@@ -21,7 +21,7 @@ index 24abf54d..3fa7df5f 100644
2121
output = hidden_states + residual
2222
elif self.is_input_vectorized:
2323
diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py
24-
index f248b243..27d4802d 100644
24+
index f248b243f..7c83d2cf5 100644
2525
--- a/src/diffusers/models/unet_2d_condition.py
2626
+++ b/src/diffusers/models/unet_2d_condition.py
2727
@@ -799,8 +799,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
@@ -34,8 +34,17 @@ index f248b243..27d4802d 100644
3434
attention_mask: Optional[torch.Tensor] = None,
3535
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
3636
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
37+
@@ -808,7 +808,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
38+
mid_block_additional_residual: Optional[torch.Tensor] = None,
39+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
40+
encoder_attention_mask: Optional[torch.Tensor] = None,
41+
- return_dict: bool = True,
42+
+ return_dict: bool = False,
43+
) -> Union[UNet2DConditionOutput, Tuple]:
44+
r"""
45+
The [`UNet2DConditionModel`] forward method.
3746
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
38-
index ff5eea2d..10ea4af1 100644
47+
index ff5eea2d5..8a9461c87 100644
3948
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
4049
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
4150
@@ -701,17 +701,33 @@ class LatentConsistencyModelPipeline(
@@ -58,16 +67,16 @@ index ff5eea2d..10ea4af1 100644
5867
+ model_pred = self.traced_unet(
5968
+ latents.to(memory_format=torch.channels_last).to(dtype=self.precision),
6069
+ t,
61-
+ encoder_hidden_states=prompt_embeds.to(dtype=self.precision),
62-
+ timestep_cond=w_embedding.to(dtype=self.precision)
63-
+ )['sample']
70+
+ prompt_embeds.to(dtype=self.precision),
71+
+ w_embedding.to(dtype=self.precision)
72+
+ )[0]
6473
+ elif hasattr(self, 'precision'):
6574
+ model_pred = self.unet(
6675
+ latents.to(memory_format=torch.channels_last).to(dtype=self.precision),
6776
+ t,
68-
+ encoder_hidden_states=prompt_embeds.to(dtype=self.precision),
69-
+ timestep_cond=w_embedding.to(dtype=self.precision)
70-
+ )['sample']
77+
+ prompt_embeds.to(dtype=self.precision),
78+
+ w_embedding.to(dtype=self.precision)
79+
+ )[0]
7180
+ else:
7281
+ model_pred = self.unet(
7382
+ latents,
@@ -91,7 +100,7 @@ index ff5eea2d..10ea4af1 100644
91100
if not output_type == "latent":
92101
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
93102
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
94-
index 9911cbe7..98c7f2ab 100644
103+
index 9911cbe75..a4e7101e3 100644
95104
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
96105
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
97106
@@ -832,19 +832,33 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
@@ -116,14 +125,14 @@ index 9911cbe7..98c7f2ab 100644
116125
+ noise_pred = self.traced_unet(
117126
+ latent_model_input.to(memory_format=torch.channels_last).to(dtype=self.precision),
118127
+ t,
119-
+ encoder_hidden_states=prompt_embeds.to(dtype=self.precision)
120-
+ )['sample']
128+
+ prompt_embeds.to(dtype=self.precision)
129+
+ )[0]
121130
+ elif hasattr(self, 'precision'):
122131
+ noise_pred = self.unet(
123132
+ latent_model_input.to(memory_format=torch.channels_last).to(dtype=self.precision),
124133
+ t,
125-
+ encoder_hidden_states=prompt_embeds.to(dtype=self.precision)
126-
+ )['sample']
134+
+ prompt_embeds.to(dtype=self.precision)
135+
+ )[0]
127136
+ else:
128137
+ noise_pred = self.unet(
129138
+ latent_model_input,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/usr/bin/env bash
2+
#
3+
# Copyright (c) 2024 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
MODEL_DIR=${MODEL_DIR-$PWD}
19+
20+
if [ ! -e "${MODEL_DIR}/inference.py" ]; then
21+
echo "Could not find the script of inference.py. Please set environment variable '\${MODEL_DIR}'."
22+
echo "From which the inference.py exist at the: \${MODEL_DIR}/inference.py"
23+
exit 1
24+
fi
25+
26+
if [ ! -d "${DATASET_DIR}" ]; then
27+
echo "The DATASET_DIR \${DATASET_DIR} does not exist"
28+
exit 1
29+
fi
30+
31+
if [ -z "${OUTPUT_DIR}" ]; then
32+
echo "The required environment variable OUTPUT_DIR has not been set"
33+
exit 1
34+
fi
35+
36+
INT8_MODEL=${INT8_MODEL:-"quantized_model.pt2"}
37+
38+
mkdir -p ${OUTPUT_DIR}
39+
40+
export DNNL_PRIMITIVE_CACHE_CAPACITY=1024
41+
export KMP_BLOCKTIME=200
42+
export KMP_AFFINITY=granularity=fine,compact,1,0
43+
44+
export TORCHINDUCTOR_FREEZING=1
45+
export TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC=0
46+
export TORCHINDUCTOR_ENABLE_LINEAR_BINARY_FOLDING=1
47+
48+
python -m torch.backends.xeon.run_cpu --disable-numactl \
49+
--log_path ${OUTPUT_DIR} \
50+
${MODEL_DIR}/inference.py \
51+
--dataset_path=${DATASET_DIR} \
52+
--quantized_model_path=${INT8_MODEL} \
53+
--compile_inductor \
54+
--precision=int8-bf16 \
55+
--calibration

‎models_v2/pytorch/stable_diffusion/inference/cpu/inference.py

+98-167
Large diffs are not rendered by default.

‎models_v2/pytorch/stable_diffusion/inference/cpu/run_model.sh

+12-4
Original file line numberDiff line numberDiff line change
@@ -116,22 +116,30 @@ elif [[ "${PRECISION}" == "fp16" ]]; then
116116
echo "### running fp16 datatype"
117117
elif [[ "${PRECISION}" == "int8-bf16" ]]; then
118118
ARGS="$ARGS --precision=int8-bf16"
119-
if [ "${MODE}" == "ipex-jit" ]; then
119+
if [[ "${MODE}" == "ipex-jit" || "${MODE}" == "compile-inductor" ]]; then
120120
if [ ! -f "${INT8_MODEL}" ]; then
121121
echo "The required file INT8_MODEL does not exist"
122122
exit 1
123123
fi
124-
ARGS="$ARGS --int8_model_path=${INT8_MODEL}"
124+
ARGS="$ARGS --quantized_model_path=${INT8_MODEL}"
125+
else
126+
echo "For int8-bf16 datatype, the specified mode '${MODE}' is unsupported."
127+
echo "Supported mode are: ipex-jit, compile-inductor"
128+
exit 1
125129
fi
126130
echo "### running int8-bf16 datatype"
127131
elif [[ "${PRECISION}" == "int8-fp32" ]]; then
128132
ARGS="$ARGS --precision=int8-fp32"
129-
if [ "${MODE}" == "ipex-jit" ]; then
133+
if [[ "${MODE}" == "ipex-jit" || "${MODE}" == "compile-inductor" ]]; then
130134
if [ ! -f "${INT8_MODEL}" ]; then
131135
echo "The required file INT8_MODEL does not exist"
132136
exit 1
133137
fi
134-
ARGS="$ARGS --int8_model_path=${INT8_MODEL}"
138+
ARGS="$ARGS --quantized_model_path=${INT8_MODEL}"
139+
else
140+
echo "For int8-fp32 datatype, the specified mode '${MODE}' is unsupported."
141+
echo "Supported mode are: ipex-jit, compile-inductor"
142+
exit 1
135143
fi
136144
echo "### running int8-fp32 datatype"
137145
elif [[ "${PRECISION}" == "bf32" ]]; then

0 commit comments

Comments
 (0)
Please sign in to comment.