-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clarification for non-portable behavior for in-place Python operations #828
Comments
Did we agree that it should always be implementation defined if the types are different, or should |
@asmeurer I believe that is what @oleksandr-pavlyk was getting at in the OP. IMO, it should be okay to allow type promotion to |
So there are two cases: Case 1: x1 = asarray([0], dtype=int64)
x2 = asarray([0], dtype=int32)
x1 += x2 Case 2: x1 = asarray([0], dtype=int32)
x2 = asarray([0], dtype=int64)
x1 += x2 I think @oleksandr-pavlyk was asking about case 2:
In my opinion, this is actually spelled out already https://data-apis.org/array-api/latest/API_specification/array_object.html#in-place-operators:
In other words, case 2 is currently required to error (which is stronger than implementation defined). Case 1 is perhaps more ambiguous whether it is required or not. My reading of the current text is that it is. If we want to make it implementation defined, we should explicitly state that. I'm not aware of any reasons why it would be a problem, though. |
Does this make current NumPy behavior non-compliant, then? Because in NumPy
it simply casts the second array into the type of the first array per And per the spec
This seems like a very strong condition implying that |
I think the The sentence before that, which I quoted, "An in-place operation must not change the data type or shape of the in-place array..." does indeed imply that NumPy is currently noncompliant, because it uses must and not should. We could potentially loosen this to be implementation defined. I guess one question is whether the NumPy team agrees that this is nonideal behavior and should be deprecated. I certainly find the behavior surprising (it really is an in-place change of dtype: even views of Also we should check if other libraries allow this. This is the behavior in PyTorch: >>> x1 = torch.asarray([0], dtype=torch.int32)
>>> x2 = torch.asarray([1], dtype=torch.int64)
>>> x1 += x2
>>> x1
tensor([1], dtype=torch.int32) That actually could arguably be within what the spec says, because it didn't change the dtype of |
Yes, it would be good to clarify there that the dtype may be different, due to:
Can you give an example of NumPy's non-compliance with that? The example in #828 (comment) looks to be doing exactly what is specified: In [1]: import numpy as np
In [2]: x = np.arange(10, dtype=np.int32)
In [3]: y = np.ones(10, dtype=np.int64)
In [4]: (x + y).dtype
Out[4]: dtype('int64')
In [5]: x += y
In [6]: x.dtype
Out[6]: dtype('int32') |
Does this mean that libraries which don't explicitly define In [1]: import jax
In [2]: jax.config.update('jax_enable_x64', True)
In [3]: import jax.numpy as jnp
In [4]: x = jnp.arange(10, dtype=jnp.int32)
In [5]: y = jnp.ones(10, dtype=jnp.int64)
In [6]: (x + y).dtype
Out[6]: dtype('int64')
In [7]: x += y
In [8]: x.dtype # type is promoted, unlike NumPy
Out[9]: dtype('int64') The reason is that in Python when |
that does look like non-compliance currently, yes |
So effectively then, the array API specification is saying "compilant libraries must overload |
Hmm, that's an issue that wasn't thought about before I think. On the one hand, requiring that On the other hand, no one brought up this problem and intended to mark JAX as non-compliant here. The Python language was also not really designed with this in mind, casting rules like we need for numerical libraries weren't really a consideration, even the simpler numerical tower for Python's scalars came later (and is now considered a mistake). Perhaps the easy way out is to allow both, and warn that only equal dtypes result in fully portable behavior and the rest is implementation-defined?
I do agree that that downcasting numpy does is bad. it uses >>> x = np.arange(3)
>>> x += np.array([0, 1, 2.5])
...
UFuncTypeError: Cannot cast ufunc 'add' output from dtype('float64') to dtype('int64') with casting rule 'same_kind' |
The JAX behavior isn't an "in-place" change, though, which is what the current standard text talks about. This is just the Python variable
|
The spec doesn't seem to make this distinction currently; it talks about requirements for If that's the intent, then I think it should be clarified. |
The thinking has always been that as long as there is no aliased memory, then there is no meaningful distinction between a real mutation (like NumPy a = <an array>
b = a
a += x are very different if But, crucially, |
We do not require any operation to be actually in-place, including The def somefunc(x , ...):
# A function with a branch that returns `x` unchanged:
...
if some_condition:
return x
... This has now come up multiple times, so I'd like to document more explicitly in the section on operators that in-place mutation is not required. I'd even be happy to add a note that |
Pathological cases aside, I think the root of the issue is that the current guidance is self-contradictory. i.e. these three statements in the spec cannot all be true:
One possible resolution is changing "may" to "must" in the last statement. That would be quite problematic for libraries like JAX whose arrays are immutable. Another possible resolution is to weaken the requirements in the first or second statements, e.g. along the lines that @asmeurer mentioned, that if |
Unless you can make that work in JAX somehow, I think we have to consider this option blocked - the standard must be implementable by a functional library which is backed by an underlying immutable array implementation like JAX.
Not completely ideal, but also not really a problem and the best we can do probably. |
It may be helpful to just phrase it as what this means for the user (and focus on that)? I.e. users must not use in-place operations unless both shape and dtype of the result (non-inplace version) are unchanged. Of course the minimal/strict implementation should assert this (and should return a copy probably); I am not sure if it does. For implementers this means both true in-place (NumPy) and non-inplace (Jax) is fine: There is de-facto no limitation for the implementing object because that is just not practical (and things are obvious enough). N.B.: One could specify further that the array object behavior must be truly in-place if in-place capability is flagged, and truly the same as not having |
If people are writing NumPy code like that then translating it to the array API then it will be run with JAX too. Actually, I already know there's functions like that in array-api-compat (which is not directly relevant to JAX, but just to say I agree with your point that this is a common pattern). |
The wording for in-place operators may be more explicit to warn users that in-place operations, e.g.,
x1 += x2
, where Type Promotion Rules requirex1 + x2
to have data type different from data type ofx1
are implementation defined.Present wording hints at it:
but states that result of the in-place operation must "equal" the result of the out-of-place operation and the equality may hold true for arrays of different data types.
The text was updated successfully, but these errors were encountered: