|
9 | 9 | import pytest
|
10 | 10 | from numpy.testing import assert_allclose, assert_array_equal, assert_equal
|
11 | 11 |
|
12 |
| -from array_api_extra import atleast_nd, cov, expand_dims, kron |
| 12 | +from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron |
13 | 13 |
|
14 | 14 | if TYPE_CHECKING:
|
15 | 15 | Array = Any # To be changed to a Protocol later (see array-api#589)
|
@@ -112,6 +112,30 @@ def test_combination(self):
|
112 | 112 | assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6)
|
113 | 113 |
|
114 | 114 |
|
| 115 | +class TestCreateDiagonal: |
| 116 | + def test_1d(self): |
| 117 | + vals = 100 * xp.arange(5, dtype=xp.float64) |
| 118 | + b = xp.zeros((5, 5)) |
| 119 | + for k in range(5): |
| 120 | + b[k, k] = vals[k] |
| 121 | + assert_array_equal(create_diagonal(vals, xp=xp), b) |
| 122 | + b = xp.zeros((7, 7)) |
| 123 | + c = xp.asarray(b, copy=True) |
| 124 | + for k in range(5): |
| 125 | + b[k, k + 2] = vals[k] |
| 126 | + c[k + 2, k] = vals[k] |
| 127 | + assert_array_equal(create_diagonal(vals, offset=2, xp=xp), b) |
| 128 | + assert_array_equal(create_diagonal(vals, offset=-2, xp=xp), c) |
| 129 | + |
| 130 | + def test_0d(self): |
| 131 | + with pytest.raises(ValueError, match="1-dimensional"): |
| 132 | + print(create_diagonal(xp.asarray(1), xp=xp)) |
| 133 | + |
| 134 | + def test_2d(self): |
| 135 | + with pytest.raises(ValueError, match="1-dimensional"): |
| 136 | + print(create_diagonal(xp.asarray([[1]]), xp=xp)) |
| 137 | + |
| 138 | + |
115 | 139 | class TestKron:
|
116 | 140 | def test_basic(self):
|
117 | 141 | # Using 0-dimensional array
|
|
0 commit comments