Skip to content

Commit

Permalink
Set onnxoptimizer.optimize as default for peepholeoptimizer (#1550)
Browse files Browse the repository at this point in the history
## Describe your changes

Set onnxoptimizer.optimize as default for peepholeoptimizer

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
xiaoyu-work authored Jan 15, 2025
1 parent 6c2b14b commit 3d6e731
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ The `OnnxPeepholeOptimizer` leverages `onnxscript` (https://onnxscript.ai/tutori
| **Fuse Consecutive Slices** | Fuses consecutive Slice operations. |
| **Eliminate Unused Initializer** | Eliminates unused initializers. |
| **Eliminate Duplicate Initializer** | Eliminates duplicate initializers. |
| **Broadcast to MatMul** | Converts broadcast patterns into MatMul operations for better efficiency. |
| **Cast Constant of Shape** | Simplifies constant casting for shape operations. |
| **GEMM to MatMul+Add** | Converts GEMM operations into MatMul and Add for improved compatibility. |
| **No-Op Removal** | Removes redundant or no-op operations in the computation graph. |

Please refer to [OnnxPeepholeOptimizer](../../../reference/pass.rst#onnx_peephole_optimizer) for more details about the pass and its config parameters.

Expand Down
45 changes: 14 additions & 31 deletions olive/passes/onnx/peephole_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# --------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict

import numpy as np
import onnx
Expand Down Expand Up @@ -236,35 +236,27 @@ def onnxscript_optimize(self):
try:
import onnxscript
except ImportError:
logger.warning("Please install onnxscript to use the ONNX optimizer feature. Skip onnxscript optimization.")
logger.warning("Please install `onnxscript` to apply more optimization.")
return

onnxscript.optimizer.optimize(self.model)

def onnxoptimizer_optimize(self):
try:
from onnxoptimizer import optimize
except ImportError:
logger.warning("Please install `onnxoptimizer` to apply more optimization.")
return

self.model = optimize(self.model)


class OnnxPeepholeOptimizer(Pass):
"""Optimize ONNX model by fusing nodes."""

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
"onnxoptimizer": PassConfigParam(
type_=bool,
default_value=False,
description="Whether to run the ONNX optimizer (https://github.com/onnx/optimizer/). Default is False.",
),
"passes": PassConfigParam(
type_=List[str],
default_value=None,
description="List of passes of ONNX optimizer to run. Default is None.",
),
"fixed_point": PassConfigParam(
type_=bool,
default_value=False,
description="Whether to run the fixed-point optimization of ONNX optimizer. Default is False.",
),
**get_external_data_config(),
}
return get_external_data_config()

def _run_for_config(
self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str
Expand All @@ -273,20 +265,11 @@ def _run_for_config(

# optimize model
peephole_optimizer = ModelOptimizer(model.model_path)
peephole_optimizer.onnxscript_optimize()
peephole_optimizer.onnxoptimizer_optimize()
peephole_optimizer.fuse_transpose_qat()
peephole_optimizer.patch_unsupported_argmax_operator()
peephole_optimizer.fuse_reshape_operations()
peephole_optimizer.onnxscript_optimize()

if config["onnxoptimizer"]:
try:
from onnxoptimizer import optimize

peephole_optimizer.model = optimize(peephole_optimizer.model, config["passes"], config["fixed_point"])
except ImportError:
logger.warning(
"Please install onnxoptimizer to use the ONNX optimizer feature. Skip onnxoptimizer optimization."
)

# save the model to the output path and return the model
return model_proto_to_olive_model(peephole_optimizer.model, output_model_path, config)
31 changes: 15 additions & 16 deletions test/unit_test/passes/onnx/test_peephole_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,19 @@ def _make_model_for_patch_unsupported_argmax_operator(
return model_proto_to_olive_model(model, filepath, config)


@patch("onnxoptimizer.optimize")
@patch("onnxscript.optimizer.optimize")
def test_onnx_peephole_optimizer_pass_patch_unsupported_argmax_operator_modified(
mock_optimize, tmp_path, external_data_config
mock_onnxscript, mock_onnxoptimizer, tmp_path, external_data_config
):
m = _make_model_for_patch_unsupported_argmax_operator(
TensorProto.INT64, str(tmp_path / "input.onnx"), external_data_config
)
p = create_pass_from_dict(
OnnxPeepholeOptimizer, external_data_config, disable_search=True, accelerator_spec=DEFAULT_GPU_CUDA_ACCELERATOR
)
mock_onnxscript.return_value = m.load_model()
mock_onnxoptimizer.return_value = m.load_model()

actual_model = p.run(m, str(tmp_path / "onnx"))
assert Path(actual_model.model_path).exists()
Expand Down Expand Up @@ -193,40 +196,36 @@ def test_onnx_peephole_optimizer_pass_fuse_reshape_operations(tmp_path, external


@patch("olive.passes.onnx.peephole_optimizer.model_proto_to_olive_model")
@patch("onnxoptimizer.optimize")
@patch("onnxscript.optimizer.optimize")
def test_onnxscript(mock_optimize, mock_model_proto_to_olive_model, tmp_path):
def test_onnxscript(mock_onnxscript, mock_onnxoptimizer, mock_model_proto_to_olive_model, tmp_path):
# setup
input_model = get_onnx_model()
p = create_pass_from_dict(OnnxPeepholeOptimizer, {}, disable_search=True)
mock_onnxscript.return_value = input_model.load_model()
mock_onnxoptimizer.return_value = input_model.load_model()
output_folder = str(tmp_path / "onnx")

# execute
p.run(input_model, output_folder)

# assert
mock_optimize.assert_called_once_with(input_model.load_model())
mock_onnxscript.assert_called_once_with(input_model.load_model())


@patch("olive.passes.onnx.peephole_optimizer.model_proto_to_olive_model")
@patch("onnxoptimizer.optimize")
def test_onnxoptimizer(mock_optimize, mock_model_proto_to_olive_model, tmp_path):
@patch("onnxscript.optimizer.optimize")
def test_onnxoptimizer(mock_onnxscript, mock_onnxoptimizer, mock_model_proto_to_olive_model, tmp_path):
# setup
input_model = get_onnx_model()
passes = ["pass"]
fixed_point = True
p = create_pass_from_dict(
OnnxPeepholeOptimizer,
{
"onnxoptimizer": True,
"passes": passes,
"fixed_point": fixed_point,
},
disable_search=True,
)
p = create_pass_from_dict(OnnxPeepholeOptimizer, {}, disable_search=True)
mock_onnxscript.return_value = input_model.load_model()
mock_onnxoptimizer.return_value = input_model.load_model()
output_folder = str(tmp_path / "onnx")

# execute
p.run(input_model, output_folder)

# assert
mock_optimize.assert_called_once_with(input_model.load_model(), passes, fixed_point)
mock_onnxoptimizer.assert_called_once()

0 comments on commit 3d6e731

Please sign in to comment.