From 50f360aeacfb949abc0d845e4070922555f7c58a Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Mon, 10 Feb 2025 14:04:19 -0800 Subject: [PATCH] DynamicToFixedShape: use ort shape infer tool, accept 0 as dim_value (#1600) ## Describe your changes - `onnxruntime.tools.onnx_model_utils.fix_output_shapes` cannot handle large models (#1595), so we use the ort shape infer helper and handle the logic ourselves. This also means it can now handle models with contrib operators too. - allow passing 0 as dim_value. this case is possible when creating a prompt processing model from a dynamic shaped llm where we want to make the past kv cache empty. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- olive/passes/onnx/dynamic_to_fixed_shape.py | 28 +++++++++++++++---- .../onnx/test_dynamic_to_fixed_shape.py | 11 ++++++-- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/olive/passes/onnx/dynamic_to_fixed_shape.py b/olive/passes/onnx/dynamic_to_fixed_shape.py index 45bf4cb41..2c343c3f9 100644 --- a/olive/passes/onnx/dynamic_to_fixed_shape.py +++ b/olive/passes/onnx/dynamic_to_fixed_shape.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging -from typing import Any, Callable, Dict, List +from typing import TYPE_CHECKING, Any, Callable, Dict, List from olive.common.pydantic_v1 import root_validator from olive.hardware import AcceleratorSpec @@ -14,6 +14,9 @@ from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model from olive.passes.pass_config import PassConfigParam +if TYPE_CHECKING: + from onnx import ModelProto + logger = logging.getLogger(__name__) @@ -66,7 +69,7 @@ def _run_for_config( config: Dict[str, Any], output_model_path: str, ) -> ONNXModelHandler: - from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_dim_param_fixed, make_input_shape_fixed + from onnxruntime.tools.onnx_model_utils import make_dim_param_fixed, make_input_shape_fixed onnx_model = model.load_model() output_model_path = resolve_onnx_path(output_model_path) @@ -78,9 +81,24 @@ def _run_for_config( for name, shape in zip(config["input_name"], config["input_shape"]): make_input_shape_fixed(onnx_model.graph, name, shape) # update the output shapes to make them fixed - fix_output_shapes(onnx_model) + # onnxruntime.tools.onnx_model_utils.fix_output_shapes cannot handle models > 2GB + self.fix_output_shapes(onnx_model) return model_proto_to_olive_model(onnx_model, output_model_path, config) + def fix_output_shapes(self, model_proto: "ModelProto"): + """Run shape inference on the model and update the output shapes to make them fixed.""" + from onnxruntime.tools.onnx_model_utils import is_fixed_size_tensor + from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference + + # use the onnxruntime shape inference tool since it can handle large models as well as contrib ops + inferred_proto = SymbolicShapeInference.infer_shapes(model_proto, auto_merge=True, guess_output_rank=True) + + for idx, o in enumerate(model_proto.graph.output): + if not is_fixed_size_tensor(o): + new_o = inferred_proto.graph.output[idx] + if is_fixed_size_tensor(new_o): + o.type.tensor_type.shape.CopyFrom(new_o.type.tensor_type.shape) + def _jointly_validate_configs(cls, values): if values.get("input_name") and values.get("dim_param"): @@ -98,8 +116,8 @@ def _jointly_validate_configs(cls, values): if values["dim_param"] and values["dim_value"]: if len(values["dim_param"]) != len(values["dim_value"]): raise ValueError("dim_param and dim_value must have the same number of elements.") - if any(i <= 0 for i in values["dim_value"]): - raise ValueError("dim_value must be all > 0 when dim_param is provided.") + if any(i < 0 for i in values["dim_value"]): + raise ValueError("dim_value must be all >= 0 when dim_param is provided.") if values["input_name"] and values["input_shape"]: if len(values["input_name"]) != len(values["input_shape"]): diff --git a/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py b/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py index c01b2bc9d..d8b4c0dbb 100644 --- a/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py +++ b/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py @@ -30,9 +30,16 @@ ( { "dim_param": ["batch_size"], - "dim_value": [0], + "dim_value": [-1], }, - "dim_value must be all > 0 when dim_param is provided.", + "dim_value must be all >= 0 when dim_param is provided.", + ), + ( + { + "input_name": ["input"], + "input_shape": [[1, 0, 256, 256]], + }, + "input_shape must be all > 0 when input_name is provided.", ), ], )