Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarification for non-portable behavior for in-place Python operations #828

Open
oleksandr-pavlyk opened this issue Jul 30, 2024 · 18 comments
Labels
Maintenance Bug fix, typo fix, or general maintenance. Narrative Content Narrative documentation content. topic: Type Promotion Type promotion.
Milestone

Comments

@oleksandr-pavlyk
Copy link
Contributor

The wording for in-place operators may be more explicit to warn users that in-place operations, e.g., x1 += x2, where Type Promotion Rules require x1 + x2 to have data type different from data type of x1 are implementation defined.

Present wording hints at it:

For example, after in-place addition x1 += x2, the modified array x1 must always equal the result of the equivalent binary arithmetic operation x1 = x1 + x2.

but states that result of the in-place operation must "equal" the result of the out-of-place operation and the equality may hold true for arrays of different data types.

@asmeurer
Copy link
Member

Did we agree that it should always be implementation defined if the types are different, or should x1 += x2 be OK if the type promoted type is the same as x1.dtype (i.e., x2 upcasts to x1)?

@kgryte kgryte added this to the v2024 milestone Sep 19, 2024
@kgryte kgryte added Maintenance Bug fix, typo fix, or general maintenance. Narrative Content Narrative documentation content. topic: Type Promotion Type promotion. labels Sep 19, 2024
@kgryte
Copy link
Contributor

kgryte commented Sep 19, 2024

should x1 += x2 be OK if the type promoted type is the same as x1.dtype (i.e., x2 upcasts to x1)?

@asmeurer I believe that is what @oleksandr-pavlyk was getting at in the OP. IMO, it should be okay to allow type promotion to x1.dtype and we can be more explicit in stating that the relationship x1 = x1 + x2 must hold provided the constraints of "limited" type promotion are satisfied.

@asmeurer
Copy link
Member

So there are two cases:

Case 1: x1 + x2 promotes to x1.dtype

x1 = asarray([0], dtype=int64)
x2 = asarray([0], dtype=int32)
x1 += x2

Case 2: x1 + x2 promotes to x2.dtype

x1 = asarray([0], dtype=int32)
x2 = asarray([0], dtype=int64)
x1 += x2

I think @oleksandr-pavlyk was asking about case 2:

... where Type Promotion Rules require x1 + x2 to have data type different from data type of x1 ...

In my opinion, this is actually spelled out already https://data-apis.org/array-api/latest/API_specification/array_object.html#in-place-operators:

An in-place operation must not change the data type or shape of the in-place array as a result of Type Promotion Rules or Broadcasting.

In other words, case 2 is currently required to error (which is stronger than implementation defined).

Case 1 is perhaps more ambiguous whether it is required or not. My reading of the current text is that it is. If we want to make it implementation defined, we should explicitly state that. I'm not aware of any reasons why it would be a problem, though.

@ndgrigorian
Copy link

ndgrigorian commented Sep 19, 2024

In other words, case 2 is currently required to error (which is stronger than implementation defined).

Does this make current NumPy behavior non-compliant, then?

Because in NumPy

In [26]: x_np = np.arange(10, dtype="i4")

In [27]: x_np += np.ones(10, dtype="i8")

In [28]: x_np
Out[28]: array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)

it simply casts the second array into the type of the first array per same_kind casting.

And per the spec

An in-place operation must have the same behavior (including special cases) as its respective binary (i.e., two operand, non-assignment) operation. For example, after in-place addition x1 += x2, the modified array x1 must always equal the result of the equivalent binary arithmetic operation x1 = x1 + x2.

This seems like a very strong condition implying that x1 += x2 of this case should be disallowed, and may be an even stronger condition than what you've quoted, because the data type of x1 (the in-place array) does not change.

@asmeurer
Copy link
Member

asmeurer commented Sep 19, 2024

I think the x1 = x1 + x2 sentence is not really intended to be saying anything about the data type of x1. It's only saying that whatever behavior is specified in __add__ (and thus add()) also applies to the += operator, i.e., all the special cases, nan behaviors, and so on. The wording should probably be improved since this isn't clear.

The sentence before that, which I quoted, "An in-place operation must not change the data type or shape of the in-place array..." does indeed imply that NumPy is currently noncompliant, because it uses must and not should. We could potentially loosen this to be implementation defined. I guess one question is whether the NumPy team agrees that this is nonideal behavior and should be deprecated. I certainly find the behavior surprising (it really is an in-place change of dtype: even views of x1 are updated to the promoted dtype).

Also we should check if other libraries allow this. This is the behavior in PyTorch:

