From 9fdba1aff3348d737911a69a823775700eb4608c Mon Sep 17 00:00:00 2001 From: hualxie Date: Thu, 13 Feb 2025 19:22:02 +0800 Subject: [PATCH 01/14] add basic eval --- examples/bge/bge-small-en-v1.5.json | 49 ++++++++++++++++++++++++++ examples/bge/user_script.py | 54 +++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 examples/bge/bge-small-en-v1.5.json create mode 100644 examples/bge/user_script.py diff --git a/examples/bge/bge-small-en-v1.5.json b/examples/bge/bge-small-en-v1.5.json new file mode 100644 index 000000000..619ad2d84 --- /dev/null +++ b/examples/bge/bge-small-en-v1.5.json @@ -0,0 +1,49 @@ +{ + "input_model": { + "type": "HfModel", + "model_path": "BAAI/bge-small-en-v1.5", + "task": "feature-extraction" + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "accelerators": [ + { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } + ] + } + }, + "data_configs": [ + ], + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "accuracy", + "type": "custom", + "sub_types": [ + { + "name": "accuracy_custom", + "priority": 1, + "higher_is_better": true, + "goal": { "type": "max-degradation", "value": 0.05 } + } + ], + "user_config": { + "user_script": "user_script.py", + "evaluate_func": "eval_accuracy", + "evaluate_func_kwargs": { "tasks": [ "Banking77Classification" ] } + } + } + ] + } + }, + "passes": { + "conversion": { "type": "OnnxConversion", "target_opset": 13 } + }, + "evaluator": "common_evaluator", + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_dir": "models/bge-small-en-v1.5", + "evaluate_input_model": true +} diff --git a/examples/bge/user_script.py b/examples/bge/user_script.py new file mode 100644 index 000000000..f91040c43 --- /dev/null +++ b/examples/bge/user_script.py @@ -0,0 +1,54 @@ +from olive.model import OliveModelHandler, HfModelHandler +from olive.constants import Framework +from olive.workflows import run as olive_run +import mteb +from typing import List +from transformers import AutoTokenizer, BertModel +import numpy as np +import json +from pathlib import Path + +class OliveEncoder: + def __init__(self, model, session): + self.model = model + self.session = session + self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5') + + def encode(self, corpus: List, **kwargs): + if self.model.framework == Framework.ONNX: + encoded_input = self.tokenizer(corpus, padding=True, truncation=True, return_tensors='np') + model_inputs = { + "input_ids": encoded_input.input_ids.astype(np.int64), + "attention_mask": encoded_input.attention_mask.astype(np.int64), + "token_type_ids": encoded_input.token_type_ids.astype(np.int64) + } + model_output = self.model.run_session(self.session, model_inputs)[0] + elif self.model.framework == Framework.PYTORCH: + model: HfModelHandler = self.model + session: BertModel = self.session # BertLMHeadModel + encoded_input = self.tokenizer(corpus, padding=True, truncation=True, return_tensors='pt') + model_inputs = { + "input_ids": encoded_input.input_ids, + "attention_mask": encoded_input.attention_mask, + "token_type_ids": encoded_input.token_type_ids + } + model_output = model.run_session(session, model_inputs) + model_output = model_output.last_hidden_state.detach().numpy() + # select the last hidden state of the first token (i.e., [CLS]) as the sentence embedding. + model_output = model_output[:, 0, :] + return model_output + + +def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks): + sess = model.prepare_session(inference_settings=None, device=device, execution_providers=execution_providers) + + evaluation = mteb.MTEB(tasks=tasks) + oliveEncoder = OliveEncoder(model, sess) + results = evaluation.run(oliveEncoder, output_folder=None) + return results[0].scores.test[0].main_score + + +if __name__ == "__main__": + with Path("bge-small-en-v1.5.json").open() as fin: + olive_config = json.load(fin) + olive_run(olive_config) \ No newline at end of file From f6fdc65a5be50d710416917b0ed43143ca93f685 Mon Sep 17 00:00:00 2001 From: hualxie Date: Thu, 13 Feb 2025 19:47:27 +0800 Subject: [PATCH 02/14] 0.8574675324675324 --- examples/bge/bge-small-en-v1.5.json | 2 +- examples/bge/user_script.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/bge/bge-small-en-v1.5.json b/examples/bge/bge-small-en-v1.5.json index 619ad2d84..a899daeac 100644 --- a/examples/bge/bge-small-en-v1.5.json +++ b/examples/bge/bge-small-en-v1.5.json @@ -38,7 +38,7 @@ } }, "passes": { - "conversion": { "type": "OnnxConversion", "target_opset": 13 } + "conversion": { "type": "OnnxConversion", "target_opset": 17 } }, "evaluator": "common_evaluator", "host": "local_system", diff --git a/examples/bge/user_script.py b/examples/bge/user_script.py index f91040c43..f3d3aef65 100644 --- a/examples/bge/user_script.py +++ b/examples/bge/user_script.py @@ -6,13 +6,16 @@ from transformers import AutoTokenizer, BertModel import numpy as np import json +import torch from pathlib import Path + class OliveEncoder: def __init__(self, model, session): self.model = model self.session = session self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5') + self.total = 0 def encode(self, corpus: List, **kwargs): if self.model.framework == Framework.ONNX: @@ -24,18 +27,19 @@ def encode(self, corpus: List, **kwargs): } model_output = self.model.run_session(self.session, model_inputs)[0] elif self.model.framework == Framework.PYTORCH: - model: HfModelHandler = self.model - session: BertModel = self.session # BertLMHeadModel encoded_input = self.tokenizer(corpus, padding=True, truncation=True, return_tensors='pt') model_inputs = { "input_ids": encoded_input.input_ids, "attention_mask": encoded_input.attention_mask, "token_type_ids": encoded_input.token_type_ids } - model_output = model.run_session(session, model_inputs) - model_output = model_output.last_hidden_state.detach().numpy() + with torch.no_grad(): + model_output = self.model.run_session(self.session, model_inputs) + model_output = model_output.last_hidden_state.numpy() # select the last hidden state of the first token (i.e., [CLS]) as the sentence embedding. model_output = model_output[:, 0, :] + self.total += len(corpus) + print(self.total) return model_output @@ -45,7 +49,7 @@ def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks): evaluation = mteb.MTEB(tasks=tasks) oliveEncoder = OliveEncoder(model, sess) results = evaluation.run(oliveEncoder, output_folder=None) - return results[0].scores.test[0].main_score + return results[0].scores["test"][0]["main_score"] if __name__ == "__main__": From b7c555d1f79c18c4a787a12d07c45ec546230883 Mon Sep 17 00:00:00 2001 From: hualxie Date: Fri, 14 Feb 2025 11:26:09 +0800 Subject: [PATCH 03/14] 0.5321753246753247 --- examples/bge/bge-small-en-v1.5.json | 23 ++++++++++++++++++++++- examples/bge/user_script.py | 3 +++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/examples/bge/bge-small-en-v1.5.json b/examples/bge/bge-small-en-v1.5.json index a899daeac..923e2acc9 100644 --- a/examples/bge/bge-small-en-v1.5.json +++ b/examples/bge/bge-small-en-v1.5.json @@ -13,6 +13,17 @@ } }, "data_configs": [ + { + "name": "quantize_data_config", + "type": "HuggingfaceContainer", + "load_dataset_config": { "data_name": "mteb/banking77", "split": "test" }, + "pre_process_data_config": { + "max_length": 384, + "padding": "max_length", + "input_cols": [ "text" ] + }, + "dataloader_config": { "batch_size": 1 } + } ], "evaluators": { "common_evaluator": { @@ -38,7 +49,17 @@ } }, "passes": { - "conversion": { "type": "OnnxConversion", "target_opset": 17 } + "conversion": { "type": "OnnxConversion", "target_opset": 17 }, + "QNNPreprocess": { "type": "QNNPreprocess", "fuse_layernorm": true }, + "OnnxQuantization": { + "type": "OnnxQuantization", + "data_config": "quantize_data_config", + "activation_type": "QUInt16", + "weight_type": "QUInt8", + "calibrate_method": "MinMax", + "quant_preprocess": true, + "prepare_qnn_config": true + } }, "evaluator": "common_evaluator", "host": "local_system", diff --git a/examples/bge/user_script.py b/examples/bge/user_script.py index f3d3aef65..dbf7f241a 100644 --- a/examples/bge/user_script.py +++ b/examples/bge/user_script.py @@ -16,6 +16,7 @@ def __init__(self, model, session): self.session = session self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5') self.total = 0 + self.max_len = 0 def encode(self, corpus: List, **kwargs): if self.model.framework == Framework.ONNX: @@ -40,6 +41,8 @@ def encode(self, corpus: List, **kwargs): model_output = model_output[:, 0, :] self.total += len(corpus) print(self.total) + self.max_len = max(self.max_len, model_output.shape[1]) + print(self.max_len) return model_output From 72f3ef4ae5df7453c056747331000c26c9694011 Mon Sep 17 00:00:00 2001 From: hualxie Date: Thu, 20 Feb 2025 12:06:32 +0800 Subject: [PATCH 04/14] add DynamicToFixedShape --- examples/bge/bge-small-en-v1.5.json | 29 +++++++++++- examples/bge/readme.md | 3 ++ examples/bge/user_script.py | 69 ++++++++++++++++++++++++----- 3 files changed, 89 insertions(+), 12 deletions(-) create mode 100644 examples/bge/readme.md diff --git a/examples/bge/bge-small-en-v1.5.json b/examples/bge/bge-small-en-v1.5.json index 923e2acc9..efc736e03 100644 --- a/examples/bge/bge-small-en-v1.5.json +++ b/examples/bge/bge-small-en-v1.5.json @@ -16,13 +16,29 @@ { "name": "quantize_data_config", "type": "HuggingfaceContainer", - "load_dataset_config": { "data_name": "mteb/banking77", "split": "test" }, + "load_dataset_config": { + "data_name": "mteb/banking77", + "split": "test" + }, "pre_process_data_config": { - "max_length": 384, + "max_length": 128, "padding": "max_length", "input_cols": [ "text" ] }, "dataloader_config": { "batch_size": 1 } + }, + { + "name": "quantize_data_config_custom", + "type": "HuggingfaceContainer", + "user_script": "user_script.py", + "load_dataset_config": { + "data_name": "mteb/banking77", + "split": "test" + }, + "pre_process_data_config": { + "type": "dataset_pre_process", + "cache_key": "cache" + } } ], "evaluators": { @@ -50,6 +66,14 @@ }, "passes": { "conversion": { "type": "OnnxConversion", "target_opset": 17 }, + "dynamic_shape_to_fixed": { + "type": "DynamicToFixedShape", + "dim_param": [ + "batch_size", + "sequence_length" + ], + "dim_value": [ 1, 128 ] + }, "QNNPreprocess": { "type": "QNNPreprocess", "fuse_layernorm": true }, "OnnxQuantization": { "type": "OnnxQuantization", @@ -61,6 +85,7 @@ "prepare_qnn_config": true } }, + "pass_flows": [ [ "conversion", "dynamic_shape_to_fixed", "QNNPreprocess", "OnnxQuantization" ] ], "evaluator": "common_evaluator", "host": "local_system", "target": "local_system", diff --git a/examples/bge/readme.md b/examples/bge/readme.md new file mode 100644 index 000000000..45e9eece0 --- /dev/null +++ b/examples/bge/readme.md @@ -0,0 +1,3 @@ +# original model accuracy "accuracy-accuracy_custom": 0.8574675324675324 + +# qdq model "accuracy-accuracy_custom": 0.5315909090909091 diff --git a/examples/bge/user_script.py b/examples/bge/user_script.py index dbf7f241a..0e09375a3 100644 --- a/examples/bge/user_script.py +++ b/examples/bge/user_script.py @@ -8,7 +8,7 @@ import json import torch from pathlib import Path - +from olive.data.registry import Registry class OliveEncoder: def __init__(self, model, session): @@ -20,13 +20,18 @@ def __init__(self, model, session): def encode(self, corpus: List, **kwargs): if self.model.framework == Framework.ONNX: - encoded_input = self.tokenizer(corpus, padding=True, truncation=True, return_tensors='np') - model_inputs = { - "input_ids": encoded_input.input_ids.astype(np.int64), - "attention_mask": encoded_input.attention_mask.astype(np.int64), - "token_type_ids": encoded_input.token_type_ids.astype(np.int64) - } - model_output = self.model.run_session(self.session, model_inputs)[0] + encoded_input = self.tokenizer(corpus, padding="max_length", max_length=128, truncation=True, return_tensors='np') + # batch_size is 1 for static model + model_outputs = [] + for i in range(len(corpus)): + model_inputs = { + "input_ids": encoded_input.input_ids[i:i+1,:].astype(np.int64), + "attention_mask": encoded_input.attention_mask[i:i+1,:].astype(np.int64), + "token_type_ids": encoded_input.token_type_ids[i:i+1,:].astype(np.int64) + } + model_output = self.model.run_session(self.session, model_inputs)[0] + model_outputs.append(model_output[0]) + model_output = np.array(model_outputs) elif self.model.framework == Framework.PYTORCH: encoded_input = self.tokenizer(corpus, padding=True, truncation=True, return_tensors='pt') model_inputs = { @@ -34,6 +39,8 @@ def encode(self, corpus: List, **kwargs): "attention_mask": encoded_input.attention_mask, "token_type_ids": encoded_input.token_type_ids } + self.max_len = max(self.max_len, model_inputs["input_ids"].shape[1]) + print(self.max_len) with torch.no_grad(): model_output = self.model.run_session(self.session, model_inputs) model_output = model_output.last_hidden_state.numpy() @@ -41,8 +48,6 @@ def encode(self, corpus: List, **kwargs): model_output = model_output[:, 0, :] self.total += len(corpus) print(self.total) - self.max_len = max(self.max_len, model_output.shape[1]) - print(self.max_len) return model_output @@ -55,6 +60,50 @@ def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks): return results[0].scores["test"][0]["main_score"] +class DataLoader: + def __init__(self, data): + self.input_ids = torch.from_numpy(data["input_ids"]) + self.attention_mask = torch.from_numpy(data["attention_mask"]) + self.token_type_ids = torch.from_numpy(data["token_type_ids"]) + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + data = {"input_ids": self.input_ids[idx],"attention_mask": self.attention_mask[idx],"token_type_ids": self.token_type_ids[idx]} + return data + + +@Registry.register_pre_process() +def dataset_pre_process(output_data, **kwargs): + cache_key = kwargs.get("cache_key") + cache_file = None + if cache_key: + cache_file = Path(f"{cache_key}.npz") + if cache_file.exists(): + with np.load(Path(cache_file)) as data: + return DataLoader(data) + + tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5') + input_ids = [] + attention_mask = [] + token_type_ids = [] + for item in output_data: + encoded_input = tokenizer(item['text'], padding="max_length", max_length=128, truncation=True, return_tensors='np') + input_ids.append(encoded_input.input_ids[0].astype(np.int64)) + attention_mask.append(encoded_input.attention_mask[0].astype(np.int64)) + token_type_ids.append(encoded_input.token_type_ids[0].astype(np.int64)) + + data = {"input_ids": np.array(input_ids), "attention_mask": np.array(attention_mask), "token_type_ids": np.array(token_type_ids)} + result_data = DataLoader(data) + + if cache_file: + cache_file.parent.resolve().mkdir(parents=True, exist_ok=True) + np.savez(cache_file, **data) + + return result_data + + if __name__ == "__main__": with Path("bge-small-en-v1.5.json").open() as fin: olive_config = json.load(fin) From bece5571a046f2150ca5fbdcffed5de73cac12e8 Mon Sep 17 00:00:00 2001 From: hualxie Date: Thu, 20 Feb 2025 14:57:35 +0800 Subject: [PATCH 05/14] auto qdq --- examples/bge/auto_qdq_debug.py | 190 +++++++++++++++++++++++++++++++++ examples/bge/qdq_config.json | 48 +++++++++ examples/bge/requirements.txt | 1 + 3 files changed, 239 insertions(+) create mode 100644 examples/bge/auto_qdq_debug.py create mode 100644 examples/bge/qdq_config.json create mode 100644 examples/bge/requirements.txt diff --git a/examples/bge/auto_qdq_debug.py b/examples/bge/auto_qdq_debug.py new file mode 100644 index 000000000..e6cefb198 --- /dev/null +++ b/examples/bge/auto_qdq_debug.py @@ -0,0 +1,190 @@ +import argparse +import onnx +from onnx import ModelProto +from onnxruntime.quantization.qdq_loss_debug import ( + collect_activations, compute_activation_error, + _add_pre_post_qdq_pair, + modify_model_output_intermediate_tensors) +from olive.workflows import run as olive_run +import numpy +from pathlib import Path +import json +import networkx +from collections import defaultdict, deque +from typing import Dict, Sequence, Optional +from olive.data.registry import Registry +from transformers import AutoTokenizer +from onnxruntime.quantization.quantize import CalibrationDataReader + +text = "How do I locate my card?" + +class DataReader(CalibrationDataReader): + def __init__(self): + tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5') + encoded_input = tokenizer(text, padding="max_length", max_length=128, truncation=True, return_tensors='np') + model_input = { + "input_ids": encoded_input.input_ids.astype(numpy.int64), + "attention_mask": encoded_input.attention_mask.astype(numpy.int64), + "token_type_ids": encoded_input.token_type_ids.astype(numpy.int64) + } + self.data = [model_input] + self.id = 0 + + def get_next(self): + if self.id >= len(self.data): return None + self.id += 1 + return self.data[self.id - 1] + + def rewind(self): + self.id = 0 + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--float_model", type=str, default="original_model.onnx", help="Path to original floating point model" + ) + parser.add_argument("--qdq_config", type=str, default="qdq_config.json", help="Path to qdq config") + parser.add_argument("--error", type=float, default=10, help="Error to exclude") + args = parser.parse_args() + return args + + +def _generate_aug_model_path(model_path: str) -> str: + aug_model_path = ( + model_path[: -len(".onnx")] if model_path.endswith(".onnx") else model_path + ) + return aug_model_path + ".save_tensors.onnx" + + +def bfs_nodes(model: ModelProto): + G = networkx.DiGraph() + inputs = set([inp.name for inp in model.graph.input]) + outputs = set([out.name for out in model.graph.output]) + for node in model.graph.node: + for input_name in node.input: + if "Constant_" in input_name: continue + if input_name in inputs or "_output_" in input_name or input_name in outputs: + G.add_edge(input_name, node.name) + for output_name in node.output: + if "Constant_" in output_name: continue + if output_name in outputs or "_output_" in output_name: + G.add_edge(node.name, output_name) + + levels = [] + visited = set() + queue = deque([(inp.name, 0) for inp in model.graph.input]) + input_count = defaultdict(int) + + for node in G.nodes: + input_count[node] = G.in_degree(node) + + while queue: + node, level = queue.popleft() + if node not in visited: + visited.add(node) + if len(levels) <= level: + levels.append([]) + levels[level].append(node) + for neighbor in G.successors(node): + input_count[neighbor] -= 1 + if input_count[neighbor] == 0: + queue.append((neighbor, level + 1)) + return levels + + +def augment_collect(model_path: str, input_data_reader, augment_model_path: str = None) -> Dict[str, numpy.ndarray]: + print(f"augment_collect {model_path}") + input_data_reader.rewind() + augment_model_path = _generate_aug_model_path(model_path) if augment_model_path is None else augment_model_path + modify_model_output_intermediate_tensors(model_path, augment_model_path) + return collect_activations(augment_model_path, input_data_reader) + + +def create_activation_matching( + qdq_activations: Dict[str, Sequence[numpy.ndarray]], + float_activations: Optional[Dict[str, Sequence[numpy.ndarray]]] = None, +) -> Dict[str, Dict[str, Sequence[numpy.ndarray]]]: + """Comparing activation values to help debugging accuracy loss due to quantization. + + This functions takes saved activations from the QDQ model and (optionally) the + float point model, and provides a data structure for comparing: + * from the qdq model, activation values before and after QDQ operation + * across both models, activations from the orignal model vs the corresponding + activations in the QDQ model + + Arg: + qdq_activations: Output of `collect_activations`. This must be from a quantized + model with QDQ format. + float_activations: Output of `collect_activations`. This must be from the float + point model. + + Returns: + Dict for comparing pre and post quantized activation tensors. E.g. + ``` + qdq_cmp = cmp_qdq_input_output(qdq_activations) + print(qdq_cmp['activation1']['pre_qdq'][0]) + print(qdq_cmp['activation1'][`post_qdq'][0]) + + + qdq_cmp = cmp_qdq_input_output(qdq_activations, float_activations) + print(qdq_cmp['activation1']['float'][0]) + print(qdq_cmp['activation1']['pre_qdq'][0]) + print(qdq_cmp['activation1'][`post_qdq'][0]) + ``` + """ + + qdq_cmp: Dict[str, Dict[str, Sequence[numpy.ndarray]]] = {} + for tensor_name, tensors in qdq_activations.items(): + pre_name = tensor_name + pre_qdq_tensors = qdq_activations.get(pre_name) + post_qdq_tensors = tensors + _add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors) + + if not float_activations: + return qdq_cmp + + for act_name, act_values in qdq_cmp.items(): + float_acts = float_activations.get(act_name) + if float_acts is not None: + act_values["float"] = float_acts + + return qdq_cmp + + +def compare_get(qdq_activations, float_activations, error: float, level_nodes: list[list[str]]): + print("Comparing activations of float model vs qdq model......") + act_matching = create_activation_matching(qdq_activations, float_activations) + act_error = compute_activation_error(act_matching) + return None + +def main(): + # Process input parameters and setup model input data reader + args = get_args() + float_model_path = args.float_model + + with Path(args.qdq_config).open() as fin: + olive_config = json.load(fin) + qdq_model_path = Path(olive_config["output_dir"]) / "output_model" / "model.onnx" + + level_nodes = bfs_nodes(onnx.load(float_model_path)) + data_reader = DataReader() + float_activations = augment_collect(float_model_path, data_reader) + + while True: + olive_run(olive_config) + qdq_activations = augment_collect(qdq_model_path, data_reader) + error_node = compare_get(qdq_activations, float_activations, args.error, level_nodes) + if error_node is None: + print("No error node found") + break + if error_node in olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"]: + print(f"{error_node} is already excluded") + break + print(f"Error node: {error_node}") + olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"].append(error_node) + + json.dump(olive_config, (Path(args.qdq_config) / ".final.json").open("w")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/bge/qdq_config.json b/examples/bge/qdq_config.json new file mode 100644 index 000000000..a9b381dc0 --- /dev/null +++ b/examples/bge/qdq_config.json @@ -0,0 +1,48 @@ +{ + "input_model": { + "type": "OnnxModel", + "model_path": "original_model.onnx" + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "accelerators": [ + { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } + ] + } + }, + "data_configs": [ + { + "name": "quantize_data_config", + "type": "HuggingfaceContainer", + "load_dataset_config": { + "data_name": "mteb/banking77", + "split": "test" + }, + "pre_process_data_config": { + "max_length": 128, + "padding": "max_length", + "input_cols": [ "text" ], + "max_samples": 1 + }, + "dataloader_config": { "batch_size": 1 } + } + ], + "passes": { + "OnnxQuantization": { + "type": "OnnxQuantization", + "data_config": "quantize_data_config", + "activation_type": "QUInt16", + "weight_type": "QUInt8", + "calibrate_method": "MinMax", + "quant_preprocess": true, + "prepare_qnn_config": true, + "nodes_to_exclude": [] + } + }, + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_dir": "models/temp", + "evaluate_input_model": true +} diff --git a/examples/bge/requirements.txt b/examples/bge/requirements.txt new file mode 100644 index 000000000..fd146d8f6 --- /dev/null +++ b/examples/bge/requirements.txt @@ -0,0 +1 @@ +mteb From 4bcea27b5ba2cfea19508f282bf674e57dd2f2c1 Mon Sep 17 00:00:00 2001 From: hualxie Date: Thu, 20 Feb 2025 15:57:58 +0800 Subject: [PATCH 06/14] fix --- examples/bge/auto_qdq_debug.py | 97 +++++++++++----------------------- examples/bge/qdq_config.json | 22 ++++---- examples/bge/readme.md | 2 + examples/bge/user_script.py | 3 ++ 4 files changed, 46 insertions(+), 78 deletions(-) diff --git a/examples/bge/auto_qdq_debug.py b/examples/bge/auto_qdq_debug.py index e6cefb198..91f84aa8d 100644 --- a/examples/bge/auto_qdq_debug.py +++ b/examples/bge/auto_qdq_debug.py @@ -2,7 +2,7 @@ import onnx from onnx import ModelProto from onnxruntime.quantization.qdq_loss_debug import ( - collect_activations, compute_activation_error, + collect_activations, compute_signal_to_quantization_noice_ratio, _add_pre_post_qdq_pair, modify_model_output_intermediate_tensors) from olive.workflows import run as olive_run @@ -40,11 +40,8 @@ def rewind(self): def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--float_model", type=str, default="original_model.onnx", help="Path to original floating point model" - ) parser.add_argument("--qdq_config", type=str, default="qdq_config.json", help="Path to qdq config") - parser.add_argument("--error", type=float, default=10, help="Error to exclude") + parser.add_argument("--error", type=float, default=20, help="Error to exclude") args = parser.parse_args() return args @@ -60,9 +57,12 @@ def bfs_nodes(model: ModelProto): G = networkx.DiGraph() inputs = set([inp.name for inp in model.graph.input]) outputs = set([out.name for out in model.graph.output]) + inits = set([init.name for init in model.graph.initializer]) + for node in model.graph.node: for input_name in node.input: if "Constant_" in input_name: continue + if input_name in inits: continue if input_name in inputs or "_output_" in input_name or input_name in outputs: G.add_edge(input_name, node.name) for output_name in node.output: @@ -85,10 +85,12 @@ def bfs_nodes(model: ModelProto): if len(levels) <= level: levels.append([]) levels[level].append(node) + if node in outputs: outputs.remove(node) for neighbor in G.successors(node): input_count[neighbor] -= 1 if input_count[neighbor] == 0: queue.append((neighbor, level + 1)) + assert not outputs return levels @@ -100,73 +102,34 @@ def augment_collect(model_path: str, input_data_reader, augment_model_path: str return collect_activations(augment_model_path, input_data_reader) -def create_activation_matching( - qdq_activations: Dict[str, Sequence[numpy.ndarray]], - float_activations: Optional[Dict[str, Sequence[numpy.ndarray]]] = None, -) -> Dict[str, Dict[str, Sequence[numpy.ndarray]]]: - """Comparing activation values to help debugging accuracy loss due to quantization. - - This functions takes saved activations from the QDQ model and (optionally) the - float point model, and provides a data structure for comparing: - * from the qdq model, activation values before and after QDQ operation - * across both models, activations from the orignal model vs the corresponding - activations in the QDQ model - - Arg: - qdq_activations: Output of `collect_activations`. This must be from a quantized - model with QDQ format. - float_activations: Output of `collect_activations`. This must be from the float - point model. - - Returns: - Dict for comparing pre and post quantized activation tensors. E.g. - ``` - qdq_cmp = cmp_qdq_input_output(qdq_activations) - print(qdq_cmp['activation1']['pre_qdq'][0]) - print(qdq_cmp['activation1'][`post_qdq'][0]) - - - qdq_cmp = cmp_qdq_input_output(qdq_activations, float_activations) - print(qdq_cmp['activation1']['float'][0]) - print(qdq_cmp['activation1']['pre_qdq'][0]) - print(qdq_cmp['activation1'][`post_qdq'][0]) - ``` - """ - - qdq_cmp: Dict[str, Dict[str, Sequence[numpy.ndarray]]] = {} - for tensor_name, tensors in qdq_activations.items(): - pre_name = tensor_name - pre_qdq_tensors = qdq_activations.get(pre_name) - post_qdq_tensors = tensors - _add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors) - - if not float_activations: - return qdq_cmp - - for act_name, act_values in qdq_cmp.items(): - float_acts = float_activations.get(act_name) - if float_acts is not None: - act_values["float"] = float_acts - - return qdq_cmp - - def compare_get(qdq_activations, float_activations, error: float, level_nodes: list[list[str]]): print("Comparing activations of float model vs qdq model......") - act_matching = create_activation_matching(qdq_activations, float_activations) - act_error = compute_activation_error(act_matching) - return None + results = [] + for nodes in level_nodes: + for node in nodes: + qdq_tensor = qdq_activations.get(node) + float_tensor = float_activations.get(node) + if qdq_tensor is None or float_tensor is None: + continue + ratio = compute_signal_to_quantization_noice_ratio(float_tensor, qdq_tensor) + if ratio < error: + print(f"Node {node} has error {ratio}") + index = node.find("_output_") + results.append(node[:index]) + if results: + return results + return results def main(): # Process input parameters and setup model input data reader args = get_args() - float_model_path = args.float_model with Path(args.qdq_config).open() as fin: olive_config = json.load(fin) - qdq_model_path = Path(olive_config["output_dir"]) / "output_model" / "model.onnx" - - level_nodes = bfs_nodes(onnx.load(float_model_path)) + qdq_model_path = olive_config["output_dir"] + "/output_model/model.onnx" + float_model_path = olive_config["input_model"]["model_path"] + model = onnx.load(float_model_path) + level_nodes = bfs_nodes(model) data_reader = DataReader() float_activations = augment_collect(float_model_path, data_reader) @@ -174,16 +137,16 @@ def main(): olive_run(olive_config) qdq_activations = augment_collect(qdq_model_path, data_reader) error_node = compare_get(qdq_activations, float_activations, args.error, level_nodes) - if error_node is None: + if not error_node: print("No error node found") break - if error_node in olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"]: + if set(error_node) & set(olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"]): print(f"{error_node} is already excluded") break print(f"Error node: {error_node}") - olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"].append(error_node) + olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"].extend(error_node) - json.dump(olive_config, (Path(args.qdq_config) / ".final.json").open("w")) + json.dump(olive_config, Path(args.qdq_config + ".final.json").open("w")) if __name__ == "__main__": diff --git a/examples/bge/qdq_config.json b/examples/bge/qdq_config.json index a9b381dc0..8a6ef5530 100644 --- a/examples/bge/qdq_config.json +++ b/examples/bge/qdq_config.json @@ -1,7 +1,7 @@ { "input_model": { "type": "OnnxModel", - "model_path": "original_model.onnx" + "model_path": "models/preprocessed/preprocessed.onnx" }, "systems": { "local_system": { @@ -13,36 +13,36 @@ }, "data_configs": [ { - "name": "quantize_data_config", + "name": "quantize_data_config_custom", "type": "HuggingfaceContainer", + "user_script": "user_script.py", "load_dataset_config": { "data_name": "mteb/banking77", "split": "test" }, "pre_process_data_config": { - "max_length": 128, - "padding": "max_length", - "input_cols": [ "text" ], - "max_samples": 1 - }, - "dataloader_config": { "batch_size": 1 } + "type": "dataset_pre_process", + "cache_key": "qdq_cache", + "max_length": 1 + } } ], "passes": { "OnnxQuantization": { "type": "OnnxQuantization", - "data_config": "quantize_data_config", + "data_config": "quantize_data_config_custom", "activation_type": "QUInt16", "weight_type": "QUInt8", "calibrate_method": "MinMax", "quant_preprocess": true, "prepare_qnn_config": true, - "nodes_to_exclude": [] + "nodes_to_exclude": [], + "op_types_to_quantize": ["Mul", "Transpose", "Unsqueeze", "Add", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape"] } }, "host": "local_system", "target": "local_system", "cache_dir": "cache", - "output_dir": "models/temp", + "output_dir": "models/auto_qdq", "evaluate_input_model": true } diff --git a/examples/bge/readme.md b/examples/bge/readme.md index 45e9eece0..ba80ea7d7 100644 --- a/examples/bge/readme.md +++ b/examples/bge/readme.md @@ -1,3 +1,5 @@ # original model accuracy "accuracy-accuracy_custom": 0.8574675324675324 # qdq model "accuracy-accuracy_custom": 0.5315909090909091 + +"op_types_to_quantize": ["Mul", "Transpose", "Unsqueeze", "Add", "Softmax", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape"] \ No newline at end of file diff --git a/examples/bge/user_script.py b/examples/bge/user_script.py index 0e09375a3..4dd133c5a 100644 --- a/examples/bge/user_script.py +++ b/examples/bge/user_script.py @@ -77,6 +77,7 @@ def __getitem__(self, idx): @Registry.register_pre_process() def dataset_pre_process(output_data, **kwargs): cache_key = kwargs.get("cache_key") + max_length = kwargs.get("max_length") cache_file = None if cache_key: cache_file = Path(f"{cache_key}.npz") @@ -93,6 +94,8 @@ def dataset_pre_process(output_data, **kwargs): input_ids.append(encoded_input.input_ids[0].astype(np.int64)) attention_mask.append(encoded_input.attention_mask[0].astype(np.int64)) token_type_ids.append(encoded_input.token_type_ids[0].astype(np.int64)) + if max_length and len(input_ids) >= max_length: + break data = {"input_ids": np.array(input_ids), "attention_mask": np.array(attention_mask), "token_type_ids": np.array(token_type_ids)} result_data = DataLoader(data) From 87dc67066765e1eaef27510606ac19cac7c8932b Mon Sep 17 00:00:00 2001 From: hualxie Date: Thu, 20 Feb 2025 17:11:35 +0800 Subject: [PATCH 07/14] unsure how to debug --- examples/bge/auto_qdq_debug.py | 6 +++++- examples/bge/qdq_config.json | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/bge/auto_qdq_debug.py b/examples/bge/auto_qdq_debug.py index 91f84aa8d..26ebad219 100644 --- a/examples/bge/auto_qdq_debug.py +++ b/examples/bge/auto_qdq_debug.py @@ -105,17 +105,21 @@ def augment_collect(model_path: str, input_data_reader, augment_model_path: str def compare_get(qdq_activations, float_activations, error: float, level_nodes: list[list[str]]): print("Comparing activations of float model vs qdq model......") results = [] - for nodes in level_nodes: + for i, nodes in enumerate(level_nodes): + ratios = [] for node in nodes: qdq_tensor = qdq_activations.get(node) float_tensor = float_activations.get(node) if qdq_tensor is None or float_tensor is None: continue ratio = compute_signal_to_quantization_noice_ratio(float_tensor, qdq_tensor) + ratios.append((node, ratio)) if ratio < error: print(f"Node {node} has error {ratio}") index = node.find("_output_") results.append(node[:index]) + if ratios: + print(ratios) if results: return results return results diff --git a/examples/bge/qdq_config.json b/examples/bge/qdq_config.json index 8a6ef5530..e88ac9fcb 100644 --- a/examples/bge/qdq_config.json +++ b/examples/bge/qdq_config.json @@ -37,7 +37,7 @@ "quant_preprocess": true, "prepare_qnn_config": true, "nodes_to_exclude": [], - "op_types_to_quantize": ["Mul", "Transpose", "Unsqueeze", "Add", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape"] + "op_types_to_quantize": null } }, "host": "local_system", From bf8c5ac426829f6949f7e650468acc2168c0f55e Mon Sep 17 00:00:00 2001 From: hualxie Date: Fri, 21 Feb 2025 09:37:40 +0800 Subject: [PATCH 08/14] better error --- examples/bge/auto_qdq_debug.py | 72 +++++++++++++++++++++++++++------- examples/bge/qdq_config.json | 4 +- 2 files changed, 59 insertions(+), 17 deletions(-) diff --git a/examples/bge/auto_qdq_debug.py b/examples/bge/auto_qdq_debug.py index 26ebad219..2a59cde22 100644 --- a/examples/bge/auto_qdq_debug.py +++ b/examples/bge/auto_qdq_debug.py @@ -2,8 +2,7 @@ import onnx from onnx import ModelProto from onnxruntime.quantization.qdq_loss_debug import ( - collect_activations, compute_signal_to_quantization_noice_ratio, - _add_pre_post_qdq_pair, + collect_activations, modify_model_output_intermediate_tensors) from olive.workflows import run as olive_run import numpy @@ -11,10 +10,11 @@ import json import networkx from collections import defaultdict, deque -from typing import Dict, Sequence, Optional +from typing import Dict, Sequence, Optional, Union from olive.data.registry import Registry from transformers import AutoTokenizer from onnxruntime.quantization.quantize import CalibrationDataReader +import math text = "How do I locate my card?" @@ -41,7 +41,7 @@ def rewind(self): def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--qdq_config", type=str, default="qdq_config.json", help="Path to qdq config") - parser.add_argument("--error", type=float, default=20, help="Error to exclude") + parser.add_argument("--error", type=float, default=30, help="Error to exclude") args = parser.parse_args() return args @@ -91,7 +91,7 @@ def bfs_nodes(model: ModelProto): if input_count[neighbor] == 0: queue.append((neighbor, level + 1)) assert not outputs - return levels + return levels, G def augment_collect(model_path: str, input_data_reader, augment_model_path: str = None) -> Dict[str, numpy.ndarray]: @@ -102,6 +102,33 @@ def augment_collect(model_path: str, input_data_reader, augment_model_path: str return collect_activations(augment_model_path, input_data_reader) +def compute_signal_to_quantization_noice_ratio( + x: Union[Sequence[numpy.ndarray], numpy.ndarray], y: Union[Sequence[numpy.ndarray], numpy.ndarray] +) -> float: + if isinstance(x, numpy.ndarray): + xlist = [x] + else: + xlist = x + if isinstance(y, numpy.ndarray): + ylist = [y] + else: + ylist = y + if len(xlist) != len(ylist): + raise RuntimeError("Unequal number of tensors to compare!") + + left = numpy.concatenate(xlist).flatten() + right = numpy.concatenate(ylist).flatten() + + epsilon = numpy.finfo("float").eps + tensor_norm = max(numpy.linalg.norm(left), epsilon) + diff_norm = max(numpy.linalg.norm(left - right), epsilon) + + if diff_norm == epsilon: + return 100.0 + + res = tensor_norm / diff_norm + return 20 * math.log10(res) + def compare_get(qdq_activations, float_activations, error: float, level_nodes: list[list[str]]): print("Comparing activations of float model vs qdq model......") results = [] @@ -116,14 +143,27 @@ def compare_get(qdq_activations, float_activations, error: float, level_nodes: l ratios.append((node, ratio)) if ratio < error: print(f"Node {node} has error {ratio}") - index = node.find("_output_") - results.append(node[:index]) + results.append(node) if ratios: print(ratios) if results: return results return results +def get_node(error_output): + index = error_output.find("_output_") + assert index != -1, f"Error output {error_output} does not contain _output_" + return error_output[:index] + +def get_nodes(error_outputs, G: networkx.DiGraph): + nodes = set() + for error_output in error_outputs: + node = get_node(error_output) + nodes.add(node) + for pred in G.predecessors(node): + nodes.add(get_node(pred)) + return nodes + def main(): # Process input parameters and setup model input data reader args = get_args() @@ -133,22 +173,24 @@ def main(): qdq_model_path = olive_config["output_dir"] + "/output_model/model.onnx" float_model_path = olive_config["input_model"]["model_path"] model = onnx.load(float_model_path) - level_nodes = bfs_nodes(model) + level_nodes, G = bfs_nodes(model) data_reader = DataReader() float_activations = augment_collect(float_model_path, data_reader) while True: olive_run(olive_config) qdq_activations = augment_collect(qdq_model_path, data_reader) - error_node = compare_get(qdq_activations, float_activations, args.error, level_nodes) - if not error_node: - print("No error node found") + assert len(float_activations) == len(qdq_activations) + error_outputs = compare_get(qdq_activations, float_activations, args.error, level_nodes) + if not error_outputs: + print("No error outputs found") break - if set(error_node) & set(olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"]): - print(f"{error_node} is already excluded") + error_nodes = get_nodes(error_outputs, G) + if error_nodes & set(olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"]): + print(f"{error_nodes} are already excluded") break - print(f"Error node: {error_node}") - olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"].extend(error_node) + print(f"Error nodes: {error_nodes}") + olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"].extend(error_nodes) json.dump(olive_config, Path(args.qdq_config + ".final.json").open("w")) diff --git a/examples/bge/qdq_config.json b/examples/bge/qdq_config.json index e88ac9fcb..ba6d0ce3f 100644 --- a/examples/bge/qdq_config.json +++ b/examples/bge/qdq_config.json @@ -23,7 +23,7 @@ "pre_process_data_config": { "type": "dataset_pre_process", "cache_key": "qdq_cache", - "max_length": 1 + "max_length": 100 } } ], @@ -34,7 +34,7 @@ "activation_type": "QUInt16", "weight_type": "QUInt8", "calibrate_method": "MinMax", - "quant_preprocess": true, + "quant_preprocess": false, "prepare_qnn_config": true, "nodes_to_exclude": [], "op_types_to_quantize": null From 69fbd766ecf4cdde6385ce0a170e269b8247ffd2 Mon Sep 17 00:00:00 2001 From: hualxie Date: Fri, 21 Feb 2025 10:03:27 +0800 Subject: [PATCH 09/14] update code --- examples/bge/auto_qdq_debug.py | 51 ++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/examples/bge/auto_qdq_debug.py b/examples/bge/auto_qdq_debug.py index 2a59cde22..df4a5a524 100644 --- a/examples/bge/auto_qdq_debug.py +++ b/examples/bge/auto_qdq_debug.py @@ -53,23 +53,31 @@ def _generate_aug_model_path(model_path: str) -> str: return aug_model_path + ".save_tensors.onnx" -def bfs_nodes(model: ModelProto): - G = networkx.DiGraph() - inputs = set([inp.name for inp in model.graph.input]) - outputs = set([out.name for out in model.graph.output]) - inits = set([init.name for init in model.graph.initializer]) +def get_valid_values(model: ModelProto): + valid_values = set() + parents = {} + for inp in model.graph.input: + valid_values.add(inp.name) + for node in model.graph.node: + for output_name in node.output: + valid_values.add(output_name) + parents[output_name] = node.name + return valid_values, parents + +def get_graph(model: ModelProto, valid_values: set[str]) -> networkx.DiGraph: + G = networkx.DiGraph() for node in model.graph.node: for input_name in node.input: - if "Constant_" in input_name: continue - if input_name in inits: continue - if input_name in inputs or "_output_" in input_name or input_name in outputs: + if input_name in valid_values: G.add_edge(input_name, node.name) for output_name in node.output: - if "Constant_" in output_name: continue - if output_name in outputs or "_output_" in output_name: + if output_name in valid_values: G.add_edge(node.name, output_name) - + return G + + +def bfs_nodes(G: networkx.DiGraph, model: ModelProto, valid_values: set[str]) -> list[list[str]]: levels = [] visited = set() queue = deque([(inp.name, 0) for inp in model.graph.input]) @@ -78,6 +86,7 @@ def bfs_nodes(model: ModelProto): for node in G.nodes: input_count[node] = G.in_degree(node) + outputs = set([out.name for out in model.graph.output]) while queue: node, level = queue.popleft() if node not in visited: @@ -91,7 +100,7 @@ def bfs_nodes(model: ModelProto): if input_count[neighbor] == 0: queue.append((neighbor, level + 1)) assert not outputs - return levels, G + return levels def augment_collect(model_path: str, input_data_reader, augment_model_path: str = None) -> Dict[str, numpy.ndarray]: @@ -129,6 +138,7 @@ def compute_signal_to_quantization_noice_ratio( res = tensor_norm / diff_norm return 20 * math.log10(res) + def compare_get(qdq_activations, float_activations, error: float, level_nodes: list[list[str]]): print("Comparing activations of float model vs qdq model......") results = [] @@ -150,20 +160,17 @@ def compare_get(qdq_activations, float_activations, error: float, level_nodes: l return results return results -def get_node(error_output): - index = error_output.find("_output_") - assert index != -1, f"Error output {error_output} does not contain _output_" - return error_output[:index] -def get_nodes(error_outputs, G: networkx.DiGraph): +def get_nodes(G: networkx.DiGraph, error_outputs: list[str], parents: Dict[str, str]) -> set[str]: nodes = set() for error_output in error_outputs: - node = get_node(error_output) + node = parents[error_output] nodes.add(node) for pred in G.predecessors(node): - nodes.add(get_node(pred)) + nodes.add(parents[pred]) return nodes + def main(): # Process input parameters and setup model input data reader args = get_args() @@ -173,7 +180,9 @@ def main(): qdq_model_path = olive_config["output_dir"] + "/output_model/model.onnx" float_model_path = olive_config["input_model"]["model_path"] model = onnx.load(float_model_path) - level_nodes, G = bfs_nodes(model) + valid_values, parents = get_valid_values(model) + G = get_graph(model, valid_values) + level_nodes = bfs_nodes(G, model, valid_values) data_reader = DataReader() float_activations = augment_collect(float_model_path, data_reader) @@ -185,7 +194,7 @@ def main(): if not error_outputs: print("No error outputs found") break - error_nodes = get_nodes(error_outputs, G) + error_nodes = get_nodes(G, error_outputs, parents) if error_nodes & set(olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"]): print(f"{error_nodes} are already excluded") break From e9ab69daec7c166e5b66879020081e994dd8675a Mon Sep 17 00:00:00 2001 From: hualxie Date: Fri, 21 Feb 2025 11:27:34 +0800 Subject: [PATCH 10/14] greedy search --- examples/bge/auto_qdq_debug.py | 2 +- examples/bge/bge-small-en-v1.5.json | 2 +- examples/bge/readme.md | 16 +++++++++++++--- examples/bge/user_script.py | 15 ++++++++++++++- 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/examples/bge/auto_qdq_debug.py b/examples/bge/auto_qdq_debug.py index df4a5a524..5bef1546e 100644 --- a/examples/bge/auto_qdq_debug.py +++ b/examples/bge/auto_qdq_debug.py @@ -201,7 +201,7 @@ def main(): print(f"Error nodes: {error_nodes}") olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"].extend(error_nodes) - json.dump(olive_config, Path(args.qdq_config + ".final.json").open("w")) + json.dump(olive_config, Path(args.qdq_config + ".final.json").open("w"), indent=4) if __name__ == "__main__": diff --git a/examples/bge/bge-small-en-v1.5.json b/examples/bge/bge-small-en-v1.5.json index efc736e03..81885a7bc 100644 --- a/examples/bge/bge-small-en-v1.5.json +++ b/examples/bge/bge-small-en-v1.5.json @@ -91,5 +91,5 @@ "target": "local_system", "cache_dir": "cache", "output_dir": "models/bge-small-en-v1.5", - "evaluate_input_model": true + "evaluate_input_model": false } diff --git a/examples/bge/readme.md b/examples/bge/readme.md index ba80ea7d7..81d06d1f2 100644 --- a/examples/bge/readme.md +++ b/examples/bge/readme.md @@ -1,5 +1,15 @@ -# original model accuracy "accuracy-accuracy_custom": 0.8574675324675324 +# original model accuracy -# qdq model "accuracy-accuracy_custom": 0.5315909090909091 +"accuracy-accuracy_custom": 0.8574675324675324 -"op_types_to_quantize": ["Mul", "Transpose", "Unsqueeze", "Add", "Softmax", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape"] \ No newline at end of file +"op_types_to_quantize": ["Mul", "Transpose", "Unsqueeze", "Add", "Softmax", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape"] + +# QDQ + +All: 0.5315909090909091 + +[ "MatMul", "LayerNormalization", "Gemm", "Gelu" ]: "accuracy-accuracy_custom": 0.8506818181818183 + +[ "Mul", "MatMul", "LayerNormalization", "Gemm", "Gelu" ]: 0.850487012987013 + +[ "Mul", "Transpose", "MatMul", "LayerNormalization", "Gemm", "Gelu" ]: 0.8504870129870131 diff --git a/examples/bge/user_script.py b/examples/bge/user_script.py index 4dd133c5a..ae65ef79d 100644 --- a/examples/bge/user_script.py +++ b/examples/bge/user_script.py @@ -1,6 +1,7 @@ from olive.model import OliveModelHandler, HfModelHandler from olive.constants import Framework from olive.workflows import run as olive_run +from olive.engine.footprint import Footprint, FootprintNode import mteb from typing import List from transformers import AutoTokenizer, BertModel @@ -108,6 +109,18 @@ def dataset_pre_process(output_data, **kwargs): if __name__ == "__main__": + all_ops = ["Mul", "Transpose", "Unsqueeze", "Add", "Softmax", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape"] + target_accuracy = 0.8 with Path("bge-small-en-v1.5.json").open() as fin: olive_config = json.load(fin) - olive_run(olive_config) \ No newline at end of file + for i, op in enumerate(all_ops): + if op in olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]: + continue + olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].append(op) + result = olive_run(olive_config) + footprint: Footprint = next(iter(result.values())) + node: FootprintNode = next(iter(footprint.nodes.values())) + accuracy = node.metrics.value["accuracy-accuracy_custom"].value + print(f"Ops: {olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]} Accuracy: {accuracy}") + if accuracy < target_accuracy: + olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].remove(op) From f2b6304f7b0aebbef471d4465f020147eb59746a Mon Sep 17 00:00:00 2001 From: hualxie Date: Fri, 21 Feb 2025 13:22:19 +0800 Subject: [PATCH 11/14] temp data --- examples/bge/readme.md | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/examples/bge/readme.md b/examples/bge/readme.md index 81d06d1f2..54682c9c7 100644 --- a/examples/bge/readme.md +++ b/examples/bge/readme.md @@ -1,15 +1,12 @@ -# original model accuracy - -"accuracy-accuracy_custom": 0.8574675324675324 - -"op_types_to_quantize": ["Mul", "Transpose", "Unsqueeze", "Add", "Softmax", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape"] - -# QDQ - -All: 0.5315909090909091 - -[ "MatMul", "LayerNormalization", "Gemm", "Gelu" ]: "accuracy-accuracy_custom": 0.8506818181818183 - -[ "Mul", "MatMul", "LayerNormalization", "Gemm", "Gelu" ]: 0.850487012987013 - -[ "Mul", "Transpose", "MatMul", "LayerNormalization", "Gemm", "Gelu" ]: 0.8504870129870131 +## Precision + +| Quantized Ops | precision | +|-|-| +| None (original model) | 0.8574675324675324 | +| All ("Mul", "Transpose", "Unsqueeze", "Add", "Softmax", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape") | 0.5315909090909091 | +| "MatMul", "LayerNormalization", "Gemm", "Gelu" | 0.8506818181818183 | +| "Mul", "MatMul", "LayerNormalization", "Gemm", "Gelu" | 0.850487012987013 | +| "Mul", "Transpose", "MatMul", "LayerNormalization", "Gemm", "Gelu" | 0.8504870129870131 | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze' | 0.8504870129870131 | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Add' | 0.5317207792207792 | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Softmax' | 0.5313961038961039 | From c09c457ce826659bcfb55901278e203df7dd61d7 Mon Sep 17 00:00:00 2001 From: hualxie Date: Fri, 21 Feb 2025 14:36:41 +0800 Subject: [PATCH 12/14] also no Add and Softmax --- examples/bge/bge-small-en-v1.5.json | 5 +++-- examples/bge/readme.md | 2 ++ examples/bge/user_script.py | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/bge/bge-small-en-v1.5.json b/examples/bge/bge-small-en-v1.5.json index 81885a7bc..aaec54476 100644 --- a/examples/bge/bge-small-en-v1.5.json +++ b/examples/bge/bge-small-en-v1.5.json @@ -82,7 +82,8 @@ "weight_type": "QUInt8", "calibrate_method": "MinMax", "quant_preprocess": true, - "prepare_qnn_config": true + "prepare_qnn_config": true, + "op_types_to_quantize": [ "Mul", "Transpose", "MatMul", "LayerNormalization", "Gemm", "Gelu", "Unsqueeze", "Gather", "Sub", "Where", "Expand", "Tanh", "Reshape" ] } }, "pass_flows": [ [ "conversion", "dynamic_shape_to_fixed", "QNNPreprocess", "OnnxQuantization" ] ], @@ -91,5 +92,5 @@ "target": "local_system", "cache_dir": "cache", "output_dir": "models/bge-small-en-v1.5", - "evaluate_input_model": false + "evaluate_input_model": true } diff --git a/examples/bge/readme.md b/examples/bge/readme.md index 54682c9c7..7cdee8491 100644 --- a/examples/bge/readme.md +++ b/examples/bge/readme.md @@ -10,3 +10,5 @@ | 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze' | 0.8504870129870131 | | 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Add' | 0.5317207792207792 | | 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Softmax' | 0.5313961038961039 | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Gather', 'Sub' | 0.8504870129870131 | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Gather', 'Sub', 'Where', 'Expand', 'Tanh', 'Reshape' | 0.8504870129870131 | diff --git a/examples/bge/user_script.py b/examples/bge/user_script.py index ae65ef79d..e96257f81 100644 --- a/examples/bge/user_script.py +++ b/examples/bge/user_script.py @@ -124,3 +124,4 @@ def dataset_pre_process(output_data, **kwargs): print(f"Ops: {olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]} Accuracy: {accuracy}") if accuracy < target_accuracy: olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].remove(op) + print(olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]) From 0995c8d05a6b89a84fe24d5467792c99f03e757e Mon Sep 17 00:00:00 2001 From: hualxie Date: Fri, 21 Feb 2025 15:01:50 +0800 Subject: [PATCH 13/14] nit --- examples/bge/auto_qdq_debug.py | 208 ---------------------------- examples/bge/bge-small-en-v1.5.json | 55 +++----- examples/bge/qdq_config.json | 48 ------- examples/bge/readme.md | 54 +++++++- examples/bge/user_script.py | 104 ++++++-------- 5 files changed, 111 insertions(+), 358 deletions(-) delete mode 100644 examples/bge/auto_qdq_debug.py delete mode 100644 examples/bge/qdq_config.json diff --git a/examples/bge/auto_qdq_debug.py b/examples/bge/auto_qdq_debug.py deleted file mode 100644 index 5bef1546e..000000000 --- a/examples/bge/auto_qdq_debug.py +++ /dev/null @@ -1,208 +0,0 @@ -import argparse -import onnx -from onnx import ModelProto -from onnxruntime.quantization.qdq_loss_debug import ( - collect_activations, - modify_model_output_intermediate_tensors) -from olive.workflows import run as olive_run -import numpy -from pathlib import Path -import json -import networkx -from collections import defaultdict, deque -from typing import Dict, Sequence, Optional, Union -from olive.data.registry import Registry -from transformers import AutoTokenizer -from onnxruntime.quantization.quantize import CalibrationDataReader -import math - -text = "How do I locate my card?" - -class DataReader(CalibrationDataReader): - def __init__(self): - tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5') - encoded_input = tokenizer(text, padding="max_length", max_length=128, truncation=True, return_tensors='np') - model_input = { - "input_ids": encoded_input.input_ids.astype(numpy.int64), - "attention_mask": encoded_input.attention_mask.astype(numpy.int64), - "token_type_ids": encoded_input.token_type_ids.astype(numpy.int64) - } - self.data = [model_input] - self.id = 0 - - def get_next(self): - if self.id >= len(self.data): return None - self.id += 1 - return self.data[self.id - 1] - - def rewind(self): - self.id = 0 - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--qdq_config", type=str, default="qdq_config.json", help="Path to qdq config") - parser.add_argument("--error", type=float, default=30, help="Error to exclude") - args = parser.parse_args() - return args - - -def _generate_aug_model_path(model_path: str) -> str: - aug_model_path = ( - model_path[: -len(".onnx")] if model_path.endswith(".onnx") else model_path - ) - return aug_model_path + ".save_tensors.onnx" - - -def get_valid_values(model: ModelProto): - valid_values = set() - parents = {} - for inp in model.graph.input: - valid_values.add(inp.name) - for node in model.graph.node: - for output_name in node.output: - valid_values.add(output_name) - parents[output_name] = node.name - return valid_values, parents - - -def get_graph(model: ModelProto, valid_values: set[str]) -> networkx.DiGraph: - G = networkx.DiGraph() - for node in model.graph.node: - for input_name in node.input: - if input_name in valid_values: - G.add_edge(input_name, node.name) - for output_name in node.output: - if output_name in valid_values: - G.add_edge(node.name, output_name) - return G - - -def bfs_nodes(G: networkx.DiGraph, model: ModelProto, valid_values: set[str]) -> list[list[str]]: - levels = [] - visited = set() - queue = deque([(inp.name, 0) for inp in model.graph.input]) - input_count = defaultdict(int) - - for node in G.nodes: - input_count[node] = G.in_degree(node) - - outputs = set([out.name for out in model.graph.output]) - while queue: - node, level = queue.popleft() - if node not in visited: - visited.add(node) - if len(levels) <= level: - levels.append([]) - levels[level].append(node) - if node in outputs: outputs.remove(node) - for neighbor in G.successors(node): - input_count[neighbor] -= 1 - if input_count[neighbor] == 0: - queue.append((neighbor, level + 1)) - assert not outputs - return levels - - -def augment_collect(model_path: str, input_data_reader, augment_model_path: str = None) -> Dict[str, numpy.ndarray]: - print(f"augment_collect {model_path}") - input_data_reader.rewind() - augment_model_path = _generate_aug_model_path(model_path) if augment_model_path is None else augment_model_path - modify_model_output_intermediate_tensors(model_path, augment_model_path) - return collect_activations(augment_model_path, input_data_reader) - - -def compute_signal_to_quantization_noice_ratio( - x: Union[Sequence[numpy.ndarray], numpy.ndarray], y: Union[Sequence[numpy.ndarray], numpy.ndarray] -) -> float: - if isinstance(x, numpy.ndarray): - xlist = [x] - else: - xlist = x - if isinstance(y, numpy.ndarray): - ylist = [y] - else: - ylist = y - if len(xlist) != len(ylist): - raise RuntimeError("Unequal number of tensors to compare!") - - left = numpy.concatenate(xlist).flatten() - right = numpy.concatenate(ylist).flatten() - - epsilon = numpy.finfo("float").eps - tensor_norm = max(numpy.linalg.norm(left), epsilon) - diff_norm = max(numpy.linalg.norm(left - right), epsilon) - - if diff_norm == epsilon: - return 100.0 - - res = tensor_norm / diff_norm - return 20 * math.log10(res) - - -def compare_get(qdq_activations, float_activations, error: float, level_nodes: list[list[str]]): - print("Comparing activations of float model vs qdq model......") - results = [] - for i, nodes in enumerate(level_nodes): - ratios = [] - for node in nodes: - qdq_tensor = qdq_activations.get(node) - float_tensor = float_activations.get(node) - if qdq_tensor is None or float_tensor is None: - continue - ratio = compute_signal_to_quantization_noice_ratio(float_tensor, qdq_tensor) - ratios.append((node, ratio)) - if ratio < error: - print(f"Node {node} has error {ratio}") - results.append(node) - if ratios: - print(ratios) - if results: - return results - return results - - -def get_nodes(G: networkx.DiGraph, error_outputs: list[str], parents: Dict[str, str]) -> set[str]: - nodes = set() - for error_output in error_outputs: - node = parents[error_output] - nodes.add(node) - for pred in G.predecessors(node): - nodes.add(parents[pred]) - return nodes - - -def main(): - # Process input parameters and setup model input data reader - args = get_args() - - with Path(args.qdq_config).open() as fin: - olive_config = json.load(fin) - qdq_model_path = olive_config["output_dir"] + "/output_model/model.onnx" - float_model_path = olive_config["input_model"]["model_path"] - model = onnx.load(float_model_path) - valid_values, parents = get_valid_values(model) - G = get_graph(model, valid_values) - level_nodes = bfs_nodes(G, model, valid_values) - data_reader = DataReader() - float_activations = augment_collect(float_model_path, data_reader) - - while True: - olive_run(olive_config) - qdq_activations = augment_collect(qdq_model_path, data_reader) - assert len(float_activations) == len(qdq_activations) - error_outputs = compare_get(qdq_activations, float_activations, args.error, level_nodes) - if not error_outputs: - print("No error outputs found") - break - error_nodes = get_nodes(G, error_outputs, parents) - if error_nodes & set(olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"]): - print(f"{error_nodes} are already excluded") - break - print(f"Error nodes: {error_nodes}") - olive_config["passes"]["OnnxQuantization"]["nodes_to_exclude"].extend(error_nodes) - - json.dump(olive_config, Path(args.qdq_config + ".final.json").open("w"), indent=4) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/bge/bge-small-en-v1.5.json b/examples/bge/bge-small-en-v1.5.json index aaec54476..057aee0f4 100644 --- a/examples/bge/bge-small-en-v1.5.json +++ b/examples/bge/bge-small-en-v1.5.json @@ -1,44 +1,18 @@ { - "input_model": { - "type": "HfModel", - "model_path": "BAAI/bge-small-en-v1.5", - "task": "feature-extraction" - }, + "input_model": { "type": "HfModel", "model_path": "BAAI/bge-small-en-v1.5", "task": "feature-extraction" }, "systems": { "local_system": { "type": "LocalSystem", - "accelerators": [ - { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } - ] + "accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ] } }, "data_configs": [ { "name": "quantize_data_config", "type": "HuggingfaceContainer", - "load_dataset_config": { - "data_name": "mteb/banking77", - "split": "test" - }, - "pre_process_data_config": { - "max_length": 128, - "padding": "max_length", - "input_cols": [ "text" ] - }, + "load_dataset_config": { "data_name": "mteb/banking77", "split": "test" }, + "pre_process_data_config": { "max_length": 128, "padding": "max_length", "input_cols": [ "text" ] }, "dataloader_config": { "batch_size": 1 } - }, - { - "name": "quantize_data_config_custom", - "type": "HuggingfaceContainer", - "user_script": "user_script.py", - "load_dataset_config": { - "data_name": "mteb/banking77", - "split": "test" - }, - "pre_process_data_config": { - "type": "dataset_pre_process", - "cache_key": "cache" - } } ], "evaluators": { @@ -68,10 +42,7 @@ "conversion": { "type": "OnnxConversion", "target_opset": 17 }, "dynamic_shape_to_fixed": { "type": "DynamicToFixedShape", - "dim_param": [ - "batch_size", - "sequence_length" - ], + "dim_param": [ "batch_size", "sequence_length" ], "dim_value": [ 1, 128 ] }, "QNNPreprocess": { "type": "QNNPreprocess", "fuse_layernorm": true }, @@ -83,7 +54,21 @@ "calibrate_method": "MinMax", "quant_preprocess": true, "prepare_qnn_config": true, - "op_types_to_quantize": [ "Mul", "Transpose", "MatMul", "LayerNormalization", "Gemm", "Gelu", "Unsqueeze", "Gather", "Sub", "Where", "Expand", "Tanh", "Reshape" ] + "op_types_to_quantize": [ + "Mul", + "Transpose", + "MatMul", + "LayerNormalization", + "Gemm", + "Gelu", + "Unsqueeze", + "Gather", + "Sub", + "Where", + "Expand", + "Tanh", + "Reshape" + ] } }, "pass_flows": [ [ "conversion", "dynamic_shape_to_fixed", "QNNPreprocess", "OnnxQuantization" ] ], diff --git a/examples/bge/qdq_config.json b/examples/bge/qdq_config.json deleted file mode 100644 index ba6d0ce3f..000000000 --- a/examples/bge/qdq_config.json +++ /dev/null @@ -1,48 +0,0 @@ -{ - "input_model": { - "type": "OnnxModel", - "model_path": "models/preprocessed/preprocessed.onnx" - }, - "systems": { - "local_system": { - "type": "LocalSystem", - "accelerators": [ - { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } - ] - } - }, - "data_configs": [ - { - "name": "quantize_data_config_custom", - "type": "HuggingfaceContainer", - "user_script": "user_script.py", - "load_dataset_config": { - "data_name": "mteb/banking77", - "split": "test" - }, - "pre_process_data_config": { - "type": "dataset_pre_process", - "cache_key": "qdq_cache", - "max_length": 100 - } - } - ], - "passes": { - "OnnxQuantization": { - "type": "OnnxQuantization", - "data_config": "quantize_data_config_custom", - "activation_type": "QUInt16", - "weight_type": "QUInt8", - "calibrate_method": "MinMax", - "quant_preprocess": false, - "prepare_qnn_config": true, - "nodes_to_exclude": [], - "op_types_to_quantize": null - } - }, - "host": "local_system", - "target": "local_system", - "cache_dir": "cache", - "output_dir": "models/auto_qdq", - "evaluate_input_model": true -} diff --git a/examples/bge/readme.md b/examples/bge/readme.md index 7cdee8491..bf9c71959 100644 --- a/examples/bge/readme.md +++ b/examples/bge/readme.md @@ -1,4 +1,17 @@ -## Precision +# BAAI/bge-small-en-v1.5 Optimization + +This folder contains examples of [BAAI/bge-small-en-v1.5 ](https://huggingface.co/BAAI/bge-small-en-v1.5) optimization using different workflows. + +- NPU: [Optimization with PTQ using QNN EP](#ptq-using-qnn-ep) + +## Optimization Workflows + +### PTQ using QNN EP + +This workflow performs the optimization pipeline: +- *PyTorch Model -> Onnx Model -> Static shaped Onnx Model -> Quantized Onnx Model* + +The precision will drop when Add or Softmax types of op are quantized, so they are not included. | Quantized Ops | precision | |-|-| @@ -10,5 +23,42 @@ | 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze' | 0.8504870129870131 | | 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Add' | 0.5317207792207792 | | 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Softmax' | 0.5313961038961039 | -| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Gather', 'Sub' | 0.8504870129870131 | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Gather' | 0.8504870129870131 | +| ... | ... | | 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Gather', 'Sub', 'Where', 'Expand', 'Tanh', 'Reshape' | 0.8504870129870131 | + +## How to run +### Pip requirements +Install the necessary python packages: +```sh +# [NPU] +pip install git+https://github.com/microsoft/Olive#egg=olive-ai[qnn] +``` + +### Other dependencies +```sh +python -m pip install -r requirements.txt +``` + +### Run sample using config + +The optimization techniques to run are specified in the relevant config json file. + +First, install required packages according to passes. +```sh +olive run --config .json --setup +``` + +Then, optimize the model +```sh +olive run --config .json +``` + +or run simply with python code: +```python +from olive.workflows import run as olive_run +olive_run(".json") +``` + +After running the above command, the model candidates and corresponding config will be saved in the output directory. +You can then select the best model and config from the candidates and run the model with the selected config. diff --git a/examples/bge/user_script.py b/examples/bge/user_script.py index e96257f81..b16baff02 100644 --- a/examples/bge/user_script.py +++ b/examples/bge/user_script.py @@ -1,44 +1,48 @@ -from olive.model import OliveModelHandler, HfModelHandler -from olive.constants import Framework -from olive.workflows import run as olive_run -from olive.engine.footprint import Footprint, FootprintNode -import mteb +import json +from pathlib import Path from typing import List -from transformers import AutoTokenizer, BertModel + +import mteb import numpy as np -import json import torch -from pathlib import Path -from olive.data.registry import Registry +from transformers import AutoTokenizer + +from olive.constants import Framework +from olive.engine.footprint import Footprint, FootprintNode +from olive.model import OliveModelHandler +from olive.workflows import run as olive_run + class OliveEncoder: def __init__(self, model, session): self.model = model self.session = session - self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5') + self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5") self.total = 0 self.max_len = 0 def encode(self, corpus: List, **kwargs): if self.model.framework == Framework.ONNX: - encoded_input = self.tokenizer(corpus, padding="max_length", max_length=128, truncation=True, return_tensors='np') + encoded_input = self.tokenizer( + corpus, padding="max_length", max_length=128, truncation=True, return_tensors="np" + ) # batch_size is 1 for static model model_outputs = [] for i in range(len(corpus)): model_inputs = { - "input_ids": encoded_input.input_ids[i:i+1,:].astype(np.int64), - "attention_mask": encoded_input.attention_mask[i:i+1,:].astype(np.int64), - "token_type_ids": encoded_input.token_type_ids[i:i+1,:].astype(np.int64) + "input_ids": encoded_input.input_ids[i : i + 1, :].astype(np.int64), + "attention_mask": encoded_input.attention_mask[i : i + 1, :].astype(np.int64), + "token_type_ids": encoded_input.token_type_ids[i : i + 1, :].astype(np.int64), } model_output = self.model.run_session(self.session, model_inputs)[0] model_outputs.append(model_output[0]) model_output = np.array(model_outputs) elif self.model.framework == Framework.PYTORCH: - encoded_input = self.tokenizer(corpus, padding=True, truncation=True, return_tensors='pt') + encoded_input = self.tokenizer(corpus, padding=True, truncation=True, return_tensors="pt") model_inputs = { "input_ids": encoded_input.input_ids, "attention_mask": encoded_input.attention_mask, - "token_type_ids": encoded_input.token_type_ids + "token_type_ids": encoded_input.token_type_ids, } self.max_len = max(self.max_len, model_inputs["input_ids"].shape[1]) print(self.max_len) @@ -61,55 +65,25 @@ def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks): return results[0].scores["test"][0]["main_score"] -class DataLoader: - def __init__(self, data): - self.input_ids = torch.from_numpy(data["input_ids"]) - self.attention_mask = torch.from_numpy(data["attention_mask"]) - self.token_type_ids = torch.from_numpy(data["token_type_ids"]) - - def __len__(self): - return len(self.input_ids) - - def __getitem__(self, idx): - data = {"input_ids": self.input_ids[idx],"attention_mask": self.attention_mask[idx],"token_type_ids": self.token_type_ids[idx]} - return data - - -@Registry.register_pre_process() -def dataset_pre_process(output_data, **kwargs): - cache_key = kwargs.get("cache_key") - max_length = kwargs.get("max_length") - cache_file = None - if cache_key: - cache_file = Path(f"{cache_key}.npz") - if cache_file.exists(): - with np.load(Path(cache_file)) as data: - return DataLoader(data) - - tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5') - input_ids = [] - attention_mask = [] - token_type_ids = [] - for item in output_data: - encoded_input = tokenizer(item['text'], padding="max_length", max_length=128, truncation=True, return_tensors='np') - input_ids.append(encoded_input.input_ids[0].astype(np.int64)) - attention_mask.append(encoded_input.attention_mask[0].astype(np.int64)) - token_type_ids.append(encoded_input.token_type_ids[0].astype(np.int64)) - if max_length and len(input_ids) >= max_length: - break - - data = {"input_ids": np.array(input_ids), "attention_mask": np.array(attention_mask), "token_type_ids": np.array(token_type_ids)} - result_data = DataLoader(data) - - if cache_file: - cache_file.parent.resolve().mkdir(parents=True, exist_ok=True) - np.savez(cache_file, **data) - - return result_data - - if __name__ == "__main__": - all_ops = ["Mul", "Transpose", "Unsqueeze", "Add", "Softmax", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape"] + # Greedy search for the best combination of ops to quantize + all_ops = [ + "Mul", + "Transpose", + "Unsqueeze", + "Add", + "Softmax", + "Gelu", + "LayerNormalization", + "Gather", + "MatMul", + "Sub", + "Where", + "Expand", + "Gemm", + "Tanh", + "Reshape", + ] target_accuracy = 0.8 with Path("bge-small-en-v1.5.json").open() as fin: olive_config = json.load(fin) @@ -124,4 +98,4 @@ def dataset_pre_process(output_data, **kwargs): print(f"Ops: {olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]} Accuracy: {accuracy}") if accuracy < target_accuracy: olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].remove(op) - print(olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]) + print(f"Used Ops: {olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]}") From 93ae4c48aa440ae81167b063ac37febde7b3f117 Mon Sep 17 00:00:00 2001 From: hualxie Date: Fri, 21 Feb 2025 15:12:20 +0800 Subject: [PATCH 14/14] linter --- ....5.json => bge-small-en-v1.5_ptq_qnn.json} | 0 examples/bge/user_script.py | 28 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) rename examples/bge/{bge-small-en-v1.5.json => bge-small-en-v1.5_ptq_qnn.json} (100%) diff --git a/examples/bge/bge-small-en-v1.5.json b/examples/bge/bge-small-en-v1.5_ptq_qnn.json similarity index 100% rename from examples/bge/bge-small-en-v1.5.json rename to examples/bge/bge-small-en-v1.5_ptq_qnn.json diff --git a/examples/bge/user_script.py b/examples/bge/user_script.py index b16baff02..fa4dee622 100644 --- a/examples/bge/user_script.py +++ b/examples/bge/user_script.py @@ -18,8 +18,6 @@ def __init__(self, model, session): self.model = model self.session = session self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5") - self.total = 0 - self.max_len = 0 def encode(self, corpus: List, **kwargs): if self.model.framework == Framework.ONNX: @@ -44,28 +42,30 @@ def encode(self, corpus: List, **kwargs): "attention_mask": encoded_input.attention_mask, "token_type_ids": encoded_input.token_type_ids, } - self.max_len = max(self.max_len, model_inputs["input_ids"].shape[1]) - print(self.max_len) with torch.no_grad(): model_output = self.model.run_session(self.session, model_inputs) model_output = model_output.last_hidden_state.numpy() # select the last hidden state of the first token (i.e., [CLS]) as the sentence embedding. - model_output = model_output[:, 0, :] - self.total += len(corpus) - print(self.total) - return model_output + return model_output[:, 0, :] def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks): sess = model.prepare_session(inference_settings=None, device=device, execution_providers=execution_providers) evaluation = mteb.MTEB(tasks=tasks) - oliveEncoder = OliveEncoder(model, sess) - results = evaluation.run(oliveEncoder, output_folder=None) + olive_encoder = OliveEncoder(model, sess) + results = evaluation.run(olive_encoder, output_folder=None) return results[0].scores["test"][0]["main_score"] if __name__ == "__main__": + import logging + import sys + + logger = logging.getLogger("bge") + logger.addHandler(logging.StreamHandler(sys.stdout)) + logger.setLevel(logging.INFO) + # Greedy search for the best combination of ops to quantize all_ops = [ "Mul", @@ -87,7 +87,7 @@ def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks): target_accuracy = 0.8 with Path("bge-small-en-v1.5.json").open() as fin: olive_config = json.load(fin) - for i, op in enumerate(all_ops): + for op in all_ops: if op in olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]: continue olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].append(op) @@ -95,7 +95,9 @@ def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks): footprint: Footprint = next(iter(result.values())) node: FootprintNode = next(iter(footprint.nodes.values())) accuracy = node.metrics.value["accuracy-accuracy_custom"].value - print(f"Ops: {olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]} Accuracy: {accuracy}") + logger.info( + "Ops: %s Accuracy: %f", olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"], accuracy + ) if accuracy < target_accuracy: olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].remove(op) - print(f"Used Ops: {olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]}") + logger.info("Final Ops: %s", olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"])