Skip to content

Commit 3daa84b

Browse files
committed
Add support for copy kwarg in astype to match Array API
1 parent 51352fa commit 3daa84b

File tree

6 files changed

+82
-11
lines changed

6 files changed

+82
-11
lines changed

Diff for: CHANGELOG.md

+7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ Remember to align the itemized text with the first line of an item within a list
3131
now leads to an error rather than a warning.
3232
* The minimum jaxlib version is now 0.4.23.
3333

34+
* Bug fixes
35+
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
36+
Previously, no copy would be made when the output array would have the same
37+
dtype as the input array. This may result in some increased memory usage.
38+
To prevent copying when possible, set `copy=None`. To error when a copy is
39+
required, set `copy=False`.
40+
3441
## jaxlib 0.4.27
3542

3643
## jax 0.4.26 (April 3, 2024)

Diff for: jax/_src/numpy/array_methods.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
import numpy as np
3232
import jax
3333
from jax import lax
34+
from jax.sharding import Sharding
3435
from jax._src import core
3536
from jax._src import dtypes
3637
from jax._src.api_util import _ensure_index_tuple
3738
from jax._src.array import ArrayImpl
3839
from jax._src.lax import lax as lax_internal
40+
from jax._src.lib import xla_client as xc
3941
from jax._src.numpy import lax_numpy
4042
from jax._src.numpy import reductions
4143
from jax._src.numpy import ufuncs
@@ -55,15 +57,15 @@
5557
# functions, which can themselves handle instances from any of these classes.
5658

5759

58-
def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
60+
def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
5961
"""Copy the array and cast to a specified dtype.
6062
6163
This is implemented via :func:`jax.lax.convert_element_type`, which may
6264
have slightly different behavior than :meth:`numpy.ndarray.astype` in
6365
some cases. In particular, the details of float-to-int and int-to-float
6466
casts are implementation dependent.
6567
"""
66-
return lax_numpy.astype(arr, dtype)
68+
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
6769

6870

6971
def _nbytes(arr: ArrayLike) -> int:

