Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.export.Dim.AUTO in ONNX conversion pass #1586

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

titaiwangms
Copy link
Contributor

Describe your changes

Checklist before requesting a review

  • Add unit tests for this change.
  • Make sure all tests can pass.
  • Update documents if necessary.
  • Lint and apply fixes to your code by running lintrunner -a
  • Is this a user-facing change? If yes, give a description of this change to be included in the release notes.
  • Is this PR including examples changes? If yes, please remember to update example documentation in a follow-up PR.

(Optional) Issue link

for input_name, dynamic_shapes in flattened_inputs.items():
# Replace flattened past_key_values with unflattened past_key_values
if not input_name.startswith("past_key_values"):
unflattened[input_name] = dynamic_shapes

Check warning

Code scanning / lintrunner

RUFF/PERF403 Warning

Use a dictionary comprehension instead of a for-loop.
See https://docs.astral.sh/ruff/rules/manual-dict-comprehension
# https://github.com/huggingface/optimum/blob/b755036ae12e0959d61085e597e7b96473c4b46d/optimum/exporters/onnx/base.py#L629
# past_key_values is a list of lists, and it locates at the end of the input list/dict
# Generate the past_key_values list using the max index
unflattened["past_key_values"] = [[dynamic_shapes, dynamic_shapes] for _ in range(max_idx + 1)]

Check warning

Code scanning / lintrunner

PYLINT/W0631 Warning

Using possibly undefined loop variable 'dynamic_shapes' (undefined-loop-variable)
See undefined-loop-variable.
# https://github.com/huggingface/optimum/blob/b755036ae12e0959d61085e597e7b96473c4b46d/optimum/exporters/onnx/base.py#L629
# past_key_values is a list of lists, and it locates at the end of the input list/dict
# Generate the past_key_values list using the max index
unflattened["past_key_values"] = [[dynamic_shapes, dynamic_shapes] for _ in range(max_idx + 1)]

Check warning

Code scanning / lintrunner

PYLINT/W0631 Warning

Using possibly undefined loop variable 'dynamic_shapes' (undefined-loop-variable)
See undefined-loop-variable.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant