diff --git a/examples/bge/bge-small-en-v1.5_ptq_qnn.json b/examples/bge/bge-small-en-v1.5_ptq_qnn.json new file mode 100644 index 000000000..057aee0f4 --- /dev/null +++ b/examples/bge/bge-small-en-v1.5_ptq_qnn.json @@ -0,0 +1,81 @@ +{ + "input_model": { "type": "HfModel", "model_path": "BAAI/bge-small-en-v1.5", "task": "feature-extraction" }, + "systems": { + "local_system": { + "type": "LocalSystem", + "accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ] + } + }, + "data_configs": [ + { + "name": "quantize_data_config", + "type": "HuggingfaceContainer", + "load_dataset_config": { "data_name": "mteb/banking77", "split": "test" }, + "pre_process_data_config": { "max_length": 128, "padding": "max_length", "input_cols": [ "text" ] }, + "dataloader_config": { "batch_size": 1 } + } + ], + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "accuracy", + "type": "custom", + "sub_types": [ + { + "name": "accuracy_custom", + "priority": 1, + "higher_is_better": true, + "goal": { "type": "max-degradation", "value": 0.05 } + } + ], + "user_config": { + "user_script": "user_script.py", + "evaluate_func": "eval_accuracy", + "evaluate_func_kwargs": { "tasks": [ "Banking77Classification" ] } + } + } + ] + } + }, + "passes": { + "conversion": { "type": "OnnxConversion", "target_opset": 17 }, + "dynamic_shape_to_fixed": { + "type": "DynamicToFixedShape", + "dim_param": [ "batch_size", "sequence_length" ], + "dim_value": [ 1, 128 ] + }, + "QNNPreprocess": { "type": "QNNPreprocess", "fuse_layernorm": true }, + "OnnxQuantization": { + "type": "OnnxQuantization", + "data_config": "quantize_data_config", + "activation_type": "QUInt16", + "weight_type": "QUInt8", + "calibrate_method": "MinMax", + "quant_preprocess": true, + "prepare_qnn_config": true, + "op_types_to_quantize": [ + "Mul", + "Transpose", + "MatMul", + "LayerNormalization", + "Gemm", + "Gelu", + "Unsqueeze", + "Gather", + "Sub", + "Where", + "Expand", + "Tanh", + "Reshape" + ] + } + }, + "pass_flows": [ [ "conversion", "dynamic_shape_to_fixed", "QNNPreprocess", "OnnxQuantization" ] ], + "evaluator": "common_evaluator", + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_dir": "models/bge-small-en-v1.5", + "evaluate_input_model": true +} diff --git a/examples/bge/readme.md b/examples/bge/readme.md new file mode 100644 index 000000000..bf9c71959 --- /dev/null +++ b/examples/bge/readme.md @@ -0,0 +1,64 @@ +# BAAI/bge-small-en-v1.5 Optimization + +This folder contains examples of [BAAI/bge-small-en-v1.5 ](https://huggingface.co/BAAI/bge-small-en-v1.5) optimization using different workflows. + +- NPU: [Optimization with PTQ using QNN EP](#ptq-using-qnn-ep) + +## Optimization Workflows + +### PTQ using QNN EP + +This workflow performs the optimization pipeline: +- *PyTorch Model -> Onnx Model -> Static shaped Onnx Model -> Quantized Onnx Model* + +The precision will drop when Add or Softmax types of op are quantized, so they are not included. + +| Quantized Ops | precision | +|-|-| +| None (original model) | 0.8574675324675324 | +| All ("Mul", "Transpose", "Unsqueeze", "Add", "Softmax", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape") | 0.5315909090909091 | +| "MatMul", "LayerNormalization", "Gemm", "Gelu" | 0.8506818181818183 | +| "Mul", "MatMul", "LayerNormalization", "Gemm", "Gelu" | 0.850487012987013 | +| "Mul", "Transpose", "MatMul", "LayerNormalization", "Gemm", "Gelu" | 0.8504870129870131 | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze' | 0.8504870129870131 | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Add' | 0.5317207792207792 | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Softmax' | 0.5313961038961039 | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Gather' | 0.8504870129870131 | +| ... | ... | +| 'Mul', 'Transpose', 'MatMul', 'LayerNormalization', 'Gemm', 'Gelu', 'Unsqueeze', 'Gather', 'Sub', 'Where', 'Expand', 'Tanh', 'Reshape' | 0.8504870129870131 | + +## How to run +### Pip requirements +Install the necessary python packages: +```sh +# [NPU] +pip install git+https://github.com/microsoft/Olive#egg=olive-ai[qnn] +``` + +### Other dependencies +```sh +python -m pip install -r requirements.txt +``` + +### Run sample using config + +The optimization techniques to run are specified in the relevant config json file. + +First, install required packages according to passes. +```sh +olive run --config .json --setup +``` + +Then, optimize the model +```sh +olive run --config .json +``` + +or run simply with python code: +```python +from olive.workflows import run as olive_run +olive_run(".json") +``` + +After running the above command, the model candidates and corresponding config will be saved in the output directory. +You can then select the best model and config from the candidates and run the model with the selected config. diff --git a/examples/bge/requirements.txt b/examples/bge/requirements.txt new file mode 100644 index 000000000..fd146d8f6 --- /dev/null +++ b/examples/bge/requirements.txt @@ -0,0 +1 @@ +mteb diff --git a/examples/bge/user_script.py b/examples/bge/user_script.py new file mode 100644 index 000000000..fa4dee622 --- /dev/null +++ b/examples/bge/user_script.py @@ -0,0 +1,103 @@ +import json +from pathlib import Path +from typing import List + +import mteb +import numpy as np +import torch +from transformers import AutoTokenizer + +from olive.constants import Framework +from olive.engine.footprint import Footprint, FootprintNode +from olive.model import OliveModelHandler +from olive.workflows import run as olive_run + + +class OliveEncoder: + def __init__(self, model, session): + self.model = model + self.session = session + self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5") + + def encode(self, corpus: List, **kwargs): + if self.model.framework == Framework.ONNX: + encoded_input = self.tokenizer( + corpus, padding="max_length", max_length=128, truncation=True, return_tensors="np" + ) + # batch_size is 1 for static model + model_outputs = [] + for i in range(len(corpus)): + model_inputs = { + "input_ids": encoded_input.input_ids[i : i + 1, :].astype(np.int64), + "attention_mask": encoded_input.attention_mask[i : i + 1, :].astype(np.int64), + "token_type_ids": encoded_input.token_type_ids[i : i + 1, :].astype(np.int64), + } + model_output = self.model.run_session(self.session, model_inputs)[0] + model_outputs.append(model_output[0]) + model_output = np.array(model_outputs) + elif self.model.framework == Framework.PYTORCH: + encoded_input = self.tokenizer(corpus, padding=True, truncation=True, return_tensors="pt") + model_inputs = { + "input_ids": encoded_input.input_ids, + "attention_mask": encoded_input.attention_mask, + "token_type_ids": encoded_input.token_type_ids, + } + with torch.no_grad(): + model_output = self.model.run_session(self.session, model_inputs) + model_output = model_output.last_hidden_state.numpy() + # select the last hidden state of the first token (i.e., [CLS]) as the sentence embedding. + return model_output[:, 0, :] + + +def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks): + sess = model.prepare_session(inference_settings=None, device=device, execution_providers=execution_providers) + + evaluation = mteb.MTEB(tasks=tasks) + olive_encoder = OliveEncoder(model, sess) + results = evaluation.run(olive_encoder, output_folder=None) + return results[0].scores["test"][0]["main_score"] + + +if __name__ == "__main__": + import logging + import sys + + logger = logging.getLogger("bge") + logger.addHandler(logging.StreamHandler(sys.stdout)) + logger.setLevel(logging.INFO) + + # Greedy search for the best combination of ops to quantize + all_ops = [ + "Mul", + "Transpose", + "Unsqueeze", + "Add", + "Softmax", + "Gelu", + "LayerNormalization", + "Gather", + "MatMul", + "Sub", + "Where", + "Expand", + "Gemm", + "Tanh", + "Reshape", + ] + target_accuracy = 0.8 + with Path("bge-small-en-v1.5.json").open() as fin: + olive_config = json.load(fin) + for op in all_ops: + if op in olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]: + continue + olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].append(op) + result = olive_run(olive_config) + footprint: Footprint = next(iter(result.values())) + node: FootprintNode = next(iter(footprint.nodes.values())) + accuracy = node.metrics.value["accuracy-accuracy_custom"].value + logger.info( + "Ops: %s Accuracy: %f", olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"], accuracy + ) + if accuracy < target_accuracy: + olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].remove(op) + logger.info("Final Ops: %s", olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"])