Skip to content

Commit

Permalink
Implement pass search
Browse files Browse the repository at this point in the history
TODO:
Remove old implementation and update documentation
  • Loading branch information
shaahji committed Jan 21, 2025
1 parent 9666502 commit 64d4138
Show file tree
Hide file tree
Showing 100 changed files with 5,299 additions and 540 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ celerybeat.pid
# Environments
.env
.venv
.vs
env/
venv/
ENV/
Expand Down
27 changes: 27 additions & 0 deletions examples/bert/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import argparse
import json
from pathlib import Path

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--optimize",
action="store_true",
help="If set, run transformers optimization pass",
)
args = parser.parse_args()

input_filename = "bert_cuda_gpu.template.json"
with Path(input_filename).open("r") as f:
config = json.load(f)

if not args.optimize:
del config["passes"]["transformers_optimization"]

output_filename = input_filename.replace(".template", "")
with Path(output_filename).open("w") as strm:
json.dump(config, fp=strm, indent=4)
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,7 @@
"transformers_optimization": { "type": "OrtTransformersOptimization", "float16": true },
"session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "glue_mrpc", "io_bind": true }
},
"pass_flows": [
[ "conversion", "transformers_optimization", "session_params_tuning" ],
[ "conversion", "session_params_tuning" ]
],
"search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 3, "seed": 0 },
"search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 3, "seed": 0 },
"host": "local_system",
"target": "local_system",
"evaluator": "common_evaluator",
Expand Down
2 changes: 1 addition & 1 deletion examples/bert/bert_inc_dynamic_ptq_cpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"transformers_optimization": { "type": "OrtTransformersOptimization", "model_type": "bert" },
"dynamic_quantization": { "type": "IncDynamicQuantization" }
},
"search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" },
"search_strategy": { "execution_order": "joint", "sampler": "sequential" },
"evaluator": "common_evaluator",
"cache_dir": "cache",
"output_dir": "models/bert_inc_dynamic_ptq_cpu"
Expand Down
2 changes: 1 addition & 1 deletion examples/bert/bert_inc_ptq_cpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
}
}
},
"search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" },
"search_strategy": { "execution_order": "joint", "sampler": "sequential" },
"evaluator": "common_evaluator",
"cache_dir": "cache",
"output_dir": "models/bert_inc_ptq_cpu"
Expand Down
2 changes: 1 addition & 1 deletion examples/bert/bert_ptq_cpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
},
"session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "glue_mrpc" }
},
"search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 3, "seed": 0 },
"search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 3, "seed": 0 },
"evaluator": "common_evaluator",
"host": "local_system",
"target": "local_system",
Expand Down
2 changes: 1 addition & 1 deletion examples/bert/notebook/bert_auto_opt_gpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
]
}
},
"search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 1, "seed": 0 },
"search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 1, "seed": 0 },
"evaluator": "common_evaluator",
"host": "local_system",
"target": "local_system",
Expand Down
10 changes: 4 additions & 6 deletions examples/bert/notebook/multi_ep_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,15 @@
"#### Engine and search strategy\n",
"\n",
"Engine is used to manage the optimization process where we run optimization on host device, and run evaluation on target device.\n",
"Search strategy is used to search the optimal optimization among different EPs. In this notebook, we use `joint` as the `execution_order` and `tpe` as the `search_algorithm`. We set the `num_samples` to 1 and `seed` to 0.\n",
"Search strategy is used to search the optimal optimization among different EPs. In this notebook, we use `joint` as the `execution_order` and `tpe` as the `sampler`. We set the `max_samples` to 1 and `seed` to 0.\n",
"\n",
"```json\n",
"\"engine\": {\n",
" \"search_strategy\": {\n",
" \"execution_order\": \"joint\",\n",
" \"search_algorithm\": \"tpe\",\n",
" \"search_algorithm_config\": {\n",
" \"num_samples\": 1,\n",
" \"seed\": 0\n",
" }\n",
" \"sampler\": \"tpe\",\n",
" \"max_samples\": 1,\n",
" \"seed\": 0\n",
" },\n",
" \"evaluator\": \"common_evaluator\",\n",
" \"host\": \"local_system\",\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/deberta/deberta.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"quantization": { "type": "OnnxQuantization", "data_config": "glue_mnli_matched" },
"session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "glue_mnli_matched" }
},
"search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 3, "seed": 0 },
"search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 3, "seed": 0 },
"evaluator": "common_evaluator",
"cache_dir": "cache",
"output_dir": "models/microsoft-deberta"
Expand Down
2 changes: 1 addition & 1 deletion examples/directml/llm/config_llm.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
}
}
},
"search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" },
"search_strategy": { "execution_order": "joint", "sampler": "sequential" },
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"host": "local_system",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@
"keep_io_types": true
}
},
"pass_flows": [ [ "convert", "optimize" ] ],
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"host": "local_system",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@
"keep_io_types": true
}
},
"pass_flows": [ [ "convert", "optimize" ] ],
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"host": "local_system",
Expand Down
1 change: 0 additions & 1 deletion examples/directml/stable_diffusion_xl/config_unet.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@
"keep_io_types": true
}
},
"pass_flows": [ [ "convert", "optimize" ] ],
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"host": "local_system",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@
"use_gpu": true
}
},
"pass_flows": [ [ "convert", "optimize" ] ],
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"host": "local_system",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
"keep_io_types": true
}
},
"pass_flows": [ [ "convert", "optimize" ] ],
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"host": "local_system",
Expand Down
10 changes: 7 additions & 3 deletions examples/directml/stable_diffusion_xl/stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import shutil
import sys
import warnings
from collections import OrderedDict
from pathlib import Path
from typing import Dict

