From 63bf1c2b564bd4c4df0ec78c53e1f1b69f4cf410 Mon Sep 17 00:00:00 2001
From: Ti-Tai Wang <titaiwang@microsoft.com>
Date: Thu, 30 Jan 2025 02:28:24 +0000
Subject: [PATCH] support torch.export.Dim.AUTO

---
 olive/common/hf/model_io.py                   | 51 ++++++++++++++--
 olive/passes/onnx/conversion.py               | 60 +++++++------------
 olive/passes/pytorch/common.py                | 23 ++++---
 requirements.txt                              |  1 +
 test/requirements-test.txt                    |  2 -
 test/unit_test/model/test_hf_model.py         |  6 +-
 test/unit_test/model/test_pytorch_model.py    |  6 +-
 test/unit_test/passes/onnx/test_conversion.py | 33 +++++-----
 8 files changed, 102 insertions(+), 80 deletions(-)

diff --git a/olive/common/hf/model_io.py b/olive/common/hf/model_io.py
index 5a786458a..5d6690731 100644
--- a/olive/common/hf/model_io.py
+++ b/olive/common/hf/model_io.py
@@ -4,7 +4,7 @@
 # --------------------------------------------------------------------------
 import logging
 from itertools import chain
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Optional
 
 from olive.common.hf.mlflow import get_pretrained_name_or_path
 from olive.common.hf.peft import is_peft_model