>>> x1 = torch.asarray([0], dtype=torch.int32)
>>> x2 = torch.asarray([1], dtype=torch.int64)
>>> x1 += x2
>>> x1
tensor([1], dtype=torch.int32)

That actually could arguably be within what the spec says, because it didn't change the dtype of x1. So we should be clear whether this is actually OK, or whether it should be an error (or implementation defined).

@lucascolley
Copy link
Contributor

I think the x1 = x1 + x2 sentence is not really intended to be saying anything about the data type of x1. It's only saying that whatever behavior is specified in add (and thus add()) also applies to the += operator, i.e., all the special cases, nan behaviors, and so on. The wording should probably be improved since this isn't clear.

Yes, it would be good to clarify there that the dtype may be different, due to:

An in-place operation must not change the data type or shape of the in-place array as a result of Type Promotion Rules or Broadcasting.


The sentence before that, which I quoted, "An in-place operation must not change the data type or shape of the in-place array..." does indeed imply that NumPy is currently noncompliant, because it uses must and not should.

Can you give an example of NumPy's non-compliance with that? The example in #828 (comment) looks to be doing exactly what is specified:

In [1]: import numpy as np

In [2]: x = np.arange(10, dtype=np.int32)

In [3]: y = np.ones(10, dtype=np.int64)

In [4]: (x + y).dtype
Out[4]: dtype('int64')

In [5]: x += y

In [6]: x.dtype
Out[6]: dtype('int32')

@jakevdp
Copy link

jakevdp commented Dec 24, 2024

Does this mean that libraries which don't explicitly define __iadd__ are out of compliance? JAX is an example: __iadd__ is undefined, so it falls back to the __add__ behavior, which differs from NumPy's behavior above:

In [1]: import jax

In [2]: jax.config.update('jax_enable_x64', True)

In [3]: import jax.numpy as jnp

In [4]: x = jnp.arange(10, dtype=jnp.int32)

In [5]: y = jnp.ones(10, dtype=jnp.int64)

In [6]: (x + y).dtype
Out[6]: dtype('int64')

In [7]: x += y

In [8]: x.dtype  # type is promoted, unlike NumPy
Out[9]: dtype('int64')

The reason is that in Python when __iadd__ is not defined, x += y falls back to x = x + y. JAX cannot define __iadd__ because its arrays are immutable.

@lucascolley
Copy link
Contributor

lucascolley commented Dec 24, 2024

that does look like non-compliance currently, yes

@jakevdp
Copy link

jakevdp commented Dec 24, 2024

So effectively then, the array API specification is saying "compilant libraries must overload __iadd__", because the behavior when falling back to a compilant implementation of __add__ is non-compliant.

@rgommers
Copy link
Member

