Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.export.Dim.AUTO in ONNX conversion pass #1586

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions olive/common/hf/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# --------------------------------------------------------------------------
import logging
from itertools import chain
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from olive.common.hf.mlflow import get_pretrained_name_or_path
from olive.common.hf.peft import is_peft_model
Expand Down Expand Up @@ -124,10 +124,51 @@
for axis, axis_name in value.items():
if axis_name == "past_sequence_length + 1":
value[axis] = "past_sequence_length + sequence_length"
# NOTE: Due to the complexity of dynamic_shapes, we don't provide it here.
# torch-onnx converter has a naive approach to auto-gen dynamic shapes based on input and
# dynamic_axes, so we don't need to provide dynamic shapes here.
return {"input_names": input_names, "output_names": output_names, "dynamic_axes": dynamic_axes}
# dynamic_shapes should follow input order and format
dynamic_shapes = _unflatten_past_key_values_with_check(inputs)
return {
"input_names": input_names,
"output_names": output_names,
"dynamic_axes": dynamic_axes,
"dynamic_shapes": dynamic_shapes,
}


def _unflatten_past_key_values_with_check(flattened_inputs: Dict[str, Any]) -> Dict[str, Any]:
max_idx = -1
past_key_value_count = 0 # Track number of key-value pairs

# Find the max index for generating unflatten past_key_values later
# and record the total number of past_key_values entries for validation
for input_name in flattened_inputs:
if input_name.startswith("past_key_values"):
# From Optimum: past_key_values.0.key, past_key_values.0.value,
# past_key_values.1.key, past_key_values.1.value, ...
idx = int(input_name.split(".")[1])
max_idx = max(max_idx, idx)
past_key_value_count += 1

# Check if we have exactly 2 * (max_idx + 1) key-value pairs
expected_count = 2 * (max_idx + 1)
if past_key_value_count != expected_count or past_key_value_count % 2 != 0:
raise ValueError(
f"Expected {expected_count} past_key_values entries, but found {past_key_value_count} from Optimum inputs."
)
if max_idx == -1:
# No past_key_values found
return flattened_inputs

unflattened = {}
for input_name, dynamic_shapes in flattened_inputs.items():
# Replace flattened past_key_values with unflattened past_key_values
if not input_name.startswith("past_key_values"):
unflattened[input_name] = dynamic_shapes

Check warning

Code scanning / lintrunner

RUFF/PERF403 Warning

Use a dictionary comprehension instead of a for-loop.
See https://docs.astral.sh/ruff/rules/manual-dict-comprehension
# Based on Optimum's implementation:
# https://github.com/huggingface/optimum/blob/b755036ae12e0959d61085e597e7b96473c4b46d/optimum/exporters/onnx/base.py#L629
# past_key_values is a list of lists, and it locates at the end of the input list/dict
# Generate the past_key_values list using the max index
unflattened["past_key_values"] = [[dynamic_shapes, dynamic_shapes] for _ in range(max_idx + 1)]

Check warning

Code scanning / lintrunner

PYLINT/W0631 Warning

Using possibly undefined loop variable 'dynamic_shapes' (undefined-loop-variable)
See undefined-loop-variable.

Check warning

Code scanning / lintrunner

PYLINT/W0631 Warning

Using possibly undefined loop variable 'dynamic_shapes' (undefined-loop-variable)
See undefined-loop-variable.
return unflattened


def get_model_dummy_input(model_name: str, task: str, **kwargs) -> Optional[Dict]:
Expand Down
60 changes: 23 additions & 37 deletions olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,11 @@ def _export_pytorch_model(
# The "legacy dynamo" is the torch.onnx_dynamo_export API
legacy_dynamo_supported_version = version.parse("2.2.0").release
# The new "dynamo" api is torch.onnx.export with dynamo=True
# TODO(#1478): Change 2.6.0 back to 2.5.0 when dynamic_shapes are supported in Olive
dynamo_supported_version = version.parse("2.6.0").release
if torch_version < legacy_dynamo_supported_version:
raise ImportError(
f"torch.onnx.dynamo_export is not available for torch version {torch_version}. "
"Please upgrade your torch version to 2.5.0 or above."
f"torch.onnx.export(..., dynamo=True) is not available for torch version {torch_version}. "
"Please upgrade your torch version to 2.6.0 or above."
)
from torch._dynamo import config as dynamo_config

Expand Down Expand Up @@ -597,14 +596,7 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs):

