Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: torch: fix result_type with python scalars #277

Merged
merged 2 commits into from
Mar 20, 2025

Conversation

ev-br
Copy link
Member

@ev-br ev-br commented Mar 17, 2025

fixes gh-273,fixes gh-274

cross-ref data-apis/array-api-tests#349 which adds tests.

@ev-br
Copy link
Member Author

ev-br commented Mar 17, 2025

cc @mdhaber

return torch.result_type(x, y)
if isinstance(x, _py_scalars):
if isinstance(y, _py_scalars):
raise ValueError("At least one array or dtype is required.")
Copy link

@mdhaber mdhaber Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding the recursion correctly, won't this raise in cases where it just hasn't gotten to the array/dtype yet? E.g. it looks like it would fail for xp.result_type(xp.float64, 1, 2) because neither of the last two arguments is an array or dtype?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but you must fundamentally change the logic here anyway. While torch avoids tricky promotions (unlike NumPy) and promotions are associative because of that, you are currently breaking that by adding scalar promotion rules.
That isn't hard to fix, because you can fix it by fixing the binary promotion to:

promote(pythong_float, torch_integer) -> python_float

I.e. the recursive part of the algorithm must drag around python_float as a proper dtype. Only at the very end can you convert it to the default float.

Copy link
Member Author

@ev-br ev-br Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that one way or other, we need to add an arbitrary rule, and we cannot fully satisfy all of Array API promotion rules, pytorch promotion rules and associativity, how about always sorting the inputs to put all scalars to the left. This way, the pytorch addition only kicks up in a binary promotion ("a conforming library may support additional type promotion rules", so pytorch does---for binary promotions only), and we definitely have associativity. Thoughts?

EDIT (fat-fingered, sorry): so if we do this, we have (1) the order of inputs does not matter; (2) for Array API conforming promotions, the result is Array API compatible: python scalars do not influence the result; (3) pytorch additional promotions do get used in a well-defined way; and (4) the additional rule is fully explained in one sentence.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the first premise of needing an arbitrary rule or that you can't satisfy everything...

... but it doesn't matter? Currently, torch (by its rules) is associative, I believe. The only odd one out is the custom behavior for float due to int_dtype + float -> default_float being wrong (breaking associativity).

And sure, since that is all you need to deal with, you can deal with it by ensuring Python scalar promotions happen last.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, torch.result_type is a binary function, it does not accept longer sequences of arguments. Thus there is no associativity in torch proper, and we need to come up with a rule.

That said, if enforcing the order does not sound crazy, let's roll with it---the last commit in this PR.

PS. There is a PR in -test, https://github.com/data-apis/array-api-tests/pull/349/files#diff-5f3973d098b702a92ef5d67d9cc03bfb452de213a23478d541746f32dbc8023dR219, which checks that result_type is associative. Are you saying that numpy does not guarantee it? (I only ran it with 10_000 hypothesis examples a couple of times).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole block looks wrong for: result_type(1., xp.int64) == xp.int64. You should just revert the whole thing to the old code, I think.

Maybe since that is "unspecified" and because of that completely untested?

I have also no idea why the code after this doesn't use torch.promote_types...

