diff --git a/olive/passes/onnx/dynamic_to_fixed_shape.py b/olive/passes/onnx/dynamic_to_fixed_shape.py index 45bf4cb41..2c343c3f9 100644 --- a/olive/passes/onnx/dynamic_to_fixed_shape.py +++ b/olive/passes/onnx/dynamic_to_fixed_shape.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging -from typing import Any, Callable, Dict, List +from typing import TYPE_CHECKING, Any, Callable, Dict, List from olive.common.pydantic_v1 import root_validator from olive.hardware import AcceleratorSpec @@ -14,6 +14,9 @@ from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model from olive.passes.pass_config import PassConfigParam +if TYPE_CHECKING: + from onnx import ModelProto + logger = logging.getLogger(__name__) @@ -66,7 +69,7 @@ def _run_for_config( config: Dict[str, Any], output_model_path: str, ) -> ONNXModelHandler: - from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_dim_param_fixed, make_input_shape_fixed + from onnxruntime.tools.onnx_model_utils import make_dim_param_fixed, make_input_shape_fixed onnx_model = model.load_model() output_model_path = resolve_onnx_path(output_model_path) @@ -78,9 +81,24 @@ def _run_for_config( for name, shape in zip(config["input_name"], config["input_shape"]): make_input_shape_fixed(onnx_model.graph, name, shape) # update the output shapes to make them fixed - fix_output_shapes(onnx_model) + # onnxruntime.tools.onnx_model_utils.fix_output_shapes cannot handle models > 2GB + self.fix_output_shapes(onnx_model) return model_proto_to_olive_model(onnx_model, output_model_path, config) + def fix_output_shapes(self, model_proto: "ModelProto"): + """Run shape inference on the model and update the output shapes to make them fixed.""" + from onnxruntime.tools.onnx_model_utils import is_fixed_size_tensor + from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference + + # use the onnxruntime shape inference tool since it can handle large models as well as contrib ops + inferred_proto = SymbolicShapeInference.infer_shapes(model_proto, auto_merge=True, guess_output_rank=True) + + for idx, o in enumerate(model_proto.graph.output): + if not is_fixed_size_tensor(o): + new_o = inferred_proto.graph.output[idx] + if is_fixed_size_tensor(new_o): + o.type.tensor_type.shape.CopyFrom(new_o.type.tensor_type.shape) + def _jointly_validate_configs(cls, values): if values.get("input_name") and values.get("dim_param"): @@ -98,8 +116,8 @@ def _jointly_validate_configs(cls, values): if values["dim_param"] and values["dim_value"]: if len(values["dim_param"]) != len(values["dim_value"]): raise ValueError("dim_param and dim_value must have the same number of elements.") - if any(i <= 0 for i in values["dim_value"]): - raise ValueError("dim_value must be all > 0 when dim_param is provided.") + if any(i < 0 for i in values["dim_value"]): + raise ValueError("dim_value must be all >= 0 when dim_param is provided.") if values["input_name"] and values["input_shape"]: if len(values["input_name"]) != len(values["input_shape"]): diff --git a/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py b/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py index c01b2bc9d..d8b4c0dbb 100644 --- a/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py +++ b/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py @@ -30,9 +30,16 @@ ( { "dim_param": ["batch_size"], - "dim_value": [0], + "dim_value": [-1], }, - "dim_value must be all > 0 when dim_param is provided.", + "dim_value must be all >= 0 when dim_param is provided.", + ), + ( + { + "input_name": ["input"], + "input_shape": [[1, 0, 256, 256]], + }, + "input_shape must be all > 0 when input_name is provided.", ), ], )