diff --git a/examples/stable_diffusion/.gitignore b/examples/stable_diffusion/.gitignore index 324c1834b..79875320b 100644 --- a/examples/stable_diffusion/.gitignore +++ b/examples/stable_diffusion/.gitignore @@ -1,2 +1,3 @@ /footprints/ /result_*.png +/quantize_data/ \ No newline at end of file diff --git a/examples/stable_diffusion/README.md b/examples/stable_diffusion/README.md index 3a5617287..b67ab7ace 100644 --- a/examples/stable_diffusion/README.md +++ b/examples/stable_diffusion/README.md @@ -179,3 +179,27 @@ Inference will loop until the generated image. The result will be saved as `resu Run `python stable_diffusion.py --help` for additional options. A few particularly relevant ones: - `--image_path `: the input image path for image to image inference. - `--img_to_img_example`: image to image example. The default input image is `assets/dog.png`, the default prompt is `amazing watercolor painting`. + +## Stable Diffusion Optimization with QDQ for QNN EP + +### Generate data for static quantization + +To get better result, we need to generate real data from original model instead of using random data for static quantization. + +First generate onnx unoptimized model (it also generates an optimized model using random data): + +`python stable_diffusion.py --model_id stabilityai/stable-diffusion-2-1-base --provider qnn --optimize --use_random_data --data_num 1` + +Then generate data (updating the prompt to generate more will be better): + +`python stable_diffusion.py --model_id stabilityai/stable-diffusion-2-1-base --provider qnn --generate_data --num_inference_steps 5 --seed 0 --test_unoptimized --prompt "hamburger swims in the river"` + +### Optimize + +`python stable_diffusion.py --model_id stabilityai/stable-diffusion-2-1-base --provider qnn --optimize --clean_cache` + +### Test + +We could add `--test_unoptimized` first to generate from original model for comparison. + +`python stable_diffusion.py --model_id stabilityai/stable-diffusion-2-1-base --provider qnn --num_inference_steps 5 --guidance_scale 7.5 --prompt "cat and dog" --seed 0` diff --git a/examples/stable_diffusion/config_text_encoder.json b/examples/stable_diffusion/config_text_encoder.json index 3747e4f80..72298e965 100644 --- a/examples/stable_diffusion/config_text_encoder.json +++ b/examples/stable_diffusion/config_text_encoder.json @@ -23,6 +23,12 @@ "user_script": "user_script.py", "load_dataset_config": { "type": "local_dataset" }, "dataloader_config": { "type": "text_encoder_data_loader", "batch_size": 1 } + }, + { + "name": "quantize_data_config", + "user_script": "user_script.py", + "load_dataset_config": { "type": "local_dataset" }, + "dataloader_config": { "type": "text_encoder_quantize_data_loader", "batch_size": 1 } } ], "evaluators": { @@ -38,7 +44,7 @@ } }, "passes": { - "convert": { "type": "OnnxConversion", "target_opset": 14 }, + "convert": { "type": "OnnxConversion", "target_opset": 17 }, "ov_convert": { "type": "OpenVINOConversion", "user_script": "user_script.py", @@ -83,6 +89,27 @@ "float16": true, "use_gpu": true, "keep_io_types": false + }, + "dynamic_shape_to_fixed": { + "type": "DynamicToFixedShape", + "dim_param": [ "batch", "sequence" ], + "dim_value": [ 1, 77 ] + }, + "qnn_preprocess": { + "type": "QNNPreprocess", + "fuse_layernorm": true + }, + "quantization": { + "type": "OnnxStaticQuantization", + "data_config": "quantize_data_config", + "activation_type": "QUInt16", + "weight_type": "QUInt8", + "calibrate_method": "MinMax", + "quant_preprocess": true, + "prepare_qnn_config": true, + "op_types_to_quantize": [ "MatMul", "LayerNormalization", "Reshape", "Transpose", "Mul", "Gather", "Gelu", "Flatten", "ArgMax" ], + "append_first_op_types_to_quantize_list": false, + "nodes_to_exclude": [ "Add", "Softmax" ] } }, "pass_flows": [ [ "convert", "optimize" ] ], diff --git a/examples/stable_diffusion/config_unet.json b/examples/stable_diffusion/config_unet.json index dfad5e88c..defde422c 100644 --- a/examples/stable_diffusion/config_unet.json +++ b/examples/stable_diffusion/config_unet.json @@ -32,6 +32,12 @@ "user_script": "user_script.py", "load_dataset_config": { "type": "local_dataset" }, "dataloader_config": { "type": "unet_data_loader", "batch_size": 1 } + }, + { + "name": "quantize_data_config", + "user_script": "user_script.py", + "load_dataset_config": { "type": "local_dataset" }, + "dataloader_config": { "type": "unet_quantize_data_loader", "batch_size": 1 } } ], "evaluators": { @@ -49,7 +55,7 @@ "passes": { "convert": { "type": "OnnxConversion", - "target_opset": 14, + "target_opset": 17, "save_as_external_data": true, "all_tensors_to_one_file": true, "external_data_name": "weights.pb" @@ -98,6 +104,24 @@ "float16": true, "use_gpu": true, "keep_io_types": false + }, + "dynamic_shape_to_fixed": { + "type": "DynamicToFixedShape", + "dim_param": [ "unet_sample_batch", "unet_sample_channels", "unet_sample_height", "unet_sample_width", "unet_time_batch", "unet_hidden_batch", "unet_hidden_sequence" ], + "dim_value": [ 1, 4, 64, 64, 1, 1, 77 ] + }, + "qnn_preprocess": { + "type": "QNNPreprocess", + "fuse_layernorm": true + }, + "quantization": { + "type": "OnnxStaticQuantization", + "data_config": "quantize_data_config", + "activation_type": "QUInt16", + "weight_type": "QUInt8", + "calibrate_method": "MinMax", + "quant_preprocess": true, + "prepare_qnn_config": true } }, "pass_flows": [ [ "convert", "optimize" ] ], diff --git a/examples/stable_diffusion/config_vae_decoder.json b/examples/stable_diffusion/config_vae_decoder.json index 362f49cb9..3f9f140a3 100644 --- a/examples/stable_diffusion/config_vae_decoder.json +++ b/examples/stable_diffusion/config_vae_decoder.json @@ -30,6 +30,12 @@ "user_script": "user_script.py", "load_dataset_config": { "type": "local_dataset" }, "dataloader_config": { "type": "vae_decoder_data_loader", "batch_size": 1 } + }, + { + "name": "quantize_data_config", + "user_script": "user_script.py", + "load_dataset_config": { "type": "local_dataset" }, + "dataloader_config": { "type": "vae_decoder_quantize_data_loader", "batch_size": 1 } } ], "evaluators": { @@ -45,7 +51,7 @@ } }, "passes": { - "convert": { "type": "OnnxConversion", "target_opset": 14 }, + "convert": { "type": "OnnxConversion", "target_opset": 17 }, "ov_convert": { "type": "OpenVINOConversion", "user_script": "user_script.py", @@ -90,6 +96,24 @@ "float16": true, "use_gpu": true, "keep_io_types": false + }, + "dynamic_shape_to_fixed": { + "type": "DynamicToFixedShape", + "dim_param": [ "decoder_batch", "decoder_channels", "decoder_height", "decoder_width" ], + "dim_value": [ 1, 4, 64, 64 ] + }, + "qnn_preprocess": { + "type": "QNNPreprocess", + "fuse_layernorm": true + }, + "quantization": { + "type": "OnnxStaticQuantization", + "data_config": "quantize_data_config", + "activation_type": "QUInt16", + "weight_type": "QUInt8", + "calibrate_method": "MinMax", + "quant_preprocess": true, + "prepare_qnn_config": true } }, "pass_flows": [ [ "convert", "optimize" ] ], diff --git a/examples/stable_diffusion/config_vae_encoder.json b/examples/stable_diffusion/config_vae_encoder.json index 61e46d298..cbfce6194 100644 --- a/examples/stable_diffusion/config_vae_encoder.json +++ b/examples/stable_diffusion/config_vae_encoder.json @@ -25,6 +25,12 @@ "user_script": "user_script.py", "load_dataset_config": { "type": "local_dataset" }, "dataloader_config": { "type": "vae_encoder_data_loader", "batch_size": 1 } + }, + { + "name": "quantize_data_config", + "user_script": "user_script.py", + "load_dataset_config": { "type": "local_dataset" }, + "dataloader_config": { "type": "vae_encoder_quantize_data_loader", "batch_size": 1 } } ], "evaluators": { @@ -40,7 +46,7 @@ } }, "passes": { - "convert": { "type": "OnnxConversion", "target_opset": 14 }, + "convert": { "type": "OnnxConversion", "target_opset": 17 }, "ov_convert": { "type": "OpenVINOConversion", "user_script": "user_script.py", @@ -85,6 +91,24 @@ "float16": true, "use_gpu": true, "keep_io_types": false + }, + "dynamic_shape_to_fixed": { + "type": "DynamicToFixedShape", + "dim_param": [ "encoder_batch", "encoder_channels", "encoder_height", "encoder_width", "Addlatent_sample_dim_0", "Addlatent_sample_dim_1", "Addlatent_sample_dim_2", "Addlatent_sample_dim_3" ], + "dim_value": [ 1, 3, 512, 512, 1, 4, 64, 64 ] + }, + "qnn_preprocess": { + "type": "QNNPreprocess", + "fuse_layernorm": true + }, + "quantization": { + "type": "OnnxStaticQuantization", + "data_config": "quantize_data_config", + "activation_type": "QUInt16", + "weight_type": "QUInt8", + "calibrate_method": "MinMax", + "quant_preprocess": true, + "prepare_qnn_config": true } }, "pass_flows": [ [ "convert", "optimize" ] ], diff --git a/examples/stable_diffusion/sd_utils/config.py b/examples/stable_diffusion/sd_utils/config.py index f8cfccd44..5730fd40d 100644 --- a/examples/stable_diffusion/sd_utils/config.py +++ b/examples/stable_diffusion/sd_utils/config.py @@ -6,3 +6,6 @@ vae_sample_size = 512 unet_sample_size = 64 cross_attention_dim = 768 +rand_data = True +data_dir = "quantize_data" +data_num = 10 diff --git a/examples/stable_diffusion/sd_utils/ort.py b/examples/stable_diffusion/sd_utils/ort.py index 01eab8923..bbfe29785 100644 --- a/examples/stable_diffusion/sd_utils/ort.py +++ b/examples/stable_diffusion/sd_utils/ort.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- + import json import shutil import sys @@ -63,7 +64,7 @@ def save_optimized_onnx_submodel(submodel_name, provider, model_info): for footprint in footprints.values(): if footprint["from_pass"] == "OnnxConversion": conversion_footprint = footprint - elif footprint["from_pass"] == "OrtTransformersOptimization": + elif footprint["from_pass"] == "OrtTransformersOptimization" or footprint["from_pass"] == "OnnxStaticQuantization": optimizer_footprint = footprint assert conversion_footprint @@ -138,7 +139,7 @@ def get_ort_pipeline(model_dir, common_args, ort_args, guidance_scale): unet_sample_size = config.unet_sample_size if static_dims: - hidden_batch_size = batch_size if (guidance_scale == 0.0) else batch_size * 2 + hidden_batch_size = batch_size if (guidance_scale <= 1.0) else batch_size * 2 # Not necessary, but helps DML EP further optimize runtime performance. # batch_size is doubled for sample & hidden state because of classifier free guidance: # https://github.com/huggingface/diffusers/blob/46c52f9b9607e6ecb29c782c052aea313e6487b7/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L672 diff --git a/examples/stable_diffusion/sd_utils/qnn.py b/examples/stable_diffusion/sd_utils/qnn.py new file mode 100644 index 000000000..93ea702ac --- /dev/null +++ b/examples/stable_diffusion/sd_utils/qnn.py @@ -0,0 +1,225 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from typing import Dict +from diffusers import OnnxStableDiffusionPipeline +import inspect +from typing import Callable, List, Optional, Union +import onnxruntime as ort +import numpy as np +import torch +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE +import os + + +def update_qnn_config(config: Dict, submodel_name: str): + # TODO onnx or onnxruntime needs to fix this + if submodel_name == "unet": + config["input_model"]["io_config"]["dynamic_axes"] = None + config["pass_flows"] = [["convert", "qnn_preprocess", "quantization"]] + else: + config["pass_flows"] = [["convert", "dynamic_shape_to_fixed", "qnn_preprocess", "quantization"]] + config["systems"]["local_system"]["accelerators"][0]["device"] = "npu" + config["systems"]["local_system"]["accelerators"][0]["execution_providers"] = ["QNNExecutionProvider"] + config["evaluator"] = None + return config + + +class QnnStableDiffusionPipeline(OnnxStableDiffusionPipeline): + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if self.data_dir: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="np", + ).input_ids.astype(np.int32) + text_inputs.tofile(self.data_dir / "text_inputs.raw") + + uncond_input = self.tokenizer( + negative_prompt if negative_prompt else "", + padding="max_length", + max_length=77, + truncation=True, + return_tensors="np", + ).input_ids.astype(np.int32) + uncond_input.tofile(self.data_dir / "uncond_input.raw") + + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # get the initial random noise unless the user supplied it + latents_dtype = prompt_embeds.dtype + latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + if do_classifier_free_guidance: + neg_embeds, text_embeds = np.split(prompt_embeds, 2) + if self.data_dir: + neg_embeds.tofile(self.data_dir / "neg_embeds.raw") + text_embeds.tofile(self.data_dir / "text_embeds.raw") + elif self.data_dir: + prompt_embeds.tofile(self.data_dir / "text_embeds.raw") + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + latent_model_input = latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + + if self.data_dir: + latent_model_input.tofile(self.data_dir / f"{i}_latent.raw") + timestep.tofile(self.data_dir / f"{i}_timestep.raw") + + if do_classifier_free_guidance: + # Note that in QNN, we need to use static dimensions (batch is fixed to 1), so we need to split + noise_pred_uncond = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=neg_embeds) + noise_pred_uncond = noise_pred_uncond[0] + noise_pred_text = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeds) + noise_pred_text = noise_pred_text[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) + noise_pred = noise_pred[0] + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + if self.data_dir: + latents[0:1].tofile(self.data_dir / "latent.raw") + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + if self.data_dir: + image.tofile(self.data_dir / "output_img.raw") + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +def get_qnn_pipeline(model_dir, common_args, qnn_args, script_dir): + ort.set_default_logger_severity(3) + + print("Loading models into ORT session...") + sess_options = ort.SessionOptions() + + # TODO diffusers needs to support new parameter for QNN + # See https://github.com/huggingface/diffusers/issues/10658 + pipeline = QnnStableDiffusionPipeline.from_pretrained( + model_dir, provider="CPUExecutionProvider", sess_options=sess_options + ) + if qnn_args.generate_data: + pipeline.data_dir = script_dir / qnn_args.data_dir / common_args.prompt + os.makedirs(pipeline.data_dir, exist_ok=True) + else: + pipeline.data_dir = None + return pipeline \ No newline at end of file diff --git a/examples/stable_diffusion/stable_diffusion.py b/examples/stable_diffusion/stable_diffusion.py index 2cd45826a..7a799f4bc 100644 --- a/examples/stable_diffusion/stable_diffusion.py +++ b/examples/stable_diffusion/stable_diffusion.py @@ -177,7 +177,7 @@ def on_generate_click(): window.mainloop() -def update_config_with_provider(config: Dict, provider: str): +def update_config_with_provider(config: Dict, provider: str, submodel_name: str): if provider == "dml": # DirectML EP is the default, so no need to update config. return config @@ -189,6 +189,10 @@ def update_config_with_provider(config: Dict, provider: str): from sd_utils.ov import update_ov_config return update_ov_config(config) + elif provider == "qnn": + from sd_utils.qnn import update_qnn_config + + return update_qnn_config(config, submodel_name) else: raise ValueError(f"Unsupported provider: {provider}") @@ -244,7 +248,7 @@ def optimize( olive_config = None with (script_dir / f"config_{submodel_name}.json").open() as fin: olive_config = json.load(fin) - olive_config = update_config_with_provider(olive_config, provider) + olive_config = update_config_with_provider(olive_config, provider, submodel_name) if submodel_name in ("unet", "text_encoder"): olive_config["input_model"]["model_path"] = model_id @@ -284,7 +288,7 @@ def parse_common_args(raw_args): parser.add_argument("--model_id", default="CompVis/stable-diffusion-v1-4", type=str) parser.add_argument( - "--provider", default="dml", type=str, choices=["dml", "cuda", "openvino"], help="Execution provider to use" + "--provider", default="dml", type=str, choices=["dml", "cuda", "openvino", "qnn"], help="Execution provider to use" ) parser.add_argument("--optimize", action="store_true", help="Runs the optimization step") parser.add_argument("--clean_cache", action="store_true", help="Deletes the Olive cache") @@ -352,6 +356,16 @@ def parse_ov_args(raw_args): return parser.parse_known_args(raw_args) +def parse_qnn_args(raw_args): + parser = argparse.ArgumentParser("QNN arguments") + + parser.add_argument("--generate_data", action="store_true") + parser.add_argument("--data_dir", default="quantize_data", type=str) + parser.add_argument("--data_num", default=10, type=int) + parser.add_argument("--use_random_data", action="store_true") + + return parser.parse_known_args(raw_args) + def main(raw_args=None): common_args, extra_args = parse_common_args(raw_args) @@ -372,9 +386,14 @@ def main(raw_args=None): guidance_scale = 0.0 print(f"WARNING: Classifier free guidance has been forcefully disabled since {model_id} doesn't support it.") - ov_args, ort_args = None, None + ov_args, qnn_args, ort_args = None, None, None if provider == "openvino": ov_args, extra_args = parse_ov_args(extra_args) + elif provider == "qnn": + qnn_args, extra_args = parse_qnn_args(extra_args) + config.rand_data = qnn_args.use_random_data + config.data_dir = script_dir / qnn_args.data_dir + config.data_num = qnn_args.data_num else: ort_args, extra_args = parse_ort_args(extra_args) @@ -384,7 +403,7 @@ def main(raw_args=None): # TODO(jstoecker): clean up warning filter (mostly during conversion from torch to ONNX) with warnings.catch_warnings(): warnings.simplefilter("ignore") - if provider != "openvino": + if provider != "openvino" and provider != "qnn": from sd_utils.ort import validate_args validate_args(ort_args, common_args.provider) @@ -400,6 +419,10 @@ def main(raw_args=None): from sd_utils.ov import get_ov_pipeline pipeline = get_ov_pipeline(common_args, ov_args, optimized_model_dir) + elif provider == "qnn": + from sd_utils.qnn import get_qnn_pipeline + + pipeline = get_qnn_pipeline(model_dir, common_args, qnn_args, script_dir) else: from sd_utils.ort import get_ort_pipeline diff --git a/examples/stable_diffusion/user_script.py b/examples/stable_diffusion/user_script.py index 16cc2b0a1..4c2fd786e 100644 --- a/examples/stable_diffusion/user_script.py +++ b/examples/stable_diffusion/user_script.py @@ -8,8 +8,144 @@ from huggingface_hub import model_info from sd_utils import config from transformers.models.clip.modeling_clip import CLIPTextModel - from olive.data.registry import Registry +import os +import numpy as np +import sys + +# Generated data helpers + +class BaseDataLoader: + def __init__(self, total): + self.data = [] + self.total = total + if not config.rand_data: + self.data_folders = [config.data_dir / f.name for f in os.scandir(config.data_dir) if f.is_dir()] + + def __getitem__(self, idx): + print("getitem: " + str(idx)) + if idx >= len(self.data) or idx >= self.total: return None + return self.data[idx] + +class UnetGeneratedDataLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + latent_min = sys.float_info.max + latent_max = sys.float_info.min + time_min = sys.float_info.max + time_max = sys.float_info.min + text_min = sys.float_info.max + text_max = sys.float_info.min + + for f in self.data_folders: + text = torch.from_numpy(np.fromfile(f / 'text_embeds.raw', dtype=np.float32).reshape(1, 77, 1024)) + text_max = max(text_max, text.max()) + text_min = min(text_min, text.min()) + text_neg = torch.from_numpy(np.fromfile(f / 'neg_embeds.raw', dtype=np.float32).reshape(1, 77, 1024)) + text_max = max(text_max, text_neg.max()) + text_min = min(text_min, text_neg.min()) + for i in range(10000): + if os.path.exists(f / f'{i}_latent.raw') == False: break + + latent = torch.from_numpy(np.fromfile(f / f'{i}_latent.raw', dtype=np.float32).reshape(1, 4, 64, 64)) + latent_max = max(latent_max, latent.max()) + latent_min = min(latent_min, latent.min()) + time = torch.from_numpy(np.fromfile(f / f'{i}_timestep.raw', dtype=np.float32).reshape(1)) + time_max = max(time_max, time.max()) + time_min = min(time_min, time.min()) + self.data.append({ "sample": latent, "timestep": time, "encoder_hidden_states": text }) + self.data.append({ "sample": latent, "timestep": time, "encoder_hidden_states": text_neg }) + print(latent_min, latent_max, time_min, time_max, text_min, text_max) + + +class TextEncoderGeneratedDataLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + latent_min = sys.float_info.max + latent_max = sys.float_info.min + for f in self.data_folders: + data = torch.from_numpy(np.fromfile(f / 'text_inputs.raw', dtype=np.int32).reshape(1, 77)) + latent_max = max(latent_max, data.max()) + latent_min = min(latent_min, data.min()) + self.data.append({ "input_ids": data }) + data = torch.from_numpy(np.fromfile(f / 'uncond_input.raw', dtype=np.int32).reshape(1, 77)) + latent_max = max(latent_max, data.max()) + latent_min = min(latent_min, data.min()) + self.data.append({ "input_ids": data }) + print(latent_min, latent_max) + + +class VaeDecoderGeneratedDataLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + latent_min = sys.float_info.max + latent_max = sys.float_info.min + for f in self.data_folders: + data = torch.from_numpy(np.fromfile(f / 'latent.raw', dtype=np.float32).reshape(1, 4, 64, 64)) + latent_max = max(latent_max, data.max()) + latent_min = min(latent_min, data.min()) + self.data.append({ "latent_sample": data }) + print(latent_min, latent_max) + + +class VaeEncoderGeneratedDataLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + latent_min = sys.float_info.max + latent_max = sys.float_info.min + for f in self.data_folders: + data = torch.from_numpy(np.fromfile(f / 'output_img.raw', dtype=np.float32).reshape(1, 3, 512, 512)) + latent_max = max(latent_max, data.max()) + latent_min = min(latent_min, data.min()) + self.data.append({ "sample": data }) + print(latent_min, latent_max) + +# TODO clean this up + +def get_data_list(size, torch_dtype, total, value_min, value_max): + result = [] + result.append(torch.zeros(size, dtype=torch_dtype)) + result.append(torch.zeros(size, dtype=torch_dtype) + value_min) + result.append(torch.zeros(size, dtype=torch_dtype) + value_max) + total -= 3 + if total <= 0: return result + for i in range(total): + result.append(torch.rand(size, dtype=torch_dtype) * (value_max - value_min) + value_min) + return result + + +class UnetDataRandomLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + samples = get_data_list((1, 4, config.unet_sample_size, config.unet_sample_size), torch.float32, total, -11, 8) + timesteps = get_data_list((1), torch.float32, total, 0, 1000) + states = get_data_list((1, 77, config.cross_attention_dim), torch.float32, total, -8, 14) + for i in range(self.total): + self.data.append({ "sample": samples[i], "timestep": timesteps[i], "encoder_hidden_states": states[i] }) + + +class TextEncoderDataRandomLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + samples = get_data_list((1, 77), torch.float32, total, 0, 49407) + for i in range(self.total): + self.data.append({ "input_ids": samples[i].to(torch.int32) }) + + +class VaeDecoderDataRandomLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + samples = get_data_list((1, 4, config.unet_sample_size, config.unet_sample_size), torch.float32, total, -62, 50) + for i in range(self.total): + self.data.append({ "latent_sample": samples[i] }) + + +class VaeEncoderDataRandomLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + samples = get_data_list((1, 3, 512, 512), torch.float32, total, -1, 1) + for i in range(self.total): + self.data.append({ "sample": samples[i] }) # Helper latency-only dataloader that creates random tensors with no label @@ -142,6 +278,13 @@ def text_encoder_data_loader(dataset, batch_size, *args, **kwargs): return RandomDataLoader(text_encoder_inputs, batch_size, torch.int32) +@Registry.register_dataloader() +def text_encoder_quantize_data_loader(dataset, batch_size, *args, **kwargs): + if config.rand_data: + return TextEncoderDataRandomLoader(config.data_num) + return TextEncoderGeneratedDataLoader(config.data_num) + + # ----------------------------------------------------------------------------- # UNET # ----------------------------------------------------------------------------- @@ -202,6 +345,13 @@ def unet_data_loader(dataset, batch_size, *args, **kwargs): return RandomDataLoader(unet_inputs, batch_size, torch.float16) +@Registry.register_dataloader() +def unet_quantize_data_loader(dataset, batch_size, *args, **kwargs): + if config.rand_data: + return UnetDataRandomLoader(config.data_num) + return UnetGeneratedDataLoader(config.data_num) + + # ----------------------------------------------------------------------------- # VAE ENCODER # ----------------------------------------------------------------------------- @@ -227,6 +377,13 @@ def vae_encoder_data_loader(dataset, batch_size, *args, **kwargs): return RandomDataLoader(vae_encoder_inputs, batch_size, torch.float16) +@Registry.register_dataloader() +def vae_encoder_quantize_data_loader(dataset, batch_size, *args, **kwargs): + if config.rand_data: + return VaeEncoderDataRandomLoader(config.data_num) + return VaeEncoderGeneratedDataLoader(config.data_num) + + # ----------------------------------------------------------------------------- # VAE DECODER # ----------------------------------------------------------------------------- @@ -256,6 +413,13 @@ def vae_decoder_data_loader(dataset, batch_size, *args, **kwargs): return RandomDataLoader(vae_decoder_inputs, batch_size, torch.float16) +@Registry.register_dataloader() +def vae_decoder_quantize_data_loader(dataset, batch_size, *args, **kwargs): + if config.rand_data: + return VaeDecoderDataRandomLoader(config.data_num) + return VaeDecoderGeneratedDataLoader(config.data_num) + + # ----------------------------------------------------------------------------- # SAFETY CHECKER # -----------------------------------------------------------------------------