return torch.result_type(x, y)
if isinstance(x, _py_scalars):
if isinstance(y, _py_scalars):
raise ValueError("At least one array or dtype is required.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole block looks wrong for: result_type(1., xp.int64) == xp.int64. You should just revert the whole thing to the old code, I think.

Maybe since that is "unspecified" and because of that completely untested?

I have also no idea why the code after this doesn't use torch.promote_types...

Copy link

@seberg seberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry had some other comments, meant to make a single "review".

(The point about being associative is, was that for the binary operation result_type(a, result_type(b, c)) == result_type(result_type(a, b), c), if you ignore python float.)

@ev-br ev-br force-pushed the torch_result_type branch 2 times, most recently from 8c7fe4e to fe83ca7 Compare March 19, 2025 10:25
@ev-br
Copy link
Member Author

ev-br commented Mar 19, 2025

Okay, third time the charm, hopefully.

Now that I look closely, torch in fact has a result_type(Number, Number) overload. So there's no need for sorting etc, just some gymnastics to map from *arrays_and_dtypes_or_scalars to torch's Tensor | Number arguments.

Maybe since that is "unspecified" and because of that completely untested?

Exactly.

Which is why previous iterations attempted to only add minimal changes on top of what already was effectively tested by usage. That did not work, so here's a (still minimal) rework + some light smoke-testing of "unspecified" logic.

@seberg
Copy link

seberg commented Mar 19, 2025

Looks like nice code improvements, but:

Now that I look closely, torch in fact has a result_type(Number, Number) overload. So there's no need for sorting

is incorrect. In making everything nicer, the current code doesn't actually fix the original issue :)?!
The point is that the result_type(int_tensor, 1.0) will return float64 which breaks associativity of _result_type.

You still need the sorting if you want order to not matter and you should maybe add a test for that:

    result_type([1., 5, 3, float16(3), 5, 6, 1.]) == float16

(or stay within that float64 default dtype context and use float32)

@ev-br
Copy link
Member Author

ev-br commented Mar 19, 2025

Well, I (now) disagree it's incorrect.

Let's start from the beginning:

a. torch allows binary result_type promotions between ints and floats
b. torch does not allow result_type(*args) with len(args) > 2 at all.

We want a result_type, which is

  1. allows len(args) > 2
  2. consistent with the spec for spec-defined promotions
  3. consistent with torch-specific additional promotions.

Note that the spec does not say anything about associativity, not explicitly anyway.
Given (a) and (b), we need to define the order of evaluations for 1.

Normally, we would not make any decisions at the level of array-api-compat. Here we are forced to, because the spec mandates things which are not in pytorch itself.

If we can define it so that the order does not matter (associativity), great. If we define some other order --- as long as it's defined and is consistent with 2. --- I'd say it's fine.

This PR as is defines result_type(*args) via functools.reduce from left to right. This does break associativity, yes for things which are not defined in either Array API spec or pytorch.

If we want to enforce associativity, let's define the order or evaluations and make sure it is indeed ensures associativity (not sure how TBH, given that we're limited in what we can test).

@ev-br
Copy link
Member Author

ev-br commented Mar 19, 2025

The only reasonable update to the rule I see now could be "1. sort the arguments so that scalars are after the array/dtype arguments; 2. combine arguments pairwise from right to left". This will fix an annoying edge case:

In [27]: f16 = xp.ones(1, dtype=xp.float16)

In [28]: xp.result_type(f16, 1.0, 1.0)
Out[28]: torch.float16

In [29]: xp.result_type(1.0, 1.0, f16)      # ouch
Out[29]: torch.float32

@seberg
Copy link

seberg commented Mar 19, 2025

Yeah, I agree, you can choose not to guarantee that order doesn't matter! I am pretty sure that torch promotion is associative, so that it isn't order dependent (besides for Python float).

I thought that was the original issue/goal here to make order not matter in this context :).

@ev-br
Copy link
Member Author

ev-br commented Mar 19, 2025

Yes, it started with this goal in mind, but then it became clear that previously result_type was just broken for python scalars, and the minimum goal is to have some version consistent with the 2024.12 spec :-). If we can make it associative, great: downstream UX is definitely better if it is.
To this end: is numpy's result_type guaranteed to be associative?

@seberg
Copy link

seberg commented Mar 19, 2025

NumPy's result type is guarantee to not be order dependent (unless you add custom dtypes that break this by not being careful).
(The binary promotion operation in NumPy itself is not associative.)

@ev-br
Copy link
Member Author

ev-br commented Mar 19, 2025

Okay, thanks for clarifying it! Let's try to follow NumPy here and add sorting back.

@ev-br ev-br force-pushed the torch_result_type branch 3 times, most recently from 65b3e2b to 7c9e572 Compare March 19, 2025 15:29
@ev-br
Copy link
Member Author

ev-br commented Mar 19, 2025

I'm going to skip tests/test_all.py::test_all for now. The test starts failing on CI (but not locally) when I add an unrelated test module (tests/test_torch.py); the test itself does something rather strange with direct access to sys.modules, and loops over private implementation modules. I frankly cannot tell what the intention is. Why do we care if common/_helpers.py has matching __all__ and dir()? Also it picks up common/_alises.py on CI but misses it locally.

If the intention is to check that high-level modules have correct __all__ lists, then that's what the test should proble, I suppose.

@ev-br
Copy link
Member Author

ev-br commented Mar 19, 2025

