Skip to content

Commit bc07964

Browse files
committed
Add support for copy kwarg in astype to match Array API
1 parent 32922f6 commit bc07964

File tree

7 files changed

+89
-13
lines changed

7 files changed

+89
-13
lines changed

Diff for: CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ Remember to align the itemized text with the first line of an item within a list
4949
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
5050
related functions now raise an error, following a similar change in NumPy.
5151

52+
* Bug fixes
53+
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
54+
Previously, no copy would be made when the output array would have the same
55+
dtype as the input array. This may result in some increased memory usage.
56+
The default value is set to `copy=False` to preserve backwards compatability.
57+
5258
## jaxlib 0.4.27
5359

5460
## jax 0.4.26 (April 3, 2024)

Diff for: jax/_src/numpy/array_methods.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
import numpy as np
3232
import jax
3333
from jax import lax
34+
from jax.sharding import Sharding
3435
from jax._src import core
3536
from jax._src import dtypes
3637
from jax._src.api_util import _ensure_index_tuple
3738
from jax._src.array import ArrayImpl
3839
from jax._src.lax import lax as lax_internal
40+
from jax._src.lib import xla_client as xc
3941
from jax._src.numpy import lax_numpy
4042
from jax._src.numpy import reductions
4143
from jax._src.numpy import ufuncs
@@ -55,15 +57,15 @@
5557
# functions, which can themselves handle instances from any of these classes.
5658

5759

58-
def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
60+
def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
5961
"""Copy the array and cast to a specified dtype.
6062
6163
This is implemented via :func:`jax.lax.convert_element_type`, which may
6264
have slightly different behavior than :meth:`numpy.ndarray.astype` in
6365
some cases. In particular, the details of float-to-int and int-to-float
6466
casts are implementation dependent.
6567
"""
66-
return lax_numpy.astype(arr, dtype)
68+
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
6769

6870

6971
def _nbytes(arr: ArrayLike) -> int:

Diff for: jax/_src/numpy/lax_numpy.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -2272,17 +2272,42 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
22722272
In particular, the details of float-to-int and int-to-float casts are
22732273
implementation dependent.
22742274
""")
2275-
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
2275+
def astype(x: ArrayLike, dtype: DTypeLike | None,
2276+
/, *, copy: bool = False,
2277+
device: xc.Device | Sharding | None = None) -> Array:
22762278
util.check_arraylike("astype", x)
22772279
x_arr = asarray(x)
2278-
del copy # unused in JAX
2280+
22792281
if dtype is None:
22802282
dtype = dtypes.canonicalize_dtype(float_)
22812283
dtypes.check_user_dtype_supported(dtype, "astype")
2282-
# convert_element_type(complex, bool) has the wrong semantics.
2283-
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
2284-
return (x_arr != _lax_const(x_arr, 0))
2285-
return lax.convert_element_type(x_arr, dtype)
2284+
if issubdtype(x_arr.dtype, complexfloating):
2285+
if dtypes.isdtype(dtype, ("integral", "real floating")):
2286+
warnings.warn(
2287+
"Casting from complex to real dtypes will soon raise a ValueError. "
2288+
"Please first use jnp.real or jnp.imag to take the real/imaginary "
2289+
"component of your input.",
2290+
DeprecationWarning, stacklevel=2
2291+
)
2292+
elif np.dtype(dtype) == bool:
2293+
# convert_element_type(complex, bool) has the wrong semantics.
2294+
x_arr = (x_arr != _lax_const(x_arr, 0))
2295+
2296+
# We offer a more specific warning than the usual ComplexWarning so we prefer
2297+
# to issue our warning.
2298+
with warnings.catch_warnings():
2299+
warnings.simplefilter("ignore", ComplexWarning)
2300+
return _place_array(
2301+
lax.convert_element_type(x_arr, dtype),
2302+
device=device, copy=copy,
2303+
)
2304+
2305+
def _place_array(x, device=None, copy=None):
2306+
# TODO(micky774): Implement in future PRs as we formalize device placement
2307+
# semantics
2308+
if copy:
2309+
return _array_copy(x)
2310+
return x
22862311

22872312

22882313
@util.implements(np.asarray, lax_description=_ARRAY_DOC)

Diff for: jax/experimental/array_api/_data_type_functions.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import builtins
1618
import functools
1719
from typing import NamedTuple
1820
import jax
1921
import jax.numpy as jnp
2022

2123

24+
from jax._src.lib import xla_client as xc
25+
from jax._src.sharding import Sharding
26+
from jax._src import dtypes as _dtypes
2227
from jax.experimental.array_api._dtypes import (
2328
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64,
2429
float32, float64, complex64, complex128
@@ -124,8 +129,19 @@ def _promote_types(t1, t2):
124129
raise ValueError("No promotion path for {t1} & {t2}")
125130

126131

127-
def astype(x, dtype, /, *, copy=True):
128-
return jnp.array(x, dtype=dtype, copy=copy)
132+
def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
133+
src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x)
134+
if (
135+
src_dtype is not None
136+
and _dtypes.isdtype(src_dtype, "complex floating")
137+
and _dtypes.isdtype(dtype, ("integral", "real floating"))
138+
):
139+
raise ValueError(
140+
"Casting from complex to non-complex dtypes is not permitted. Please "
141+
"first use jnp.real or jnp.imag to take the real/imaginary component of "
142+
"your input."
143+
)
144+
return jnp.astype(x, dtype, copy=copy, device=device)
129145

130146

131147
def can_cast(from_, to, /):

Diff for: jax/numpy/__init__.pyi

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ from jax._src.typing import (
1313
Array, ArrayLike, DType, DTypeLike,
1414
DimSize, DuckTypedArray, Shape, DeprecatedArg
1515
)
16+
from jax._src.sharding import Sharding
17+
from jax._src.lib import xla_client as xc
1618
from jax.numpy import fft as fft, linalg as linalg
1719
from jax.sharding import Sharding as _Sharding
1820
import numpy as _np
@@ -115,7 +117,7 @@ def asarray(
115117
) -> Array: ...
116118
def asin(x: ArrayLike, /) -> Array: ...
117119
def asinh(x: ArrayLike, /) -> Array: ...
118-
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ...
120+
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: xc.Device | Sharding | None = ...) -> Array: ...
119121
def atan(x: ArrayLike, /) -> Array: ...
120122
def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
121123
def atanh(x: ArrayLike, /) -> Array: ...

Diff for: tests/lax_numpy_reducers_test.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -776,8 +776,13 @@ def test_f16_mean(self, dtype):
776776
for axis in list(
777777
range(-len(shape), len(shape))
778778
) + ([None] if len(shape) == 1 else [])],
779-
dtype=all_dtypes + [None],
780-
out_dtype=all_dtypes,
779+
[dict(dtype=dtype, out_dtype=out_dtype)
780+
for dtype in (all_dtypes+[None])
781+
for out_dtype in (
782+
complex_dtypes if np.issubdtype(dtype, np.complexfloating)
783+
else all_dtypes
784+
)
785+
],
781786
include_initial=[False, True],
782787
)
783788
@jtu.ignore_warning(category=NumpyComplexWarning)

Diff for: tests/lax_numpy_test.py

+20
Original file line numberDiff line numberDiff line change
@@ -3870,6 +3870,26 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
38703870
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
38713871
self._CompileAndCheck(jnp_op, args_maker)
38723872

3873+
@jtu.sample_product(
3874+
change_dtype=[True, False],
3875+
copy=[True, False],
3876+
)
3877+
def testAstypeCopy(self, change_dtype, copy):
3878+
dtype = 'float32' if change_dtype else 'int32'
3879+
expect_copy = change_dtype or copy
3880+
x = jnp.arange(5, dtype='int32')
3881+
y = x.astype(dtype, copy=copy)
3882+
3883+
assert y.dtype == dtype
3884+
y.delete()
3885+
assert x.is_deleted() != expect_copy
3886+
3887+
def testAstypeComplexDowncast(self):
3888+
x = jnp.array(2.0+1.5j, dtype='complex64')
3889+
msg = "Casting from complex to non-complex dtypes will soon raise "
3890+
with self.assertWarns(DeprecationWarning, msg=msg):
3891+
x.astype('float32')
3892+
38733893
def testAstypeInt4(self):
38743894
# Test converting from int4 to int8
38753895
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)

0 commit comments

Comments
 (0)