Skip to content

Commit 34db8df

Browse files
ENH: add create_diagonal (#19)
* ENH: add `create_diagonal` Co-authored-by: Jake Bowhay <[email protected]> * TST: `create_diagonal`: add tests * DOC: `create_diagonal`: add docs * DOC: `create_diagonal`: add `xp` param * appease linter * TST: `create_diagonal`: add test from SciPy Co-authored-by: Jake Bowhay <[email protected]> --------- Co-authored-by: Jake Bowhay <[email protected]>
1 parent bd00f3a commit 34db8df

File tree

4 files changed

+99
-4
lines changed

4 files changed

+99
-4
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
99
atleast_nd
1010
cov
11+
create_diagonal
1112
expand_dims
1213
kron
1314
sinc

src/array_api_extra/__init__.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd, cov, expand_dims, kron, sinc
3+
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc
44

55
__version__ = "0.1.2.dev0"
66

7-
__all__ = ["__version__", "atleast_nd", "cov", "expand_dims", "kron", "sinc"]
7+
__all__ = [
8+
"__version__",
9+
"atleast_nd",
10+
"cov",
11+
"create_diagonal",
12+
"expand_dims",
13+
"kron",
14+
"sinc",
15+
]

src/array_api_extra/_funcs.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
if TYPE_CHECKING:
77
from ._typing import Array, ModuleType
88

9-
__all__ = ["atleast_nd", "cov", "expand_dims", "kron", "sinc"]
9+
__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron", "sinc"]
1010

1111

1212
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
@@ -140,6 +140,55 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
140140
return xp.squeeze(c, axis=axes)
141141

142142

143+
def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
144+
"""
145+
Construct a diagonal array.
146+
147+
Parameters
148+
----------
149+
x : array
150+
A 1-D array
151+
offset : int, optional
152+
Offset from the leading diagonal (default is ``0``).
153+
Use positive ints for diagonals above the leading diagonal,
154+
and negative ints for diagonals below the leading diagonal.
155+
xp : array_namespace
156+
The standard-compatible namespace for `x`.
157+
158+
Returns
159+
-------
160+
res : array
161+
A 2-D array with `x` on the diagonal (offset by `offset`).
162+
163+
Examples
164+
--------
165+
>>> import array_api_strict as xp
166+
>>> import array_api_extra as xpx
167+
>>> x = xp.asarray([2, 4, 8])
168+
169+
>>> xpx.create_diagonal(x, xp=xp)
170+
Array([[2, 0, 0],
171+
[0, 4, 0],
172+
[0, 0, 8]], dtype=array_api_strict.int64)
173+
174+
>>> xpx.create_diagonal(x, offset=-2, xp=xp)
175+
Array([[0, 0, 0, 0, 0],
176+
[0, 0, 0, 0, 0],
177+
[2, 0, 0, 0, 0],
178+
[0, 4, 0, 0, 0],
179+
[0, 0, 8, 0, 0]], dtype=array_api_strict.int64)
180+
181+
"""
182+
if x.ndim != 1:
183+
err_msg = "`x` must be 1-dimensional."
184+
raise ValueError(err_msg)
185+
n = x.shape[0] + abs(offset)
186+
diag = xp.zeros(n**2, dtype=x.dtype)
187+
i = offset if offset >= 0 else abs(offset) * n
188+
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x
189+
return xp.reshape(diag, (n, n))
190+
191+
143192
def _mean(
144193
x: Array,
145194
/,

tests/test_funcs.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
# array-api-strict#6
88
import array_api_strict as xp # type: ignore[import-untyped]
9+
import numpy as np
910
import pytest
1011
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
1112

12-
from array_api_extra import atleast_nd, cov, expand_dims, kron, sinc
13+
from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc
1314

1415
if TYPE_CHECKING:
1516
Array = Any # To be changed to a Protocol later (see array-api#589)
@@ -112,6 +113,42 @@ def test_combination(self):
112113
assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6)
113114

114115

116+
class TestCreateDiagonal:
117+
def test_1d(self):
118+
# from np.diag tests
119+
vals = 100 * xp.arange(5, dtype=xp.float64)
120+
b = xp.zeros((5, 5))
121+
for k in range(5):
122+
b[k, k] = vals[k]
123+
assert_array_equal(create_diagonal(vals, xp=xp), b)
124+
b = xp.zeros((7, 7))
125+
c = xp.asarray(b, copy=True)
126+
for k in range(5):
127+
b[k, k + 2] = vals[k]
128+
c[k + 2, k] = vals[k]
129+
assert_array_equal(create_diagonal(vals, offset=2, xp=xp), b)
130+
assert_array_equal(create_diagonal(vals, offset=-2, xp=xp), c)
131+
132+
@pytest.mark.parametrize("n", range(1, 10))
133+
@pytest.mark.parametrize("offset", range(1, 10))
134+
def test_create_diagonal(self, n, offset):
135+
# from scipy._lib tests
136+
rng = np.random.default_rng(2347823)
137+
one = xp.asarray(1.0)
138+
x = rng.random(n)
139+
A = create_diagonal(xp.asarray(x, dtype=one.dtype), offset=offset, xp=xp)
140+
B = xp.asarray(np.diag(x, offset), dtype=one.dtype)
141+
assert_array_equal(A, B)
142+
143+
def test_0d(self):
144+
with pytest.raises(ValueError, match="1-dimensional"):
145+
create_diagonal(xp.asarray(1), xp=xp)
146+
147+
def test_2d(self):
148+
with pytest.raises(ValueError, match="1-dimensional"):
149+
create_diagonal(xp.asarray([[1]]), xp=xp)
150+
151+
115152
class TestKron:
116153
def test_basic(self):
117154
# Using 0-dimensional array

0 commit comments

Comments
 (0)