Test weirdness aside, I think the current version is as good as it gets. Does this look reasonable to you @seberg?
Would be great if you could test it @mdhaber .
Unless somebody sees further problems, I'll keep it open for a short while, then merge and proceed.

@seberg
Copy link

seberg commented Mar 19, 2025

Yes, LGTM. I can always find more nits if you like though ;p.
(Behavior nit: result_type(1, 1) doesn't raise, but result_type(1, 1, 1) does.)

@ev-br
Copy link
Member Author

ev-br commented Mar 19, 2025

(Behavior nit: result_type(1, 1) doesn't raise, but result_type(1, 1, 1) does.)

Yeah. We cannot prohibit the former because torch allows it; allowing the latter goes against "no default dtype" stance as I understand it --- heck, even where(x > 0, 1, 2) is not allowed with two scalars.
So I'd rather block it, at least until pytorch defines result_type(1, 1, 1) if it ever does.
This is actually what I meant in an "arbitrary rules" in #277 (comment).

I can always find more nits if you like though ;p.

Nope, PR review graphs are directed and nits only flow along one direction :-).

@ev-br ev-br force-pushed the torch_result_type branch from 7c9e572 to c9081d6 Compare March 19, 2025 18:32
ev-br added 2 commits March 19, 2025 19:42
1. Allow inputs to be arrays or dtypes or python scalars
2. Keep the pytorch-specific additions, e.g.
   `result_type(int, float) -> float`, `result_type(scalar, scalar) -> dtype`
   which are unspecified in the standard
3. Since pytorch only defines a binary `result_type` function, add a version
   with multiple inputs.

The latter is a bit tricky because we want to
- keep allowing "unspecified" behaviors
- keep standard-allowed promotions compliant
- (preferably) make result_type independent on the argument order

The latter is important because of `int,float->float` promotions which
break associativity.

So what we do, we always promote all scalars after all array/dtype arguments.
@ev-br ev-br force-pushed the torch_result_type branch from c9081d6 to 5473d84 Compare March 19, 2025 18:42
@mdhaber
Copy link

mdhaber commented Mar 19, 2025

Can we just merge it? Then I can just change the commit of array-api-compat that SciPy is using and see whether it passes tests.

@ev-br
Copy link
Member Author

ev-br commented Mar 20, 2025

Okay, this is okay'd by Sebastian, and has been tested via data-apis/array-api-tests#349 locally, so hopefully it just works.
Merging.
Thanks @seberg for the (multiple-stage) review!

@ev-br ev-br merged commit 18fbec4 into data-apis:main Mar 20, 2025
40 checks passed
@ev-br ev-br added this to the 1.12 milestone Mar 20, 2025
@mdhaber
Copy link

mdhaber commented Mar 26, 2025

I think this did the trick! I am still having one problem that I thought was due to gh-273, but maybe not - I'll have to take a closer look. Thanks!

@mdhaber
Copy link

mdhaber commented Mar 27, 2025

The one remaining problem was a different issue, but I found another: if torch.result_type with two Python scalar arguments doesn't raise, and we get an unexpected type of error if there is only one Python scalar argument.

from array_api_compat import torch as xp

xp.result_type(1)
# AttributeError: 'int' object has no attribute 'dtype'

xp.result_type(1, 1)
# torch.int64

xp.result_type(1, 1, 1)
# ValueError: At least one array or dtype must be provided

While this behavior is not governed by the standard, it would be helpful for testing purposes if it were self-consistent.

@ev-br
Copy link
Member Author

ev-br commented Mar 28, 2025

Ah yes, this was noted by Sebastian: #277 (comment)
I stand by #277 (comment) : we cannot make it truly consistent, sadly.

For a one-argument version, turning the AttributeError into a ValueError is totally possible of course; PR welcome if you feel strongly enough!

@seberg
Copy link

seberg commented Mar 28, 2025

I think you could check for result_type(1, 1) and raise an error? If that is less confusing.

@ev-br
Copy link
Member Author

ev-br commented Mar 28, 2025

That would make a precedent for array-api-compat raising on operations which are unspecified in the spec but allowed by the array provider.

And I do believe that the binary result_type(x, y) is way more common in practice than result_type(x, y, z, ....), so from a pragmatic POV the status quo is a lesser evil.

@mdhaber
Copy link

mdhaber commented Mar 28, 2025

Maybe we need an array-api-compat-compat that does this sort of thing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants