-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: torch: fix result_type
with python scalars
#277
Conversation
cc @mdhaber |
array_api_compat/torch/_aliases.py
Outdated
return torch.result_type(x, y) | ||
if isinstance(x, _py_scalars): | ||
if isinstance(y, _py_scalars): | ||
raise ValueError("At least one array or dtype is required.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm understanding the recursion correctly, won't this raise in cases where it just hasn't gotten to the array/dtype yet? E.g. it looks like it would fail for xp.result_type(xp.float64, 1, 2)
because neither of the last two arguments is an array or dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, but you must fundamentally change the logic here anyway. While torch avoids tricky promotions (unlike NumPy) and promotions are associative because of that, you are currently breaking that by adding scalar promotion rules.
That isn't hard to fix, because you can fix it by fixing the binary promotion to:
promote(pythong_float, torch_integer) -> python_float
I.e. the recursive part of the algorithm must drag around python_float
as a proper dtype. Only at the very end can you convert it to the default float.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that one way or other, we need to add an arbitrary rule, and we cannot fully satisfy all of Array API promotion rules, pytorch promotion rules and associativity, how about always sorting the inputs to put all scalars to the left. This way, the pytorch addition only kicks up in a binary promotion ("a conforming library may support additional type promotion rules", so pytorch does---for binary promotions only), and we definitely have associativity. Thoughts?
EDIT (fat-fingered, sorry): so if we do this, we have (1) the order of inputs does not matter; (2) for Array API conforming promotions, the result is Array API compatible: python scalars do not influence the result; (3) pytorch additional promotions do get used in a well-defined way; and (4) the additional rule is fully explained in one sentence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the first premise of needing an arbitrary rule or that you can't satisfy everything...
... but it doesn't matter? Currently, torch (by its rules) is associative, I believe. The only odd one out is the custom behavior for float
due to int_dtype + float -> default_float
being wrong (breaking associativity).
And sure, since that is all you need to deal with, you can deal with it by ensuring Python scalar promotions happen last.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, torch.result_type
is a binary function, it does not accept longer sequences of arguments. Thus there is no associativity in torch proper, and we need to come up with a rule.
That said, if enforcing the order does not sound crazy, let's roll with it---the last commit in this PR.
PS. There is a PR in -test, https://github.com/data-apis/array-api-tests/pull/349/files#diff-5f3973d098b702a92ef5d67d9cc03bfb452de213a23478d541746f32dbc8023dR219, which checks that result_type
is associative. Are you saying that numpy does not guarantee it? (I only ran it with 10_000 hypothesis examples a couple of times).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole block looks wrong for: result_type(1., xp.int64) == xp.int64
. You should just revert the whole thing to the old code, I think.
Maybe since that is "unspecified" and because of that completely untested?
I have also no idea why the code after this doesn't use torch.promote_types
...
array_api_compat/torch/_aliases.py
Outdated
return torch.result_type(x, y) | ||
if isinstance(x, _py_scalars): | ||
if isinstance(y, _py_scalars): | ||
raise ValueError("At least one array or dtype is required.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole block looks wrong for: result_type(1., xp.int64) == xp.int64
. You should just revert the whole thing to the old code, I think.
Maybe since that is "unspecified" and because of that completely untested?
I have also no idea why the code after this doesn't use torch.promote_types
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry had some other comments, meant to make a single "review".
(The point about being associative is, was that for the binary operation result_type(a, result_type(b, c)) == result_type(result_type(a, b), c)
, if you ignore python float
.)
8c7fe4e
to
fe83ca7
Compare
Okay, third time the charm, hopefully. Now that I look closely, torch in fact has a
Exactly. Which is why previous iterations attempted to only add minimal changes on top of what already was effectively tested by usage. That did not work, so here's a (still minimal) rework + some light smoke-testing of "unspecified" logic. |
Looks like nice code improvements, but:
is incorrect. In making everything nicer, the current code doesn't actually fix the original issue :)?! You still need the sorting if you want order to not matter and you should maybe add a test for that:
(or stay within that |
Well, I (now) disagree it's incorrect. Let's start from the beginning: a. torch allows binary We want a
Note that the spec does not say anything about associativity, not explicitly anyway. Normally, we would not make any decisions at the level of If we can define it so that the order does not matter (associativity), great. If we define some other order --- as long as it's defined and is consistent with 2. --- I'd say it's fine. This PR as is defines If we want to enforce associativity, let's define the order or evaluations and make sure it is indeed ensures associativity (not sure how TBH, given that we're limited in what we can test). |
The only reasonable update to the rule I see now could be "1. sort the arguments so that scalars are after the array/dtype arguments; 2. combine arguments pairwise from right to left". This will fix an annoying edge case:
|
Yeah, I agree, you can choose not to guarantee that order doesn't matter! I am pretty sure that torch promotion is associative, so that it isn't order dependent (besides for Python I thought that was the original issue/goal here to make order not matter in this context :). |
Yes, it started with this goal in mind, but then it became clear that previously |
NumPy's result type is guarantee to not be order dependent (unless you add custom dtypes that break this by not being careful). |
Okay, thanks for clarifying it! Let's try to follow NumPy here and add sorting back. |
65b3e2b
to
7c9e572
Compare
I'm going to skip If the intention is to check that high-level modules have correct |
Yes, LGTM. I can always find more nits if you like though ;p. |
Yeah. We cannot prohibit the former because torch allows it; allowing the latter goes against "no default dtype" stance as I understand it --- heck, even
Nope, PR review graphs are directed and nits only flow along one direction :-). |
7c9e572
to
c9081d6
Compare
1. Allow inputs to be arrays or dtypes or python scalars 2. Keep the pytorch-specific additions, e.g. `result_type(int, float) -> float`, `result_type(scalar, scalar) -> dtype` which are unspecified in the standard 3. Since pytorch only defines a binary `result_type` function, add a version with multiple inputs. The latter is a bit tricky because we want to - keep allowing "unspecified" behaviors - keep standard-allowed promotions compliant - (preferably) make result_type independent on the argument order The latter is important because of `int,float->float` promotions which break associativity. So what we do, we always promote all scalars after all array/dtype arguments.
c9081d6
to
5473d84
Compare
Can we just merge it? Then I can just change the commit of |
Okay, this is okay'd by Sebastian, and has been tested via data-apis/array-api-tests#349 locally, so hopefully it just works. |
I think this did the trick! I am still having one problem that I thought was due to gh-273, but maybe not - I'll have to take a closer look. Thanks! |
The one remaining problem was a different issue, but I found another: if from array_api_compat import torch as xp
xp.result_type(1)
# AttributeError: 'int' object has no attribute 'dtype'
xp.result_type(1, 1)
# torch.int64
xp.result_type(1, 1, 1)
# ValueError: At least one array or dtype must be provided While this behavior is not governed by the standard, it would be helpful for testing purposes if it were self-consistent. |
Ah yes, this was noted by Sebastian: #277 (comment) For a one-argument version, turning the AttributeError into a ValueError is totally possible of course; PR welcome if you feel strongly enough! |
I think you could check for |
That would make a precedent for And I do believe that the binary |
Maybe we need an |
fixes gh-273,fixes gh-274
cross-ref data-apis/array-api-tests#349 which adds tests.