@@ -124,10 +124,51 @@ def get_model_io_config(model_name: str, task: str, model: "PreTrainedModel", **
         for axis, axis_name in value.items():
             if axis_name == "past_sequence_length + 1":
                 value[axis] = "past_sequence_length + sequence_length"
-    # NOTE: Due to the complexity of dynamic_shapes, we don't provide it here.
-    # torch-onnx converter has a naive approach to auto-gen dynamic shapes based on input and
-    # dynamic_axes, so we don't need to provide dynamic shapes here.
-    return {"input_names": input_names, "output_names": output_names, "dynamic_axes": dynamic_axes}
+    # dynamic_shapes should follow input order and format
+    dynamic_shapes = _unflatten_past_key_values_with_check(inputs)
+    return {
+        "input_names": input_names,
+        "output_names": output_names,
+        "dynamic_axes": dynamic_axes,
+        "dynamic_shapes": dynamic_shapes,
+    }
+
+
+def _unflatten_past_key_values_with_check(flattened_inputs: Dict[str, Any]) -> Dict[str, Any]:
+    max_idx = -1
+    past_key_value_count = 0  # Track number of key-value pairs
+
+    # Find the max index for generating unflatten past_key_values later
+    # and record the total number of past_key_values entries for validation
+    for input_name in flattened_inputs:
+        if input_name.startswith("past_key_values"):
+            # From Optimum: past_key_values.0.key, past_key_values.0.value,
+            #               past_key_values.1.key, past_key_values.1.value, ...
+            idx = int(input_name.split(".")[1])
+            max_idx = max(max_idx, idx)
+            past_key_value_count += 1
+
+    # Check if we have exactly 2 * (max_idx + 1) key-value pairs
+    expected_count = 2 * (max_idx + 1)
+    if past_key_value_count != expected_count or past_key_value_count % 2 != 0:
+        raise ValueError(
+            f"Expected {expected_count} past_key_values entries, but found {past_key_value_count} from Optimum inputs."
+        )
+    if max_idx == -1:
+        # No past_key_values found
+        return flattened_inputs
+
+    unflattened = {}
+    for input_name, dynamic_shapes in flattened_inputs.items():
+        # Replace flattened past_key_values with unflattened past_key_values
+        if not input_name.startswith("past_key_values"):
+            unflattened[input_name] = dynamic_shapes
+    # Based on Optimum's implementation:
+    # https://github.com/huggingface/optimum/blob/b755036ae12e0959d61085e597e7b96473c4b46d/optimum/exporters/onnx/base.py#L629
+    # past_key_values is a list of lists, and it locates at the end of the input list/dict
+    # Generate the past_key_values list using the max index
+    unflattened["past_key_values"] = [[dynamic_shapes, dynamic_shapes] for _ in range(max_idx + 1)]
+    return unflattened
 
 
 def get_model_dummy_input(model_name: str, task: str, **kwargs) -> Optional[Dict]:
diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py
index a32dc88cf..f808803a5 100644
--- a/olive/passes/onnx/conversion.py
+++ b/olive/passes/onnx/conversion.py
@@ -212,12 +212,11 @@ def _export_pytorch_model(
             # The "legacy dynamo" is the torch.onnx_dynamo_export API
             legacy_dynamo_supported_version = version.parse("2.2.0").release
             # The new "dynamo" api is torch.onnx.export with dynamo=True
-            # TODO(#1478): Change 2.6.0 back to 2.5.0 when dynamic_shapes are supported in Olive
             dynamo_supported_version = version.parse("2.6.0").release
             if torch_version < legacy_dynamo_supported_version:
                 raise ImportError(
-                    f"torch.onnx.dynamo_export is not available for torch version {torch_version}. "
-                    "Please upgrade your torch version to 2.5.0 or above."
+                    f"torch.onnx.export(..., dynamo=True) is not available for torch version {torch_version}. "
+                    "Please upgrade your torch version to 2.6.0 or above."
                 )
             from torch._dynamo import config as dynamo_config
 
@@ -597,14 +596,7 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs):
 
     def is_dict_axes(x) -> bool:
         return isinstance(x, dict) and all(
-            isinstance(key, str)
-            and len(key) == 1
-            and isinstance(value, list)
-            and len(value) == 3
-            and isinstance(value[0], str)
-            and isinstance(value[1], int)
-            and isinstance(value[2], int)
-            for key, value in x.items()
+            isinstance(key, (str, int)) and isinstance(value, str) for key, value in x.items()
         )
 
     flat_dynamic_shapes, _ = _pytree.tree_flatten(dynamic_shapes, is_leaf=is_dict_axes)
@@ -615,15 +607,20 @@ def is_dict_axes(x) -> bool:
             continue
         new_axes = {}
         for axis, dynamic_shape in axes.items():
-            new_axes[int(axis)] = dynamic_shape
+            try:
+                axis_int = int(axis)
+            except ValueError as e:
+                raise ValueError(
+                    f"Please check dynamic_shapes in configuration. Invalid axis {axis}. The axis should be an integer."
+                ) from e
+            new_axes[axis_int] = dynamic_shape
         new_dynamic_shapes.append(new_axes)
-
     _, tree_structure = _pytree.tree_flatten(dummy_inputs, is_leaf=is_dict_axes)
     return _pytree.tree_unflatten(new_dynamic_shapes, tree_structure)
 
 
 def _convert_dynamic_shapes_to_torch_export_dims(
-    dynamic_shapes: Dict[str, Dict[int, torch.export.Dim]]
+    dynamic_shapes: Dict[str, Dict[int, torch.export.Dim]],
 ) -> Dict[str, Dict[int, torch.export.Dim]]:
     """Convert dynamic_shapes to torch export dims.
 
@@ -634,8 +631,12 @@ def _convert_dynamic_shapes_to_torch_export_dims(
 
     For a single axis:
 
-    before: ["axis_name", min_value, max_value]
-    after: torch.export.Dim("axis_name", min=min_value, max=max_value)
+    before: {0: "batch_size"}
+    after: {0: torch.export.Dim.AUTO)
+
+    NOTE: The user specified dim name is not respected at the moment due to
+          technical issue in torch.export. Follow up
+          https://github.com/pytorch/pytorch/issues/144273
 
     # Please check `dynamic_shapes` in torch.export.export
     # https://pytorch.org/docs/stable/export.html#torch.export.export
@@ -646,9 +647,6 @@ def _convert_dynamic_shapes_to_torch_export_dims(
     if dynamic_shapes is None:
         return None
 
-    # If the axes has the same name, they should be the same torch.export.Dim
-    torch_export_dim_farm: Dict[str, torch.export.Dim] = {}
-
     # dynamic_shapes follows input format, which could be nested
     def _from_tuple_to_dim(data: Union[Dict, List, Tuple, Any]) -> Union[Dict, List, Tuple, Any]:
         if isinstance(data, dict):
@@ -657,26 +655,14 @@ def _from_tuple_to_dim(data: Union[Dict, List, Tuple, Any]) -> Union[Dict, List,
         # TODO(titaiwang): Can we use `dummy_inputs` to align the dynamic_shapes format?
         # JSON foramt does not accept tuple.
         elif isinstance(data, (tuple, list)):
-            # We assume the tuple/list is in the format of (name, min, max)
-            # TODO(titaiwang): This format could potentially be used as model
-            # inputs (would string be used as model input?)
-            if len(data) == 3 and isinstance(data[0], str) and isinstance(data[1], int) and isinstance(data[2], int):
-                if data[0] in torch_export_dim_farm:
-                    if torch_export_dim_farm[data[0]].min == data[1] and torch_export_dim_farm[data[0]].max == data[2]:
-                        return torch_export_dim_farm[data[0]]
-                    raise ValueError(
-                        f"Found different boundary for the same axis name {data[0]}. "
-                        f"Previous min: {torch_export_dim_farm[data[0]].min} and "
-                        f"max: {torch_export_dim_farm[data[0]].max}. "
-                        f"Current min: {data[1]} and max: {data[2]}."
-                    )
-                dim = torch.export.Dim(data[0], min=data[1], max=data[2])
-                torch_export_dim_farm[data[0]] = dim
-                return dim
             if isinstance(data, tuple):
                 return tuple(_from_tuple_to_dim(item) for item in data)
-            if isinstance(data, list):
-                return [_from_tuple_to_dim(item) for item in data]
+            return [_from_tuple_to_dim(item) for item in data]
+        elif isinstance(data, str):
+            # NOTE: AUTO does not guarantee the axis will be dynamic, but
+            #       it is the only option that not only avoid crashing during
+            #       export, but also does not require user to specify min/max.
+            return torch.export.Dim.AUTO
         return data
 
     return _from_tuple_to_dim(dynamic_shapes)
diff --git a/olive/passes/pytorch/common.py b/olive/passes/pytorch/common.py
index 246635c98..9e1fcb3f9 100644
--- a/olive/passes/pytorch/common.py
+++ b/olive/passes/pytorch/common.py
@@ -51,18 +51,6 @@ def inherit_pytorch_from_hf(
         hf_io_config = deepcopy(model.io_config)
         hf_dummy_inputs = model.get_dummy_inputs()
 
-        dynamic_shapes = hf_io_config.get("dynamic_shapes", {})
-        if isinstance(dynamic_shapes, dict):
-            {
-                k: v
-                for k, v in hf_io_config.get("dynamic_axes", {}).items()
-                if not k.startswith(("present", "past_key_values"))
-            }
-        else:
-            # TODO(titaiwang): fix this when we have a better way to handle dynamic_shapes
-            # If the dynamic_shapes is a list, we don't inherit it since
-            # we do not know the exact index of the past_key_values in the list
-            dynamic_shapes = {}
         # kv cache will be handled by the kv_cache flag in io_config
         io_config = {
             "input_names": [i for i in hf_io_config.get("input_names", []) if not i.startswith("past_key_values")],
@@ -74,7 +62,6 @@ def inherit_pytorch_from_hf(
                 for k, v in hf_io_config.get("dynamic_axes", {}).items()
                 if not k.startswith(("present", "past_key_values"))
             },
-            "dynamic_shapes": dynamic_shapes,
         }
 
         for i_name in io_config["input_names"]:
@@ -84,6 +71,16 @@ def inherit_pytorch_from_hf(
     if io_config and not io_config.get("kv_cache") and model.task.endswith("-with-past"):
         io_config["kv_cache"] = True
 
+    # dynamic_shapes deals with kv_cache here. If kv_cache is False,
+    # we remove past_key_values from dynamic_shapes
+    if not io_config.get("kv_cache", False):
+        dynamic_shapes = {
+            k: v for k, v in hf_io_config.get("dynamic_shapes", {}).items() if not k.startswith("past_key_values")
+        }
+    else:
+        dynamic_shapes = hf_io_config.get("dynamic_shapes", {})
+    io_config["dynamic_shapes"] = dynamic_shapes
+
     return PyTorchModelHandler(
         model_path=model_path,
         model_file_format=model_file_format,
diff --git a/requirements.txt b/requirements.txt
index 786904514..54b97185f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,6 @@
 numpy
 onnx
+onnxscript
 optuna
 pandas
 protobuf<4.0.0
diff --git a/test/requirements-test.txt b/test/requirements-test.txt
index f2d073aef..70dd4dd0c 100644
--- a/test/requirements-test.txt
+++ b/test/requirements-test.txt
@@ -29,8 +29,6 @@ onnxconverter_common
 onnxmltools
 onnxoptimizer
 onnxruntime_extensions
-# TODO(titaiwai): Add onnxscript to requirements.txt once it's released
-onnxscript
 openvino==2023.2.0
 optimum>=1.17.0
 pandas
diff --git a/test/unit_test/model/test_hf_model.py b/test/unit_test/model/test_hf_model.py
index bd8ad17b6..343c18b1c 100644
--- a/test/unit_test/model/test_hf_model.py
+++ b/test/unit_test/model/test_hf_model.py
@@ -145,9 +145,9 @@ def setup(self):
                 "token_type_ids": {"0": "batch_size", "1": "seq_length"},
             },
             "dynamic_shapes": {
-                "input_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
-                "attention_mask": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
-                "token_type_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
+                "input_ids": {"0": "batch_size", "1": "seq_length"},
+                "attention_mask": {"0": "batch_size", "1": "seq_length"},
+                "token_type_ids": {"0": "batch_size", "1": "seq_length"},
             },
         }
 
diff --git a/test/unit_test/model/test_pytorch_model.py b/test/unit_test/model/test_pytorch_model.py
index ac26454e1..4bbdd00d4 100644
--- a/test/unit_test/model/test_pytorch_model.py
+++ b/test/unit_test/model/test_pytorch_model.py
@@ -23,9 +23,9 @@ def io_config_fixture():
             "token_type_ids": {"0": "batch_size", "1": "seq_length"},
         },
         "dynamic_shapes": {
-            "input_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
-            "attention_mask": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
-            "token_type_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
+            "input_ids": {"0": "batch_size", "1": "seq_length"},
+            "attention_mask": {"0": "batch_size", "1": "seq_length"},
+            "token_type_ids": {"0": "batch_size", "1": "seq_length"},
         },
     }
 
diff --git a/test/unit_test/passes/onnx/test_conversion.py b/test/unit_test/passes/onnx/test_conversion.py
index 92a384415..8b7f63f00 100644
--- a/test/unit_test/passes/onnx/test_conversion.py
+++ b/test/unit_test/passes/onnx/test_conversion.py
@@ -4,7 +4,6 @@
 # --------------------------------------------------------------------------
 import platform
 import shutil
-import sys
 from itertools import chain
 from pathlib import Path
 from test.unit_test.utils import ONNX_MODEL_PATH, get_hf_model, get_onnx_model, get_pytorch_model, pytorch_model_loader
@@ -21,7 +20,7 @@
 from olive.passes.onnx.conversion import OnnxConversion, OnnxOpVersionConversion
 
 
-@pytest.mark.skipif(sys.version_info > (3, 8), reason="Failed with Python 3.10, need to investigate.")
+# @pytest.mark.skipif(sys.version_info > (3, 8), reason="Failed with Python 3.10, need to investigate.")
 @pytest.mark.parametrize(
     ("input_model", "use_dynamo_exporter"),
     [(get_pytorch_model(), True), (get_hf_model(), True), (get_pytorch_model(), False), (get_hf_model(), False)],
@@ -189,10 +188,10 @@ def mock_onnx_export_func(*args, **kwargs):
 @pytest.mark.parametrize(
     "dynamic_shapes",
     [
-        [{"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024], "1": ["y_axis", 0, 6]}],
+        [{"0": "axis_batch", "1": "x_axis"}, {"0": "axis_batch", "1": "y_axis"}],
         {
-            "input_x": {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]},
-            "input_y": {"0": ["axis_batch", 0, 1024], "1": ["y_axis", 0, 6]},
+            "input_x": {"0": "axis_batch", "1": "x_axis"},
+            "input_y": {"0": "axis_batch", "1": "y_axis"},
         },
     ],
 )
@@ -224,30 +223,30 @@ def _get_simulate_torch_float_tensor_inputs(return_tuple: bool = False):
     [
         (
             [
-                {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]},
-                [{"1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024]}],
-                {"a": {"0": ["axis_batch", 0, 1024]}, "b": {"1": ["x_axis", 0, 8]}},
+                {"0": "axis_batch", "1": "x_axis"},
+                [{"1": "x_axis"}, {"0": "axis_batch"}],
+                {"a": {"0": "axis_batch"}, "b": {"1": "x_axis"}},
                 None,
             ],
             (
-                {0: ["axis_batch", 0, 1024], 1: ["x_axis", 0, 8]},
-                ({1: ["x_axis", 0, 8]}, {0: ["axis_batch", 0, 1024]}),
-                {"a": {0: ["axis_batch", 0, 1024]}, "b": {1: ["x_axis", 0, 8]}},
+                {0: "axis_batch", 1: "x_axis"},
+                ({1: "x_axis"}, {0: "axis_batch"}),
+                {"a": {0: "axis_batch"}, "b": {1: "x_axis"}},
                 None,
             ),
             _get_simulate_torch_float_tensor_inputs(return_tuple=True),
         ),
         (
             {
-                "w": {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]},
-                "x": [{"1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024]}],
-                "y": {"a": {"0": ["axis_batch", 0, 1024]}, "b": {"1": ["x_axis", 0, 8]}},
+                "w": {"0": "axis_batch", "1": "x_axis"},
+                "x": [{"1": "x_axis"}, {"0": "axis_batch"}],
+                "y": {"a": {"0": "axis_batch"}, "b": {"1": "x_axis"}},
                 "z": None,
             },
             {
-                "w": {0: ["axis_batch", 0, 1024], 1: ["x_axis", 0, 8]},
-                "x": ({1: ["x_axis", 0, 8]}, {0: ["axis_batch", 0, 1024]}),
-                "y": {"a": {0: ["axis_batch", 0, 1024]}, "b": {1: ["x_axis", 0, 8]}},
+                "w": {0: "axis_batch", 1: "x_axis"},
+                "x": ({1: "x_axis"}, {0: "axis_batch"}),
+                "y": {"a": {0: "axis_batch"}, "b": {1: "x_axis"}},
                 "z": None,
             },
             _get_simulate_torch_float_tensor_inputs(return_tuple=False),