Skip to content

Commit

Permalink
Add ReplaceAttentionMaskValue graph surgeon, Bert QNN EP example (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jambayk authored Feb 5, 2025
1 parent 6a9ea11 commit e2ae2e0
Show file tree
Hide file tree
Showing 10 changed files with 396 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/source/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
||phi2|[Link](https://github.com/microsoft/Olive/tree/main/examples/phi2)|`CPU`: with ONNX Runtime optimizations fp32/int4<br>`GPU` with ONNX Runtime optimizations fp16/int4, with PyTorch QLoRA for model fine tune<br>`GPU` with SliceGPT for an optimized PyTorch model with sparsity
||falcon|[Link](https://github.com/microsoft/Olive/tree/main/examples/falcon)|`GPU`: with ONNX Runtime optimizations for optimized FP16 ONNX model
||red pajama|[Link](https://github.com/microsoft/Olive/tree/main/examples/red_pajama)| `CPU`: with Optimum conversion and merging and ONNX Runtime optimizations for a single optimized ONNX model
||bert|[Link](https://github.com/microsoft/Olive/tree/main/examples/bert)|`CPU`: with ONNX Runtime optimizations and quantization for optimized INT8 ONNX model<br>`CPU`: with ONNX Runtime optimizations and Intel® Neural Compressor quantization for optimized INT8 ONNX model<br>`CPU`: with PyTorch QAT Customized Training Loop and ONNX Runtime optimizations for optimized ONNX INT8 model<br>`GPU`: with ONNX Runtime optimizations for CUDA EP<br>`GPU`: with ONNX Runtime optimizations for TRT EP
||bert|[Link](https://github.com/microsoft/Olive/tree/main/examples/bert)|`CPU`: with ONNX Runtime optimizations and quantization for optimized INT8 ONNX model<br>`CPU`: with ONNX Runtime optimizations and Intel® Neural Compressor quantization for optimized INT8 ONNX model<br>`CPU`: with PyTorch QAT Customized Training Loop and ONNX Runtime optimizations for optimized ONNX INT8 model<br>`GPU`: with ONNX Runtime optimizations for CUDA EP<br>`GPU`: with ONNX Runtime optimizations for TRT EP<br>`NPU`: with ONNX Runtime optimizations for QNN EP
||deberta|[Link](https://github.com/microsoft/Olive/tree/main/examples/deberta)|`GPU`: Optimize Azureml Registry Model with ONNX Runtime optimizations and quantization
||gptj|[Link](https://github.com/microsoft/Olive/tree/main/examples/gptj)|`CPU`: with Intel® Neural Compressor static/dynamic quantization for INT8 ONNX model
|Audio|whisper|[Link](https://github.com/microsoft/Olive/tree/main/examples/whisper)|`CPU`: with ONNX Runtime optimizations for all-in-one ONNX model in FP32<br>`CPU`: with ONNX Runtime optimizations for all-in-one ONNX model in INT8<br>`CPU`: with ONNX Runtime optimizations and Intel® Neural Compressor Dynamic Quantization for all-in-one ONNX model in INT8<br>`GPU`: with ONNX Runtime optimizations for all-in-one ONNX model in FP32<br>`GPU`: with ONNX Runtime optimizations for all-in-one ONNX model in FP16<br>`GPU`: with ONNX Runtime optimizations for all-in-one ONNX model in INT8
Expand Down
106 changes: 106 additions & 0 deletions docs/source/how-to/configure-workflows/onnx-graph-surgeon.md
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,109 @@ Transformed model graph:
[Root] --> LpNormalization --> Mul
(p=2, axis=-1)
```

### ReplaceAttentionMaskValue

#### Description

Replace the value of extended attention mask with a new value. This surgery is useful if the default mask value does not quantize well due to numerical instability.

#### Example

Initial model graph:

```
graph {
node {
input: "input1"
output: "output1"
name: "ConstantOfShape"
op_type: "ConstantOfShape"
attribute {
name: "value"
t {
dims: 1
data_type: 1
float_data: -3.4028234663852886e+38
name: ""
}
type: TENSOR
}
}
node {
output: "Constant_output"
name: "Constant"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 1
float_data: -3.4028234663852886e+38
name: ""
}
type: TENSOR
}
}
initializer {
data_type: 1
float_data: -3.4028234663852886e+38
name: "init"
}
}
```

After applying:

```json
{
"type": "GraphSurgeries",
"surgeries": [
{
"surgeon": "ReplaceAttentionMaskValue"
}
]
}
```


Transformed model graph:

```
graph {
node {
input: "input1"
output: "output1"
name: "ConstantOfShape"
op_type: "ConstantOfShape"
attribute {
name: "value"
t {
dims: 1
data_type: 1
float_data: -10000.0
name: ""
}
type: TENSOR
}
}
node {
output: "Constant_output"
name: "Constant"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 1
float_data: -10000.0
name: ""
}
type: TENSOR
}
}
initializer {
data_type: 1
float_data: -10000.0
name: "init"
}
}
```
14 changes: 14 additions & 0 deletions examples/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This folder contains examples of BERT optimization using different workflows.
- CPU: [Optimization with Intel® Neural Compressor PTQ](#bert-optimization-with-intel®-neural-compressor-ptq-on-cpu)
- CPU: [Optimization with QAT Customized Training Loop](#bert-optimization-with-qat-customized-training-loop-on-cpu)
- GPU: [Optimization with CUDA/TensorRT](#bert-optimization-with-cudatensorrt-on-gpu)
- NPU: [Optimization with PTQ on Qualcomm NPU using QNN EP](#bert-optimization-with-ptq-on-npu)

Go to [How to run](#how-to-run)

Expand Down Expand Up @@ -98,6 +99,16 @@ This workflow performs BERT optimization on GPU with CUDA/TensorRT. It performs
- *PyTorch Model -> Onnx Model -> ONNX Runtime performance tuning with trt_fp16_enable*
Config file: [bert_trt_gpu.json](bert_trt_gpu.json)

### BERT optimization with PTQ on NPU
This workflow performs BERT optimization on Qualcomm NPU with ONNX Runtime PTQ. It performs the optimization pipeline:
- *PyTorch Model -> Onnx Model -> Static shaped Onnx Model -> Quantized Onnx Model*

It requires x86 python environment on a Windows ARM machine with `onnxruntime-qnn` installed.

Config file: [bert_ptq_qnn.json](bert_ptq_qnn.json)

**NOTE:** The model optimization part of the workflow can also be done on a Linux/Windows machine with a different onnxruntime package installed. Remove the `"evaluators"` and `"evaluator"` sections from the configuration file to skip the evaluation step.

## How to run
### Pip requirements
Install the necessary python packages:
Expand All @@ -106,6 +117,9 @@ Install the necessary python packages:
pip install git+https://github.com/microsoft/Olive#egg=olive-ai[cpu]
# [GPU]
pip install git+https://github.com/microsoft/Olive#egg=olive-ai[gpu]
# [NPU]
pip install git+https://github.com/microsoft/Olive#egg=olive-ai[qnn]
```

# Other dependencies
python -m pip install -r requirements.txt
Expand Down
75 changes: 75 additions & 0 deletions examples/bert/bert_ptq_qnn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
{
"input_model": {
"type": "HfModel",
"model_path": "Intel/bert-base-uncased-mrpc",
"task": "text-classification",
"load_kwargs": { "attn_implementation": "eager" }
},
"systems": {
"local_system": {
"type": "LocalSystem",
"accelerators": [ { "device": "npu", "execution_providers": [ "QNNExecutionProvider" ] } ]
}
},
"data_configs": [
{
"name": "glue_mrpc",
"type": "HuggingfaceContainer",
"load_dataset_config": { "data_name": "glue", "subset": "mrpc", "split": "validation" },
"pre_process_data_config": {
"max_length": 128,
"padding": "max_length",
"input_cols": [ "sentence1", "sentence2" ],
"max_samples": 100
},
"dataloader_config": { "batch_size": 1 }
}
],
"evaluators": {
"common_evaluator": {
"metrics": [
{
"name": "accuracy",
"type": "accuracy",
"data_config": "glue_mrpc",
"sub_types": [ { "name": "accuracy_score", "priority": 1 } ]
},
{
"name": "latency",
"type": "latency",
"data_config": "glue_mrpc",
"sub_types": [ { "name": "avg", "priority": 2 } ]
},
{
"name": "latency_cpu",
"type": "latency",
"data_config": "glue_mrpc",
"sub_types": [ { "name": "avg", "priority": 3 } ],
"inference_settings": { "onnx": { "execution_provider": "CPUExecutionProvider" } }
}
]
}
},
"passes": {
"conversion": { "type": "OnnxConversion", "target_opset": 17 },
"dynamic_shape_to_fixed": {
"type": "DynamicToFixedShape",
"dim_param": [ "batch_size", "sequence_length" ],
"dim_value": [ 1, 128 ]
},
"surgery": { "type": "GraphSurgeries", "surgeries": [ { "surgeon": "ReplaceAttentionMaskValue" } ] },
"qnn_preprocess": { "type": "QNNPreprocess" },
"quantization": {
"type": "OnnxStaticQuantization",
"data_config": "glue_mrpc",
"activation_type": "QUInt16",
"weight_type": "QUInt8"
}
},
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"host": "local_system",
"target": "local_system",
"cache_dir": "cache",
"output_dir": "models/bert_ptq_qnn"
}
2 changes: 1 addition & 1 deletion examples/mobilenet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ Run the following command to quantize the model and evaluate it on the NPU:
olive run --config mobilenet_qnn_ep.json
```

**NOTE:** The model optimization part of the workflow can also be done on a Linux machine with a different onnxruntime package installed. Remove the `"evaluators"` and `"evaluator`" sections from the `mobilenet_qnn_ep.json` configuration file to skip the evaluation step.
**NOTE:** The model optimization part of the workflow can also be done on a Linux/Windows machine with a different onnxruntime package installed. Remove the `"evaluators"` and `"evaluator"` sections from the `mobilenet_qnn_ep.json` configuration file to skip the evaluation step.
6 changes: 4 additions & 2 deletions olive/evaluator/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from typing import Any, Dict, List, Optional, Union

from olive.common.config_utils import ConfigBase, validate_config
from olive.common.config_utils import ConfigBase, NestedConfig, validate_config
from olive.common.pydantic_v1 import validator
from olive.common.utils import StrEnumBase
from olive.data.config import DataConfig
Expand Down Expand Up @@ -92,7 +92,9 @@ def validate_goal(cls, v, values):
return v


class Metric(ConfigBase):
class Metric(NestedConfig):
_nested_field_name = "user_config"

name: str
type: MetricType
backend: Optional[str] = "torch_metrics"
Expand Down
1 change: 1 addition & 0 deletions olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@
"optimum": [ "optimum" ],
"ort-genai": [ "onnxruntime-genai" ],
"ort": [ "onnxruntime", "onnxruntime-directml", "onnxruntime-gpu", "onnxruntime-openvino", "numpy<2.0" ],
"qnn": [ "onnxruntime-qnn" ],
"tf": [ "tensorflow==1.15.0" ],
"torch-tensorrt": [ "torch-tensorrt" ],
"tune-session-params": [ "psutil" ]
Expand Down
90 changes: 80 additions & 10 deletions olive/passes/onnx/graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,7 @@ def __call__(self, model: ModelProto):
rmsnorm_weight = np.array([1], dtype=rmsnorm_weight.dtype)
rmsnorm_weight = sqrt_n * rmsnorm_weight

dag.replace_initializer(
onnx.numpy_helper.from_array(rmsnorm_weight, name=rmsnorm_weight_name), graph_idx
)
dag.replace_initializer(onnx.numpy_helper.from_array(rmsnorm_weight, name=rmsnorm_weight_name))
replaced_initializers.add(rmsnorm_weight_name)

# add and replace nodes
Expand Down Expand Up @@ -597,6 +595,67 @@ def get_rmsnorm_nodes(pow_node: str, dag: OnnxDAG) -> Optional[List[str]]:
return rmsnorm_nodes if len(rmsnorm_nodes) >= (len(pattern) - 1) else []


class ReplaceAttentionMaskValue(Surgeon):
"""Replace the value of extended attention mask with a new value.
This surgery is useful if the default mask value does not quantize well due to numerical instability.
"""

def __init__(self, threshold: float = -3e30, replacement: float = -1e4):
self.threshold = threshold
self.replacement = replacement

def __call__(self, model: ModelProto):
dag = OnnxDAG(model)
modified = 0

# update any constant or constantofshape nodes with the threshold value
for node_name in dag.get_node_names():
op_type = dag.get_node_op_type(node_name)
node_proto = dag.get_node_proto(node_name)
# print(node_name)
if not (
op_type in {"Constant", "ConstantOfShape"}
and node_proto.attribute
and node_proto.attribute[0].t
and node_proto.attribute[0].t.data_type == onnx.TensorProto.FLOAT
and node_proto.attribute[0].t.dims in [[], [1]]
):
continue

value = onnx.helper.get_attribute_value(node_proto.attribute[0])
tensor_value = onnx.numpy_helper.to_array(value)
if tensor_value < self.threshold:
node_proto.ClearField("attribute")
node_proto.attribute.extend(
[
onnx.helper.make_attribute(
"value", onnx.numpy_helper.from_array(np.full_like(tensor_value, self.replacement))
)
]
)
modified += 1

# update any initializer nodes with the threshold value
for init_name in dag.get_initializer_names():
init_proto = dag.get_initializer_proto(init_name)
if not (init_proto.data_type == onnx.TensorProto.FLOAT and init_proto.dims in [[], [1]]):
continue

tensor_value = onnx.numpy_helper.to_array(init_proto)
if tensor_value < self.threshold:
dag.replace_initializer(
onnx.numpy_helper.from_array(np.full_like(tensor_value, self.replacement), name=init_name)
)
modified += 1

if modified > 0:
logger.debug("Replaced %d values below threshold with replacement.", modified)

dag.update()
return dag.model


class GraphSurgeries(Pass):
"""ONNX graph surgeries collections.
Expand Down Expand Up @@ -657,22 +716,33 @@ def init_surgeon_instance(self, surgery):
if not surgeon_class:
raise ValueError(f"Surgeon '{surgeon_name}' does not exist. Available surgeons: {Surgeon.registry.keys()}")

required_params = self.get_surgeon_parameters(surgeon_class)
required_params, optional_params = self.get_surgeon_parameters(surgeon_class)
provided_params = set(surgery.keys()) - {"surgeon"}
missing_params = set(required_params) - provided_params
extra_params = provided_params - set(required_params)
extra_params = provided_params - set(required_params) - set(optional_params)

if missing_params:
raise ValueError(f"Missing parameters for surgery '{surgeon_name}': {missing_params}")
if extra_params:
raise ValueError(f"Ignoring extra parameters for surgery '{surgeon_name}': {extra_params}")
logger.warning("Ignoring extra parameters for surgery '%s': %s", surgeon_name, extra_params)

init_params = {param: surgery[param] for param in required_params}
init_params.update({param: surgery[param] for param in optional_params if param in surgery})
return surgeon_class(**init_params)

@staticmethod
def get_surgeon_parameters(surgeon_class):
signature = inspect.signature(surgeon_class.__init__)
params = list(signature.parameters.keys())
params.remove("self")
return params
parameters = inspect.signature(surgeon_class.__init__).parameters

positional_args = [
name
for name, param in parameters.items()
if param.default == param.empty and param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD)
]
positional_args.remove("self")
keyword_args = [
name
for name, param in parameters.items()
if param.default != param.empty or param.kind == param.KEYWORD_ONLY
]
return positional_args, keyword_args
Loading

0 comments on commit e2ae2e0

Please sign in to comment.