@@ -89,10 +89,10 @@ def test_allclose_error(self):
89
89
arr1 = math .astensor ([1 , 2 , 3 ])
90
90
arr2 = math .astensor ([[1 , 2 ], [1 , 2 ]])
91
91
92
- if math .backend_name != "jax " :
93
- with pytest .raises (ValueError , match = "Cannot compare " ):
92
+ if math .backend_name == "numpy " :
93
+ with pytest .raises (ValueError , match = "could not be broadcast " ):
94
94
math .allclose (arr1 , arr2 )
95
- else :
95
+ elif math . backend_name == "jax" :
96
96
with pytest .raises (ValueError , match = "Incompatible shapes" ):
97
97
math .allclose (arr2 , arr1 )
98
98
@@ -262,6 +262,14 @@ def test_block_diag(self):
262
262
assert R .shape == (8 , 8 )
263
263
assert math .allclose (math .block ([[I , O ], [O , 1j * I ]]), R )
264
264
265
+ def test_broadcast_to (self ):
266
+ r"""
267
+ Tests the ``broadcast_to`` method.
268
+ """
269
+ arr = math .astensor ([1 , 2 , 3 ])
270
+ res = math .broadcast_to (arr , (3 , 3 ))
271
+ assert math .allclose (res , math .astensor ([[1 , 2 , 3 ], [1 , 2 , 3 ], [1 , 2 , 3 ]]))
272
+
265
273
@pytest .mark .parametrize ("t" , types )
266
274
def test_cast (self , t ):
267
275
r"""
@@ -674,6 +682,16 @@ def test_sqrtm(self):
674
682
res = math .asnumpy (math .sqrtm (arr ))
675
683
assert math .allclose (res , 2 * np .eye (3 ))
676
684
685
+ def test_stack (self ):
686
+ r"""
687
+ Tests the ``stack`` method.
688
+ """
689
+ arr1 = np .eye (3 )
690
+ arr2 = 2 * np .eye (3 )
691
+ res = math .asnumpy (math .stack ([arr1 , arr2 ], axis = 0 ))
692
+ exp = np .stack ([arr1 , arr2 ], axis = 0 )
693
+ assert np .allclose (res , exp )
694
+
677
695
def test_sum (self ):
678
696
r"""
679
697
Tests the ``sum`` method.
0 commit comments