Skip to content

Commit f4a3bad

Browse files
lucascolleyj-bowhay
andcommitted
TST: create_diagonal: add test from SciPy
Co-authored-by: Jake Bowhay <[email protected]>
1 parent 0b28fc3 commit f4a3bad

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

tests/test_funcs.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# array-api-strict#6
88
import array_api_strict as xp # type: ignore[import-untyped]
9+
import numpy as np
910
import pytest
1011
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
1112

@@ -114,6 +115,7 @@ def test_combination(self):
114115

115116
class TestCreateDiagonal:
116117
def test_1d(self):
118+
# from np.diag tests
117119
vals = 100 * xp.arange(5, dtype=xp.float64)
118120
b = xp.zeros((5, 5))
119121
for k in range(5):
@@ -127,13 +129,24 @@ def test_1d(self):
127129
assert_array_equal(create_diagonal(vals, offset=2, xp=xp), b)
128130
assert_array_equal(create_diagonal(vals, offset=-2, xp=xp), c)
129131

132+
@pytest.mark.parametrize("n", range(1, 10))
133+
@pytest.mark.parametrize("offset", range(1, 10))
134+
def test_create_diagonal(self, n, offset):
135+
# from scipy._lib tests
136+
rng = np.random.default_rng(2347823)
137+
one = xp.asarray(1.0)
138+
x = rng.random(n)
139+
A = create_diagonal(xp.asarray(x, dtype=one.dtype), offset=offset, xp=xp)
140+
B = xp.asarray(np.diag(x, offset), dtype=one.dtype)
141+
assert_array_equal(A, B)
142+
130143
def test_0d(self):
131144
with pytest.raises(ValueError, match="1-dimensional"):
132-
print(create_diagonal(xp.asarray(1), xp=xp))
145+
create_diagonal(xp.asarray(1), xp=xp)
133146

134147
def test_2d(self):
135148
with pytest.raises(ValueError, match="1-dimensional"):
136-
print(create_diagonal(xp.asarray([[1]]), xp=xp))
149+
create_diagonal(xp.asarray([[1]]), xp=xp)
137150

138151

139152
class TestKron:

0 commit comments

Comments
 (0)