From 63bf1c2b564bd4c4df0ec78c53e1f1b69f4cf410 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang <titaiwang@microsoft.com> Date: Thu, 30 Jan 2025 02:28:24 +0000 Subject: [PATCH] support torch.export.Dim.AUTO --- olive/common/hf/model_io.py | 51 ++++++++++++++-- olive/passes/onnx/conversion.py | 60 +++++++------------ olive/passes/pytorch/common.py | 23 ++++--- requirements.txt | 1 + test/requirements-test.txt | 2 - test/unit_test/model/test_hf_model.py | 6 +- test/unit_test/model/test_pytorch_model.py | 6 +- test/unit_test/passes/onnx/test_conversion.py | 33 +++++----- 8 files changed, 102 insertions(+), 80 deletions(-) diff --git a/olive/common/hf/model_io.py b/olive/common/hf/model_io.py index 5a786458a..5d6690731 100644 --- a/olive/common/hf/model_io.py +++ b/olive/common/hf/model_io.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging from itertools import chain -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from olive.common.hf.mlflow import get_pretrained_name_or_path from olive.common.hf.peft import is_peft_model @@ -124,10 +124,51 @@ def get_model_io_config(model_name: str, task: str, model: "PreTrainedModel", ** for axis, axis_name in value.items(): if axis_name == "past_sequence_length + 1": value[axis] = "past_sequence_length + sequence_length" - # NOTE: Due to the complexity of dynamic_shapes, we don't provide it here. - # torch-onnx converter has a naive approach to auto-gen dynamic shapes based on input and - # dynamic_axes, so we don't need to provide dynamic shapes here. - return {"input_names": input_names, "output_names": output_names, "dynamic_axes": dynamic_axes} + # dynamic_shapes should follow input order and format + dynamic_shapes = _unflatten_past_key_values_with_check(inputs) + return { + "input_names": input_names, + "output_names": output_names, + "dynamic_axes": dynamic_axes, + "dynamic_shapes": dynamic_shapes, + } + + +def _unflatten_past_key_values_with_check(flattened_inputs: Dict[str, Any]) -> Dict[str, Any]: + max_idx = -1 + past_key_value_count = 0 # Track number of key-value pairs + + # Find the max index for generating unflatten past_key_values later + # and record the total number of past_key_values entries for validation + for input_name in flattened_inputs: + if input_name.startswith("past_key_values"): + # From Optimum: past_key_values.0.key, past_key_values.0.value, + # past_key_values.1.key, past_key_values.1.value, ... + idx = int(input_name.split(".")[1]) + max_idx = max(max_idx, idx) + past_key_value_count += 1 + + # Check if we have exactly 2 * (max_idx + 1) key-value pairs + expected_count = 2 * (max_idx + 1) + if past_key_value_count != expected_count or past_key_value_count % 2 != 0: + raise ValueError( + f"Expected {expected_count} past_key_values entries, but found {past_key_value_count} from Optimum inputs." + ) + if max_idx == -1: + # No past_key_values found + return flattened_inputs + + unflattened = {} + for input_name, dynamic_shapes in flattened_inputs.items(): + # Replace flattened past_key_values with unflattened past_key_values + if not input_name.startswith("past_key_values"): + unflattened[input_name] = dynamic_shapes + # Based on Optimum's implementation: + # https://github.com/huggingface/optimum/blob/b755036ae12e0959d61085e597e7b96473c4b46d/optimum/exporters/onnx/base.py#L629 + # past_key_values is a list of lists, and it locates at the end of the input list/dict + # Generate the past_key_values list using the max index + unflattened["past_key_values"] = [[dynamic_shapes, dynamic_shapes] for _ in range(max_idx + 1)] + return unflattened def get_model_dummy_input(model_name: str, task: str, **kwargs) -> Optional[Dict]: diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py index a32dc88cf..f808803a5 100644 --- a/olive/passes/onnx/conversion.py +++ b/olive/passes/onnx/conversion.py @@ -212,12 +212,11 @@ def _export_pytorch_model( # The "legacy dynamo" is the torch.onnx_dynamo_export API legacy_dynamo_supported_version = version.parse("2.2.0").release # The new "dynamo" api is torch.onnx.export with dynamo=True - # TODO(#1478): Change 2.6.0 back to 2.5.0 when dynamic_shapes are supported in Olive dynamo_supported_version = version.parse("2.6.0").release if torch_version < legacy_dynamo_supported_version: raise ImportError( - f"torch.onnx.dynamo_export is not available for torch version {torch_version}. " - "Please upgrade your torch version to 2.5.0 or above." + f"torch.onnx.export(..., dynamo=True) is not available for torch version {torch_version}. " + "Please upgrade your torch version to 2.6.0 or above." ) from torch._dynamo import config as dynamo_config @@ -597,14 +596,7 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs): def is_dict_axes(x) -> bool: return isinstance(x, dict) and all( - isinstance(key, str) - and len(key) == 1 - and isinstance(value, list) - and len(value) == 3 - and isinstance(value[0], str) - and isinstance(value[1], int) - and isinstance(value[2], int) - for key, value in x.items() + isinstance(key, (str, int)) and isinstance(value, str) for key, value in x.items() ) flat_dynamic_shapes, _ = _pytree.tree_flatten(dynamic_shapes, is_leaf=is_dict_axes) @@ -615,15 +607,20 @@ def is_dict_axes(x) -> bool: continue new_axes = {} for axis, dynamic_shape in axes.items(): - new_axes[int(axis)] = dynamic_shape + try: + axis_int = int(axis) + except ValueError as e: + raise ValueError( + f"Please check dynamic_shapes in configuration. Invalid axis {axis}. The axis should be an integer." + ) from e + new_axes[axis_int] = dynamic_shape new_dynamic_shapes.append(new_axes) - _, tree_structure = _pytree.tree_flatten(dummy_inputs, is_leaf=is_dict_axes) return _pytree.tree_unflatten(new_dynamic_shapes, tree_structure) def _convert_dynamic_shapes_to_torch_export_dims( - dynamic_shapes: Dict[str, Dict[int, torch.export.Dim]] + dynamic_shapes: Dict[str, Dict[int, torch.export.Dim]], ) -> Dict[str, Dict[int, torch.export.Dim]]: """Convert dynamic_shapes to torch export dims. @@ -634,8 +631,12 @@ def _convert_dynamic_shapes_to_torch_export_dims( For a single axis: - before: ["axis_name", min_value, max_value] - after: torch.export.Dim("axis_name", min=min_value, max=max_value) + before: {0: "batch_size"} + after: {0: torch.export.Dim.AUTO) + + NOTE: The user specified dim name is not respected at the moment due to + technical issue in torch.export. Follow up + https://github.com/pytorch/pytorch/issues/144273 # Please check `dynamic_shapes` in torch.export.export # https://pytorch.org/docs/stable/export.html#torch.export.export @@ -646,9 +647,6 @@ def _convert_dynamic_shapes_to_torch_export_dims( if dynamic_shapes is None: return None - # If the axes has the same name, they should be the same torch.export.Dim - torch_export_dim_farm: Dict[str, torch.export.Dim] = {} - # dynamic_shapes follows input format, which could be nested def _from_tuple_to_dim(data: Union[Dict, List, Tuple, Any]) -> Union[Dict, List, Tuple, Any]: if isinstance(data, dict): @@ -657,26 +655,14 @@ def _from_tuple_to_dim(data: Union[Dict, List, Tuple, Any]) -> Union[Dict, List, # TODO(titaiwang): Can we use `dummy_inputs` to align the dynamic_shapes format? # JSON foramt does not accept tuple. elif isinstance(data, (tuple, list)): - # We assume the tuple/list is in the format of (name, min, max) - # TODO(titaiwang): This format could potentially be used as model - # inputs (would string be used as model input?) - if len(data) == 3 and isinstance(data[0], str) and isinstance(data[1], int) and isinstance(data[2], int): - if data[0] in torch_export_dim_farm: - if torch_export_dim_farm[data[0]].min == data[1] and torch_export_dim_farm[data[0]].max == data[2]: - return torch_export_dim_farm[data[0]] - raise ValueError( - f"Found different boundary for the same axis name {data[0]}. " - f"Previous min: {torch_export_dim_farm[data[0]].min} and " - f"max: {torch_export_dim_farm[data[0]].max}. " - f"Current min: {data[1]} and max: {data[2]}." - ) - dim = torch.export.Dim(data[0], min=data[1], max=data[2]) - torch_export_dim_farm[data[0]] = dim - return dim if isinstance(data, tuple): return tuple(_from_tuple_to_dim(item) for item in data) - if isinstance(data, list): - return [_from_tuple_to_dim(item) for item in data] + return [_from_tuple_to_dim(item) for item in data] + elif isinstance(data, str): + # NOTE: AUTO does not guarantee the axis will be dynamic, but + # it is the only option that not only avoid crashing during + # export, but also does not require user to specify min/max. + return torch.export.Dim.AUTO return data return _from_tuple_to_dim(dynamic_shapes) diff --git a/olive/passes/pytorch/common.py b/olive/passes/pytorch/common.py index 246635c98..9e1fcb3f9 100644 --- a/olive/passes/pytorch/common.py +++ b/olive/passes/pytorch/common.py @@ -51,18 +51,6 @@ def inherit_pytorch_from_hf( hf_io_config = deepcopy(model.io_config) hf_dummy_inputs = model.get_dummy_inputs() - dynamic_shapes = hf_io_config.get("dynamic_shapes", {}) - if isinstance(dynamic_shapes, dict): - { - k: v - for k, v in hf_io_config.get("dynamic_axes", {}).items() - if not k.startswith(("present", "past_key_values")) - } - else: - # TODO(titaiwang): fix this when we have a better way to handle dynamic_shapes - # If the dynamic_shapes is a list, we don't inherit it since - # we do not know the exact index of the past_key_values in the list - dynamic_shapes = {} # kv cache will be handled by the kv_cache flag in io_config io_config = { "input_names": [i for i in hf_io_config.get("input_names", []) if not i.startswith("past_key_values")], @@ -74,7 +62,6 @@ def inherit_pytorch_from_hf( for k, v in hf_io_config.get("dynamic_axes", {}).items() if not k.startswith(("present", "past_key_values")) }, - "dynamic_shapes": dynamic_shapes, } for i_name in io_config["input_names"]: @@ -84,6 +71,16 @@ def inherit_pytorch_from_hf( if io_config and not io_config.get("kv_cache") and model.task.endswith("-with-past"): io_config["kv_cache"] = True + # dynamic_shapes deals with kv_cache here. If kv_cache is False, + # we remove past_key_values from dynamic_shapes + if not io_config.get("kv_cache", False): + dynamic_shapes = { + k: v for k, v in hf_io_config.get("dynamic_shapes", {}).items() if not k.startswith("past_key_values") + } + else: + dynamic_shapes = hf_io_config.get("dynamic_shapes", {}) + io_config["dynamic_shapes"] = dynamic_shapes + return PyTorchModelHandler( model_path=model_path, model_file_format=model_file_format, diff --git a/requirements.txt b/requirements.txt index 786904514..54b97185f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ numpy onnx +onnxscript optuna pandas protobuf<4.0.0 diff --git a/test/requirements-test.txt b/test/requirements-test.txt index f2d073aef..70dd4dd0c 100644 --- a/test/requirements-test.txt +++ b/test/requirements-test.txt @@ -29,8 +29,6 @@ onnxconverter_common onnxmltools onnxoptimizer onnxruntime_extensions -# TODO(titaiwai): Add onnxscript to requirements.txt once it's released -onnxscript openvino==2023.2.0 optimum>=1.17.0 pandas diff --git a/test/unit_test/model/test_hf_model.py b/test/unit_test/model/test_hf_model.py index bd8ad17b6..343c18b1c 100644 --- a/test/unit_test/model/test_hf_model.py +++ b/test/unit_test/model/test_hf_model.py @@ -145,9 +145,9 @@ def setup(self): "token_type_ids": {"0": "batch_size", "1": "seq_length"}, }, "dynamic_shapes": { - "input_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, - "attention_mask": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, - "token_type_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, + "input_ids": {"0": "batch_size", "1": "seq_length"}, + "attention_mask": {"0": "batch_size", "1": "seq_length"}, + "token_type_ids": {"0": "batch_size", "1": "seq_length"}, }, } diff --git a/test/unit_test/model/test_pytorch_model.py b/test/unit_test/model/test_pytorch_model.py index ac26454e1..4bbdd00d4 100644 --- a/test/unit_test/model/test_pytorch_model.py +++ b/test/unit_test/model/test_pytorch_model.py @@ -23,9 +23,9 @@ def io_config_fixture(): "token_type_ids": {"0": "batch_size", "1": "seq_length"}, }, "dynamic_shapes": { - "input_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, - "attention_mask": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, - "token_type_ids": {"0": ["batch_size", 1, 32], "1": ["seq_length", 1, 256]}, + "input_ids": {"0": "batch_size", "1": "seq_length"}, + "attention_mask": {"0": "batch_size", "1": "seq_length"}, + "token_type_ids": {"0": "batch_size", "1": "seq_length"}, }, } diff --git a/test/unit_test/passes/onnx/test_conversion.py b/test/unit_test/passes/onnx/test_conversion.py index 92a384415..8b7f63f00 100644 --- a/test/unit_test/passes/onnx/test_conversion.py +++ b/test/unit_test/passes/onnx/test_conversion.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import platform import shutil -import sys from itertools import chain from pathlib import Path from test.unit_test.utils import ONNX_MODEL_PATH, get_hf_model, get_onnx_model, get_pytorch_model, pytorch_model_loader @@ -21,7 +20,7 @@ from olive.passes.onnx.conversion import OnnxConversion, OnnxOpVersionConversion -@pytest.mark.skipif(sys.version_info > (3, 8), reason="Failed with Python 3.10, need to investigate.") +# @pytest.mark.skipif(sys.version_info > (3, 8), reason="Failed with Python 3.10, need to investigate.") @pytest.mark.parametrize( ("input_model", "use_dynamo_exporter"), [(get_pytorch_model(), True), (get_hf_model(), True), (get_pytorch_model(), False), (get_hf_model(), False)], @@ -189,10 +188,10 @@ def mock_onnx_export_func(*args, **kwargs): @pytest.mark.parametrize( "dynamic_shapes", [ - [{"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024], "1": ["y_axis", 0, 6]}], + [{"0": "axis_batch", "1": "x_axis"}, {"0": "axis_batch", "1": "y_axis"}], { - "input_x": {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]}, - "input_y": {"0": ["axis_batch", 0, 1024], "1": ["y_axis", 0, 6]}, + "input_x": {"0": "axis_batch", "1": "x_axis"}, + "input_y": {"0": "axis_batch", "1": "y_axis"}, }, ], ) @@ -224,30 +223,30 @@ def _get_simulate_torch_float_tensor_inputs(return_tuple: bool = False): [ ( [ - {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]}, - [{"1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024]}], - {"a": {"0": ["axis_batch", 0, 1024]}, "b": {"1": ["x_axis", 0, 8]}}, + {"0": "axis_batch", "1": "x_axis"}, + [{"1": "x_axis"}, {"0": "axis_batch"}], + {"a": {"0": "axis_batch"}, "b": {"1": "x_axis"}}, None, ], ( - {0: ["axis_batch", 0, 1024], 1: ["x_axis", 0, 8]}, - ({1: ["x_axis", 0, 8]}, {0: ["axis_batch", 0, 1024]}), - {"a": {0: ["axis_batch", 0, 1024]}, "b": {1: ["x_axis", 0, 8]}}, + {0: "axis_batch", 1: "x_axis"}, + ({1: "x_axis"}, {0: "axis_batch"}), + {"a": {0: "axis_batch"}, "b": {1: "x_axis"}}, None, ), _get_simulate_torch_float_tensor_inputs(return_tuple=True), ), ( { - "w": {"0": ["axis_batch", 0, 1024], "1": ["x_axis", 0, 8]}, - "x": [{"1": ["x_axis", 0, 8]}, {"0": ["axis_batch", 0, 1024]}], - "y": {"a": {"0": ["axis_batch", 0, 1024]}, "b": {"1": ["x_axis", 0, 8]}}, + "w": {"0": "axis_batch", "1": "x_axis"}, + "x": [{"1": "x_axis"}, {"0": "axis_batch"}], + "y": {"a": {"0": "axis_batch"}, "b": {"1": "x_axis"}}, "z": None, }, { - "w": {0: ["axis_batch", 0, 1024], 1: ["x_axis", 0, 8]}, - "x": ({1: ["x_axis", 0, 8]}, {0: ["axis_batch", 0, 1024]}), - "y": {"a": {0: ["axis_batch", 0, 1024]}, "b": {1: ["x_axis", 0, 8]}}, + "w": {0: "axis_batch", 1: "x_axis"}, + "x": ({1: "x_axis"}, {0: "axis_batch"}), + "y": {"a": {0: "axis_batch"}, "b": {1: "x_axis"}}, "z": None, }, _get_simulate_torch_float_tensor_inputs(return_tuple=False),