Skip to content

Commit 0a53e57

Browse files
committed
TST: create_diagonal: add tests
1 parent 57e26a8 commit 0a53e57

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

src/array_api_extra/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd, cov, expand_dims, kron
3+
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron
44

55
__version__ = "0.1.2.dev0"
66

7-
__all__ = ["__version__", "atleast_nd", "cov", "expand_dims", "kron"]
7+
__all__ = ["__version__", "atleast_nd", "cov", "create_diagonal", "expand_dims", "kron"]

src/array_api_extra/_funcs.py

+3
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
141141

142142

143143
def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
144+
if x.ndim != 1:
145+
err_msg = "`x` must be 1-dimensional."
146+
raise ValueError(err_msg)
144147
n = x.shape[0] + abs(offset)
145148
diag = xp.zeros(n**2, dtype=x.dtype)
146149
i = offset if offset >= 0 else abs(offset) * n

tests/test_funcs.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
1111

12-
from array_api_extra import atleast_nd, cov, expand_dims, kron
12+
from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron
1313

1414
if TYPE_CHECKING:
1515
Array = Any # To be changed to a Protocol later (see array-api#589)
@@ -112,6 +112,30 @@ def test_combination(self):
112112
assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6)
113113

114114

115+
class TestCreateDiagonal:
116+
def test_1d(self):
117+
vals = 100 * xp.arange(5, dtype=xp.float64)
118+
b = xp.zeros((5, 5))
119+
for k in range(5):
120+
b[k, k] = vals[k]
121+
assert_array_equal(create_diagonal(vals, xp=xp), b)
122+
b = xp.zeros((7, 7))
123+
c = xp.asarray(b, copy=True)
124+
for k in range(5):
125+
b[k, k + 2] = vals[k]
126+
c[k + 2, k] = vals[k]
127+
assert_array_equal(create_diagonal(vals, offset=2, xp=xp), b)
128+
assert_array_equal(create_diagonal(vals, offset=-2, xp=xp), c)
129+
130+
def test_0d(self):
131+
with pytest.raises(ValueError, match="1-dimensional"):
132+
print(create_diagonal(xp.asarray(1), xp=xp))
133+
134+
def test_2d(self):
135+
with pytest.raises(ValueError, match="1-dimensional"):
136+
print(create_diagonal(xp.asarray([[1]]), xp=xp))
137+
138+
115139
class TestKron:
116140
def test_basic(self):
117141
# Using 0-dimensional array

0 commit comments

Comments
 (0)