Skip to content

Commit

Permalink
DynamicToFixedShape: use ort shape infer tool, accept 0 as dim_value (#…
Browse files Browse the repository at this point in the history
…1600)

## Describe your changes
- `onnxruntime.tools.onnx_model_utils.fix_output_shapes` cannot handle
large models (#1595), so we use the ort shape infer helper and handle
the logic ourselves. This also means it can now handle models with
contrib operators too.
- allow passing 0 as dim_value. this case is possible when creating a
prompt processing model from a dynamic shaped llm where we want to make
the past kv cache empty.

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
jambayk authored Feb 10, 2025
1 parent 59effd4 commit 50f360a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
28 changes: 23 additions & 5 deletions olive/passes/onnx/dynamic_to_fixed_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# --------------------------------------------------------------------------

import logging
from typing import Any, Callable, Dict, List
from typing import TYPE_CHECKING, Any, Callable, Dict, List

from olive.common.pydantic_v1 import root_validator
from olive.hardware import AcceleratorSpec
Expand All @@ -14,6 +14,9 @@
from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model
from olive.passes.pass_config import PassConfigParam

if TYPE_CHECKING:
from onnx import ModelProto

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -66,7 +69,7 @@ def _run_for_config(
config: Dict[str, Any],
output_model_path: str,
) -> ONNXModelHandler:
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_dim_param_fixed, make_input_shape_fixed
from onnxruntime.tools.onnx_model_utils import make_dim_param_fixed, make_input_shape_fixed

onnx_model = model.load_model()
output_model_path = resolve_onnx_path(output_model_path)
Expand All @@ -78,9 +81,24 @@ def _run_for_config(
for name, shape in zip(config["input_name"], config["input_shape"]):
make_input_shape_fixed(onnx_model.graph, name, shape)
# update the output shapes to make them fixed
fix_output_shapes(onnx_model)
# onnxruntime.tools.onnx_model_utils.fix_output_shapes cannot handle models > 2GB
self.fix_output_shapes(onnx_model)
return model_proto_to_olive_model(onnx_model, output_model_path, config)

def fix_output_shapes(self, model_proto: "ModelProto"):
"""Run shape inference on the model and update the output shapes to make them fixed."""
from onnxruntime.tools.onnx_model_utils import is_fixed_size_tensor
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

# use the onnxruntime shape inference tool since it can handle large models as well as contrib ops
inferred_proto = SymbolicShapeInference.infer_shapes(model_proto, auto_merge=True, guess_output_rank=True)

for idx, o in enumerate(model_proto.graph.output):
if not is_fixed_size_tensor(o):
new_o = inferred_proto.graph.output[idx]
if is_fixed_size_tensor(new_o):
o.type.tensor_type.shape.CopyFrom(new_o.type.tensor_type.shape)


def _jointly_validate_configs(cls, values):
if values.get("input_name") and values.get("dim_param"):
Expand All @@ -98,8 +116,8 @@ def _jointly_validate_configs(cls, values):
if values["dim_param"] and values["dim_value"]:
if len(values["dim_param"]) != len(values["dim_value"]):
raise ValueError("dim_param and dim_value must have the same number of elements.")
if any(i <= 0 for i in values["dim_value"]):
raise ValueError("dim_value must be all > 0 when dim_param is provided.")
if any(i < 0 for i in values["dim_value"]):
raise ValueError("dim_value must be all >= 0 when dim_param is provided.")

if values["input_name"] and values["input_shape"]:
if len(values["input_name"]) != len(values["input_shape"]):
Expand Down
11 changes: 9 additions & 2 deletions test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,16 @@
(
{
"dim_param": ["batch_size"],
"dim_value": [0],
"dim_value": [-1],
},
"dim_value must be all > 0 when dim_param is provided.",
"dim_value must be all >= 0 when dim_param is provided.",
),
(
{
"input_name": ["input"],
"input_shape": [[1, 0, 256, 256]],
},
"input_shape must be all > 0 when input_name is provided.",
),
],
)
Expand Down

0 comments on commit 50f360a

Please sign in to comment.