Hmm, that's an issue that wasn't thought about before I think. On the one hand, requiring that x += y doesn't upcast x makes perfect sense to me as a design rule, because (a) it's what the user would most likely expect when reading that code (they're called "in-place operators" after all, so changing dtype and memory footprint is unexpected), and (b) if we leave it free then the dtype becomes unpredictable, depending on whether __iadd__ was defined or not.

On the other hand, no one brought up this problem and intended to mark JAX as non-compliant here. The Python language was also not really designed with this in mind, casting rules like we need for numerical libraries weren't really a consideration, even the simpler numerical tower for Python's scalars came later (and is now considered a mistake).

Perhaps the easy way out is to allow both, and warn that only equal dtypes result in fully portable behavior and the rest is implementation-defined?

I guess one question is whether the NumPy team agrees that this is nonideal behavior and should be deprecated. I certainly find the behavior surprising (it really is an in-place change of dtype: even views of x1 are updated to the promoted dtype).

I do agree that that downcasting numpy does is bad. it uses 'same_kind' casting, and that probably should be 'safe' instead:

>>> x = np.arange(3)
>>> x += np.array([0, 1, 2.5])
...
UFuncTypeError: Cannot cast ufunc 'add' output from dtype('float64') to dtype('int64') with casting rule 'same_kind'

@asmeurer
Copy link
Member

asmeurer commented Jan 2, 2025

The JAX behavior isn't an "in-place" change, though, which is what the current standard text talks about. This is just the Python variable x being updated to point to a new object (it will have a different id()).

An in-place operation must not change the data type or shape of the in-place array as a result of Type Promotion Rules or Broadcasting.

@jakevdp
Copy link

jakevdp commented Jan 2, 2025

The JAX behavior isn't an "in-place" change

The spec doesn't seem to make this distinction currently; it talks about requirements for x1 += x2, and later says that __iadd__ "may" be implemented, but does not specifically call out that if __iadd__ is left unimplemented, the previously-stated requirements don't apply.

If that's the intent, then I think it should be clarified.

@asmeurer
Copy link
Member

asmeurer commented Jan 2, 2025

The thinking has always been that as long as there is no aliased memory, then there is no meaningful distinction between a real mutation (like NumPy __iadd__) and a copy (i.e., replacing x1 with a new array object, which is what x1 += x2 does when __iadd__ is not implemented). But this isn't really completely the case here. There could be other Python variables referring to the same Python object as x1. In that case, the semantics of something like

a = <an array>
b = a
a += x

are very different if a += x is implemented using __iadd__ vs. if it is implemented as the default a = a + x. In the former case, the values of x are added to b and in the latter they aren't.

But, crucially, b is not aliased memory to a in the traditional sense. It's just another Python variable referencing the same Python object. There's no way for the array library to know that b exists, unless it actually traces the source/byte code (or introspects the locals or something).

@rgommers
Copy link
Member

rgommers commented Jan 2, 2025

We do not require any operation to be actually in-place, including __iadd__ & co. We're using the term "in-place operators" simply because that's how they are called by Python: https://docs.python.org/3/library/operator.html#in-place-operators.

The b = a case is a bit pathological, and not really interesting. I think that's a corner case in JAX, and it should be very rare to hit that in practice. It's more common in NumPy et al. to have two references to the same array object, because of code like:

def somefunc(x , ...):
   # A function with a branch that returns `x` unchanged:
   ...
   if some_condition:
        return x
    ...

This has now come up multiple times, so I'd like to document more explicitly in the section on operators that in-place mutation is not required. I'd even be happy to add a note that b = a is a corner case that results in behavior that isn't fully portable, so don't do that (there's no functional reason to do this after all).

@jakevdp
Copy link

jakevdp commented Jan 2, 2025

Pathological cases aside, I think the root of the issue is that the current guidance is self-contradictory. i.e. these three statements in the spec cannot all be true:

An in-place operation must not change the data type or shape of the in-place array as a result of Type Promotion Rules or Broadcasting.

An in-place operation must have the same behavior (including special cases) as its respective binary (i.e., two operand, non-assignment) operation.

+=. May be implemented via __iadd__.

One possible resolution is changing "may" to "must" in the last statement. That would be quite problematic for libraries like JAX whose arrays are immutable.

Another possible resolution is to weaken the requirements in the first or second statements, e.g. along the lines that @asmeurer mentioned, that if __iadd__ is not implemented, then the notes about "in-place" no longer apply to x1 += x2.

@rgommers
Copy link
Member

rgommers commented Jan 2, 2025

One possible resolution is changing "may" to "must" in the last statement. That would be quite problematic for libraries like JAX whose arrays are immutable.

Unless you can make that work in JAX somehow, I think we have to consider this option blocked - the standard must be implementable by a functional library which is backed by an underlying immutable array implementation like JAX.

Another possible resolution is to weaken the requirements in the first or second statements, e.g. along the lines that @asmeurer mentioned, that if __iadd__ is not implemented, then the notes about "in-place" no longer apply to x1 += x2.

Not completely ideal, but also not really a problem and the best we can do probably.

@seberg
Copy link
Contributor

seberg commented Jan 2, 2025

It may be helpful to just phrase it as what this means for the user (and focus on that)? I.e. users must not use in-place operations unless both shape and dtype of the result (non-inplace version) are unchanged.
Users must also never rely on in-place modification of the original object unless checking capabilities first.

Of course the minimal/strict implementation should assert this (and should return a copy probably); I am not sure if it does.

For implementers this means both true in-place (NumPy) and non-inplace (Jax) is fine: There is de-facto no limitation for the implementing object because that is just not practical (and things are obvious enough).

N.B.: One could specify further that the array object behavior must be truly in-place if in-place capability is flagged, and truly the same as not having __iadd__ if not.

@asmeurer
Copy link
Member

asmeurer commented Jan 2, 2025

It's more common in NumPy et al. to have two references to the same array object, because of code like:

If people are writing NumPy code like that then translating it to the array API then it will be run with JAX too. Actually, I already know there's functions like that in array-api-compat (which is not directly relevant to JAX, but just to say I agree with your point that this is a common pattern).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Maintenance Bug fix, typo fix, or general maintenance. Narrative Content Narrative documentation content. topic: Type Promotion Type promotion.
Projects
None yet
Development

No branches or pull requests

8 participants