def is_dict_axes(x) -> bool:
return isinstance(x, dict) and all(
isinstance(key, str)
and len(key) == 1
and isinstance(value, list)
and len(value) == 3
and isinstance(value[0], str)
and isinstance(value[1], int)
and isinstance(value[2], int)
for key, value in x.items()
isinstance(key, (str, int)) and isinstance(value, str) for key, value in x.items()
)

flat_dynamic_shapes, _ = _pytree.tree_flatten(dynamic_shapes, is_leaf=is_dict_axes)
Expand All @@ -615,15 +607,20 @@ def is_dict_axes(x) -> bool:
continue
new_axes = {}
for axis, dynamic_shape in axes.items():
new_axes[int(axis)] = dynamic_shape
try:
axis_int = int(axis)
except ValueError as e:
raise ValueError(
f"Please check dynamic_shapes in configuration. Invalid axis {axis}. The axis should be an integer."
) from e
new_axes[axis_int] = dynamic_shape
new_dynamic_shapes.append(new_axes)

_, tree_structure = _pytree.tree_flatten(dummy_inputs, is_leaf=is_dict_axes)
return _pytree.tree_unflatten(new_dynamic_shapes, tree_structure)


def _convert_dynamic_shapes_to_torch_export_dims(
dynamic_shapes: Dict[str, Dict[int, torch.export.Dim]]
dynamic_shapes: Dict[str, Dict[int, torch.export.Dim]],
) -> Dict[str, Dict[int, torch.export.Dim]]:
"""Convert dynamic_shapes to torch export dims.

Expand All @@ -634,8 +631,12 @@ def _convert_dynamic_shapes_to_torch_export_dims(

For a single axis:

before: ["axis_name", min_value, max_value]
after: torch.export.Dim("axis_name", min=min_value, max=max_value)
before: {0: "batch_size"}
after: {0: torch.export.Dim.AUTO)

NOTE: The user specified dim name is not respected at the moment due to
technical issue in torch.export. Follow up
https://github.com/pytorch/pytorch/issues/144273

# Please check `dynamic_shapes` in torch.export.export
# https://pytorch.org/docs/stable/export.html#torch.export.export
Expand All @@ -646,9 +647,6 @@ def _convert_dynamic_shapes_to_torch_export_dims(
if dynamic_shapes is None:
return None

# If the axes has the same name, they should be the same torch.export.Dim
torch_export_dim_farm: Dict[str, torch.export.Dim] = {}

# dynamic_shapes follows input format, which could be nested
def _from_tuple_to_dim(data: Union[Dict, List, Tuple, Any]) -> Union[Dict, List, Tuple, Any]:
if isinstance(data, dict):
Expand All @@ -657,26 +655,14 @@ def _from_tuple_to_dim(data: Union[Dict, List, Tuple, Any]) -> Union[Dict, List,
# TODO(titaiwang): Can we use `dummy_inputs` to align the dynamic_shapes format?
# JSON foramt does not accept tuple.
elif isinstance(data, (tuple, list)):
# We assume the tuple/list is in the format of (name, min, max)
# TODO(titaiwang): This format could potentially be used as model
# inputs (would string be used as model input?)
if len(data) == 3 and isinstance(data[0], str) and isinstance(data[1], int) and isinstance(data[2], int):
if data[0] in torch_export_dim_farm:
if torch_export_dim_farm[data[0]].min == data[1] and torch_export_dim_farm[data[0]].max == data[2]:
return torch_export_dim_farm[data[0]]
raise ValueError(
f"Found different boundary for the same axis name {data[0]}. "
f"Previous min: {torch_export_dim_farm[data[0]].min} and "
f"max: {torch_export_dim_farm[data[0]].max}. "
f"Current min: {data[1]} and max: {data[2]}."
)
dim = torch.export.Dim(data[0], min=data[1], max=data[2])
torch_export_dim_farm[data[0]] = dim
return dim
if isinstance(data, tuple):
return tuple(_from_tuple_to_dim(item) for item in data)
if isinstance(data, list):
return [_from_tuple_to_dim(item) for item in data]
return [_from_tuple_to_dim(item) for item in data]
elif isinstance(data, str):
# NOTE: AUTO does not guarantee the axis will be dynamic, but
# it is the only option that not only avoid crashing during
# export, but also does not require user to specify min/max.
return torch.export.Dim.AUTO
return data

return _from_tuple_to_dim(dynamic_shapes)
23 changes: 10 additions & 13 deletions olive/passes/pytorch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,6 @@ def inherit_pytorch_from_hf(
hf_io_config = deepcopy(model.io_config)
hf_dummy_inputs = model.get_dummy_inputs()

dynamic_shapes = hf_io_config.get("dynamic_shapes", {})
if isinstance(dynamic_shapes, dict):
{
k: v
for k, v in hf_io_config.get("dynamic_axes", {}).items()
if not k.startswith(("present", "past_key_values"))
}
else:
# TODO(titaiwang): fix this when we have a better way to handle dynamic_shapes
# If the dynamic_shapes is a list, we don't inherit it since
# we do not know the exact index of the past_key_values in the list
dynamic_shapes = {}
# kv cache will be handled by the kv_cache flag in io_config
io_config = {
"input_names": [i for i in hf_io_config.get("input_names", []) if not i.startswith("past_key_values")],
Expand All @@ -74,7 +62,6 @@ def inherit_pytorch_from_hf(
for k, v in hf_io_config.get("dynamic_axes", {}).items()
if not k.startswith(("present", "past_key_values"))
},
"dynamic_shapes": dynamic_shapes,
}

for i_name in io_config["input_names"]:
Expand All @@ -84,6 +71,16 @@ def inherit_pytorch_from_hf(
if io_config and not io_config.get("kv_cache") and model.task.endswith("-with-past"):
io_config["kv_cache"] = True

# dynamic_shapes deals with kv_cache here. If kv_cache is False,
# we remove past_key_values from dynamic_shapes
if not io_config.get("kv_cache", False):
dynamic_shapes = {
k: v for k, v in hf_io_config.get("dynamic_shapes", {}).items() if not k.startswith("past_key_values")
}
else:
dynamic_shapes = hf_io_config.get("dynamic_shapes", {})
io_config["dynamic_shapes"] = dynamic_shapes

return PyTorchModelHandler(
model_path=model_path,
model_file_format=model_file_format,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
numpy
onnx
onnxscript
optuna
pandas
protobuf<4.0.0
Expand Down
2 changes: 0 additions & 2 deletions test/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ onnxconverter_common
onnxmltools
onnxoptimizer
onnxruntime_extensions
# TODO(titaiwai): Add onnxscript to requirements.txt once it's released
onnxscript
openvino==2023.2.0
optimum>=1.17.0
pandas
Expand Down
6 changes: 3 additions & 3 deletions test/unit_test/model/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def setup(self):
"token_type_ids": {"0": "batch_size", "1": "seq_length"},
},
"dynamic_shapes": {
"input_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
"attention_mask": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
"token_type_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
"input_ids": {"0": "batch_size", "1": "seq_length"},
"attention_mask": {"0": "batch_size", "1": "seq_length"},
"token_type_ids": {"0": "batch_size", "1": "seq_length"},
},
}

Expand Down
6 changes: 3 additions & 3 deletions test/unit_test/model/test_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def io_config_fixture():
"token_type_ids": {"0": "batch_size", "1": "seq_length"},
},
"dynamic_shapes": {
"input_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
"attention_mask": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
"token_type_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]},
"input_ids": {"0": "batch_size", "1": "seq_length"},
"attention_mask": {"0": "batch_size", "1": "seq_length"},
"token_type_ids": {"0": "batch_size", "1": "seq_length"},
},
}

Expand Down
33 changes: 16 additions & 17 deletions test/unit_test/passes/onnx/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# --------------------------------------------------------------------------
import platform
import shutil
import sys
from itertools import chain
from pathlib import Path
from test.unit_test.utils import ONNX_MODEL_PATH, get_hf_model, get_onnx_model, get_pytorch_model, pytorch_model_loader
Expand All @@ -21,7 +20,7 @@
from olive.passes.onnx.conversion import OnnxConversion, OnnxOpVersionConversion


@pytest.mark.skipif(sys.version_info > (3, 8), reason="Failed with Python 3.10, need to investigate.")
# @pytest.mark.skipif(sys.version_info > (3, 8), reason="Failed with Python 3.10, need to investigate.")
@pytest.mark.parametrize(
("input_model", "use_dynamo_exporter"),
[(get_pytorch_model(), True), (get_hf_model(), True), (get_pytorch_model(), False), (get_hf_model(), False)],
Expand Down Expand Up @@ -189,10 +188,10 @@ def mock_onnx_export_func(*args, **kwargs):
@pytest.mark.parametrize(
"dynamic_shapes",
[
[{"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024], "1": ["y_axis", 0, 6]}],
[{"0": "axis_batch", "1": "x_axis"}, {"0": "axis_batch", "1": "y_axis"}],
{
"input_x": {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]},
"input_y": {"0": ["axis_batch", 0, 1024], "1": ["y_axis", 0, 6]},
"input_x": {"0": "axis_batch", "1": "x_axis"},
"input_y": {"0": "axis_batch", "1": "y_axis"},
},
],
)
Expand Down Expand Up @@ -224,30 +223,30 @@ def _get_simulate_torch_float_tensor_inputs(return_tuple: bool = False):
[
(
[
{"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]},
[{"1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024]}],
{"a": {"0": ["axis_batch", 0, 1024]}, "b": {"1": ["x_axis", 0, 8]}},
{"0": "axis_batch", "1": "x_axis"},
[{"1": "x_axis"}, {"0": "axis_batch"}],
{"a": {"0": "axis_batch"}, "b": {"1": "x_axis"}},
None,
],
(
{0: ["axis_batch", 0, 1024], 1: ["x_axis", 0, 8]},
({1: ["x_axis", 0, 8]}, {0: ["axis_batch", 0, 1024]}),
{"a": {0: ["axis_batch", 0, 1024]}, "b": {1: ["x_axis", 0, 8]}},
{0: "axis_batch", 1: "x_axis"},
({1: "x_axis"}, {0: "axis_batch"}),
{"a": {0: "axis_batch"}, "b": {1: "x_axis"}},
None,
),
_get_simulate_torch_float_tensor_inputs(return_tuple=True),
),
(
{
"w": {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]},
"x": [{"1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024]}],
"y": {"a": {"0": ["axis_batch", 0, 1024]}, "b": {"1": ["x_axis", 0, 8]}},
"w": {"0": "axis_batch", "1": "x_axis"},
"x": [{"1": "x_axis"}, {"0": "axis_batch"}],
"y": {"a": {"0": "axis_batch"}, "b": {"1": "x_axis"}},
"z": None,
},
{
"w": {0: ["axis_batch", 0, 1024], 1: ["x_axis", 0, 8]},
"x": ({1: ["x_axis", 0, 8]}, {0: ["axis_batch", 0, 1024]}),
"y": {"a": {0: ["axis_batch", 0, 1024]}, "b": {1: ["x_axis", 0, 8]}},
"w": {0: "axis_batch", 1: "x_axis"},
"x": ({1: "x_axis"}, {0: "axis_batch"}),
"y": {"a": {0: "axis_batch"}, "b": {1: "x_axis"}},
"z": None,
},
_get_simulate_torch_float_tensor_inputs(return_tuple=False),
Expand Down
Loading