Diff for: jax/_src/numpy/lax_numpy.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -2262,17 +2262,41 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
22622262
In particular, the details of float-to-int and int-to-float casts are
22632263
implementation dependent.
22642264
""")
2265-
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
2265+
def astype(x: ArrayLike, dtype: DTypeLike | None,
2266+
/, *, copy: bool = True,
2267+
device: xc.Device | Sharding | None = None) -> Array:
22662268
util.check_arraylike("astype", x)
22672269
x_arr = asarray(x)
2268-
del copy # unused in JAX
22692270
if dtype is None:
22702271
dtype = dtypes.canonicalize_dtype(float_)
22712272
dtypes.check_user_dtype_supported(dtype, "astype")
2272-
# convert_element_type(complex, bool) has the wrong semantics.
2273-
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
2274-
return (x_arr != _lax_const(x_arr, 0))
2275-
return lax.convert_element_type(x_arr, dtype)
2273+
if issubdtype(x_arr.dtype, complexfloating):
2274+
if dtypes.isdtype(dtype, ("integral", "real floating")):
2275+
warnings.warn(
2276+
"Casting from complex to real dtypes will soon raise a ValueError. "
2277+
"Please first use jnp.real or jnp.imag to take the real/imaginary "
2278+
"component of your input.",
2279+
DeprecationWarning, stacklevel=2
2280+
)
2281+
elif np.dtype(dtype) == bool:
2282+
# convert_element_type(complex, bool) has the wrong semantics.
2283+
x_arr = (x_arr != _lax_const(x_arr, 0))
2284+
2285+
# We offer a more specific warning than the usual ComplexWarning so we prefer
2286+
# to issue our warning.
2287+
with warnings.catch_warnings():
2288+
warnings.simplefilter("ignore", ComplexWarning)
2289+
return _place_array(
2290+
lax.convert_element_type(x_arr, dtype),
2291+
device=device, copy=copy,
2292+
)
2293+
2294+
def _place_array(x, device=None, copy=None):
2295+
# TODO(micky774): Implement in future PRs as we formalize device placement
2296+
# semantics
2297+
if copy:
2298+
return _array_copy(x)
2299+
return x
22762300

22772301

22782302
@util.implements(np.asarray, lax_description=_ARRAY_DOC)

Diff for: jax/experimental/array_api/_data_type_functions.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import builtins
1618
import functools
1719
from typing import NamedTuple
1820
import jax
1921
import jax.numpy as jnp
2022

2123

24+
from jax._src.lib import xla_client as xc
25+
from jax._src.sharding import Sharding
26+
from jax._src import dtypes as _dtypes
2227
from jax.experimental.array_api._dtypes import (
2328
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64,
2429
float32, float64, complex64, complex128
@@ -124,8 +129,19 @@ def _promote_types(t1, t2):
124129
raise ValueError("No promotion path for {t1} & {t2}")
125130

126131

127-
def astype(x, dtype, /, *, copy=True):
128-
return jnp.array(x, dtype=dtype, copy=copy)
132+
def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
133+
src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x)
134+
if (
135+
src_dtype is not None
136+
and _dtypes.isdtype(src_dtype, "complex floating")
137+
and _dtypes.isdtype(dtype, ("integral", "real floating"))
138+
):
139+
raise ValueError(
140+
"Casting from complex to non-complex dtypes is not permitted. Please "
141+
"first use jnp.real or jnp.imag to take the real/imaginary component of "
142+
"your input."
143+
)
144+
return jnp.astype(x, dtype, copy=copy, device=device)
129145

130146

131147
def can_cast(from_, to, /):

Diff for: jax/numpy/__init__.pyi

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ from jax._src.typing import (
1313
Array, ArrayLike, DType, DTypeLike,
1414
DimSize, DuckTypedArray, Shape, DeprecatedArg
1515
)
16+
from jax._src.sharding import Sharding
17+
from jax._src.lib import xla_client as xc
1618
from jax.numpy import fft as fft, linalg as linalg
1719
from jax.sharding import Sharding as _Sharding
1820
import numpy as _np
@@ -115,7 +117,7 @@ def asarray(
115117
) -> Array: ...
116118
def asin(x: ArrayLike, /) -> Array: ...
117119
def asinh(x: ArrayLike, /) -> Array: ...
118-
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ...
120+
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: xc.Device | Sharding | None = ...) -> Array: ...
119121
def atan(x: ArrayLike, /) -> Array: ...
120122
def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
121123
def atanh(x: ArrayLike, /) -> Array: ...

Diff for: tests/lax_numpy_test.py

+20
Original file line numberDiff line numberDiff line change
@@ -3840,6 +3840,26 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
38403840
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
38413841
self._CompileAndCheck(jnp_op, args_maker)
38423842

3843+
@jtu.sample_product(
3844+
change_dtype=[True, False],
3845+
copy=[True, False],
3846+
)
3847+
def testAstypeCopy(self, change_dtype, copy):
3848+
dtype = 'float32' if change_dtype else 'int32'
3849+
expect_copy = change_dtype or copy
3850+
x = jnp.arange(5, dtype='int32')
3851+
y = x.astype(dtype, copy=copy)
3852+
3853+
assert y.dtype == dtype
3854+
y.delete()
3855+
assert x.is_deleted() != expect_copy
3856+
3857+
def testAstypeComplexDowncast(self):
3858+
x = jnp.array(2.0+1.5j, dtype='complex64')
3859+
msg = "Casting from complex to non-complex dtypes will soon raise "
3860+
with self.assertWarns(DeprecationWarning, msg=msg):
3861+
x.astype('float32')
3862+
38433863
def testAstypeInt4(self):
38443864
# Test converting from int4 to int8
38453865
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)

0 commit comments

Comments
 (0)