Expand Down Expand Up @@ -289,22 +290,25 @@ def run_inference(


def update_config_with_provider(config: Dict, provider: str, is_fp16: bool) -> Dict:
used_passes = {}
if provider == "dml":
# DirectML EP is the default, so no need to update config.
return config
used_passes = {"convert", "optimize"}
elif provider == "cuda":
if version.parse(OrtVersion) < version.parse("1.17.0"):
# disable skip_group_norm fusion since there is a shape inference bug which leads to invalid models
config["passes"]["optimize_cuda"]["optimization_options"] = {"enable_skip_group_norm": False}
# keep model fully in fp16 if use_fp16_fixed_vae is set
if is_fp16:
config["passes"]["optimize_cuda"].update({"float16": True, "keep_io_types": False})
config["pass_flows"] = [["convert", "optimize_cuda"]]
used_passes = {"convert", "optimize_cuda"}
config["systems"]["local_system"]["accelerators"][0]["execution_providers"] = ["CUDAExecutionProvider"]
return config
else:
raise ValueError(f"Unsupported provider: {provider}")

config["passes"] = OrderedDict([(k, v) for k, v in config["passes"].items() if k in used_passes])
return config


def optimize(
model_id: str,
Expand Down
17 changes: 13 additions & 4 deletions examples/llama2/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,26 @@ def get_general_config(args):
gqa = "gqa" if args.use_gqa else "mha"
config_name = f"llama2_{device}_{gqa}"

# add pass flows
# add pass names
if not args.use_gptq:
template_json["pass_flows"] = [flow for flow in SUPPORTED_WORKFLOWS[device] if "gptq" not in flow[0]]
used_passes = [
pass_name
for passes_names in SUPPORTED_WORKFLOWS[device]
if "gptq" not in passes_names[0]
for pass_name in passes_names
]
else:
template_json["pass_flows"] = [flow for flow in SUPPORTED_WORKFLOWS[device] if "gptq" in flow[0]]
used_passes = [
pass_name
for passes_names in SUPPORTED_WORKFLOWS[device]
if "gptq" in passes_names[0]
for pass_name in passes_names
]
auto_gptq_logger = logging.getLogger("auto_gptq")
auto_gptq_logger.addHandler(logging.StreamHandler(sys.stdout))
auto_gptq_logger.setLevel(logging.INFO)

# remove unused passes and set gqa related configs
used_passes = {pass_name for pass_flow in SUPPORTED_WORKFLOWS[device] for pass_name in pass_flow}
for pass_name in list(template_json["passes"].keys()):
if pass_name not in used_passes:
del template_json["passes"][pass_name]
Expand Down
2 changes: 1 addition & 1 deletion examples/llama2/llama2_lmeval.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
}
},
"auto_optimizer_config": { "opt_level": 0, "disable_auto_optimizer": true, "precision": "fp16" },
"search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" },
"search_strategy": { "execution_order": "joint", "sampler": "exhaustive" },
"evaluator": "evaluator",
"host": "local_system",
"target": "local_system",
Expand Down
7 changes: 3 additions & 4 deletions examples/llama2/llama2_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import argparse
import json
from collections import OrderedDict

from olive.workflows import run as olive_run

Expand Down Expand Up @@ -38,10 +39,8 @@ def main(raw_args=None):
template_json = json.loads(template_json_str)

# add pass flows
if args.metadata_only:
template_json["pass_flows"] = [["conversion", "metadata"]]
else:
template_json["pass_flows"] = [["builder", "session_params_tuning"]]
used_passes = {"conversion", "metadata"} if args.metadata_only else {"builder", "session_params_tuning"}
template_json["passes"] = OrderedDict([(k, v) for k, v in template_json["passes"].items() if k in used_passes])
template_json["output_dir"] = f"models/{model_name}"

# dump config
Expand Down
32 changes: 32 additions & 0 deletions examples/llama2/notebook/llama2_multiep/llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import argparse
import json
from pathlib import Path

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--device",
choices=["cpu", "gpu", "multi_ep"],
help="Device to use",
)
parser.add_argument(
"--quantize",
action="store_true",
help="If set, run transformers optimization pass",
)
args = parser.parse_args()

input_filename = f"config_{args.device}.template.json"
with Path(input_filename).open("r") as f:
config = json.load(f)

if not args.quantize:
del config["passes"]["blockwise_quant_int4"]

output_filename = input_filename.replace(".template", "")
with Path(output_filename).open("w") as strm:
json.dump(config, fp=strm, indent=4)
2 changes: 1 addition & 1 deletion examples/mistral/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main(raw_args=None):
optimized_model_dir = (
script_dir
/ config["output_dir"]
/ "-".join(config["pass_flows"][0])
/ "-".join(config["passes"].keys())
/ f"{config['output_name']}_{accelerator}-{ep_header}_model"
)

Expand Down
1 change: 0 additions & 1 deletion examples/mistral/mistral_fp16.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
"enable_profiling": false
}
},
"pass_flows": [ [ "convert", "optimize", "session_params_tuning" ] ],
"evaluate_input_model": false,
"evaluator": "common_evaluator",
"host": "local_system",
Expand Down
1 change: 0 additions & 1 deletion examples/mistral/mistral_int4.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
"all_tensors_to_one_file": true
}
},
"pass_flows": [ [ "convert", "optimize", "quantization" ] ],
"evaluate_input_model": false,
"evaluator": "common_evaluator",
"host": "local_system",
Expand Down
29 changes: 20 additions & 9 deletions examples/mobilenet/prepare_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,29 @@
import argparse
import json
import platform
from collections import OrderedDict
from pathlib import Path

from olive.common.constants import OS


def raw_qnn_config():
def raw_qnn_config(mode: str):
# pylint: disable=redefined-outer-name

with Path("./raw_qnn_sdk_template.json").open("r") as f:
with Path("raw_qnn_sdk_template.json").open("r") as f:
raw_qnn_config = json.load(f)

sys_platform = platform.system()

if mode == "convert":
used_passes = {"converter", "build_model_lib"}
elif mode == "quantize":
used_passes = {"quantization", "build_model_lib"}
else:
raise ValueError(f"Unsupported mode: {mode}")

if sys_platform == OS.LINUX:
raw_qnn_config["passes"]["qnn_context_binary"] = {
"type": "QNNContextBinaryGenerator",
"backend": "libQnnHtp.so",
}
raw_qnn_config["pass_flows"].append(["converter", "build_model_lib", "qnn_context_binary"])
used_passes.add("qnn_context_binary")
raw_qnn_config["passes"]["build_model_lib"]["lib_targets"] = "x86_64-linux-clang"
elif sys_platform == OS.WINDOWS:
raw_qnn_config["passes"]["build_model_lib"]["lib_targets"] = "x86_64-windows-msvc"
Expand All @@ -34,9 +38,11 @@ def raw_qnn_config():
elif sys_platform == OS.LINUX:
metric_config["user_config"]["inference_settings"]["qnn"]["backend"] = "libQnnCpu"

raw_qnn_config["passes"] = OrderedDict([(k, v) for k, v in raw_qnn_config["passes"].items() if k in used_passes])

with Path("raw_qnn_sdk_config.json").open("w") as f:
json_str = json.dumps(raw_qnn_config, indent=4)
input_file_path = Path("./data/eval/input_order.txt").resolve().as_posix()
input_file_path = Path("data/eval/input_order.txt").resolve().as_posix()
f.write(json_str.replace("<input_list.txt>", str(input_file_path)))


Expand All @@ -47,6 +53,11 @@ def raw_qnn_config():
action="store_true",
help="If set, use the raw qnn sdk instead of the qnn EP",
)
parser.add_argument(
"--mode",
choices=["convert", "quantize"],
help="Mode selection",
)
args = parser.parse_args()
if args.use_raw_qnn_sdk:
raw_qnn_config()
raw_qnn_config(args.mode)
4 changes: 2 additions & 2 deletions examples/mobilenet/raw_qnn_sdk_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@
"passes": {
"converter": { "type": "QNNConversion" },
"quantization": { "type": "QNNConversion", "extra_args": "--input_list <input_list.txt>" },
"build_model_lib": { "type": "QNNModelLibGenerator", "lib_targets": "x86_64-linux-clang" }
"build_model_lib": { "type": "QNNModelLibGenerator", "lib_targets": "x86_64-linux-clang" },
"qnn_context_binary": { "type": "QNNContextBinaryGenerator", "backend": "libQnnHtp.so" }
},
"pass_flows": [ [ "converter", "build_model_lib" ], [ "quantization", "build_model_lib" ] ],
"log_severity_level": 0,
"host": "local_system",
"target": "local_system",
Expand Down
Loading

0 comments on commit 64d4138

Please sign in to comment.