@@ -5,9 +5,38 @@ import (
5
5
"github.com/gomlx/gopjrt/dtypes"
6
6
. "github.com/gomlx/gopjrt/xlabuilder"
7
7
"github.com/stretchr/testify/require"
8
+ "math"
8
9
"testing"
9
10
)
10
11
12
+ // TestMax tests the Max function, as part of the ReduceMax test.
13
+ // See https://github.com/openxla/xla/issues/21461
14
+ func TestMax (t * testing.T ) {
15
+ client := getPJRTClient (t )
16
+ {
17
+ builder := New (fmt .Sprintf ("%s-Max(NaN, 1) as Constant" , t .Name ()))
18
+ input0 := capture (Constant (builder , NewScalarLiteral (math .NaN ()))).Test (t )
19
+ input1 := capture (Constant (builder , NewScalarLiteral (1.0 ))).Test (t )
20
+ output := capture (Max (input0 , input1 )).Test (t )
21
+ exec := compile (t , client , capture (builder .Build (output )).Test (t ))
22
+ got := execScalarOutput [float64 ](t , client , exec )
23
+ require .True (t , math .IsNaN (got ))
24
+ builder .Destroy ()
25
+ }
26
+ {
27
+ builder := New (fmt .Sprintf ("%s-Max(NaN, 1) as Parameter" , t .Name ()))
28
+ input0 := capture (Parameter (builder , "x" , 0 , MakeShape (dtypes .Float64 ))).Test (t )
29
+ input1 := capture (Parameter (builder , "y" , 1 , MakeShape (dtypes .Float64 ))).Test (t )
30
+ input0 = capture (Sqrt (input0 )).Test (t )
31
+ input1 = capture (Sqrt (input1 )).Test (t )
32
+ output := capture (Max (input0 , input1 )).Test (t )
33
+ exec := compile (t , client , capture (builder .Build (output )).Test (t ))
34
+ got := execWithScalars (t , client , exec , - 1.0 , 1.0 )
35
+ require .True (t , math .IsNaN (got ))
36
+ builder .Destroy ()
37
+ }
38
+ }
39
+
11
40
func TestReduce (t * testing.T ) {
12
41
client := getPJRTClient (t )
13
42
@@ -33,6 +62,35 @@ func TestReduce(t *testing.T) {
33
62
builder .Destroy ()
34
63
}
35
64
65
+ {
66
+ builder := New (fmt .Sprintf ("%s-ReduceMax with NaN as constant" , t .Name ()))
67
+ literal := capture (NewArrayLiteral ([]float32 {float32 (math .NaN ()), 1 }, 2 )).Test (t )
68
+ input := capture (Constant (builder , literal )).Test (t )
69
+ output := capture (ReduceMax (input , 0 )).Test (t )
70
+ comp := capture (builder .Build (output )).Test (t )
71
+ fmt .Printf ("HLO:\n %s\n " , comp .TextHLO ())
72
+ exec := compile (t , client , comp )
73
+ got := execWithScalars [float32 ](t , client , exec )
74
+ require .True (t , math .IsNaN (float64 (got )))
75
+ builder .Destroy ()
76
+ }
77
+
78
+ {
79
+ builder := New (fmt .Sprintf ("%s-ReduceMax with NaN as parameter" , t .Name ()))
80
+ input := capture (Parameter (builder , "x" , 0 , MakeShape (dtypes .Float32 , 2 ))).Test (t )
81
+ output := capture (ReduceMax (input , 0 )).Test (t )
82
+ comp := capture (builder .Build (output )).Test (t )
83
+ fmt .Printf ("HLO:\n %s\n " , comp .TextHLO ())
84
+ exec := compile (t , client , comp )
85
+ got , dims := execWithSlices (t , client , exec , []float32 {float32 (math .NaN ()), 1 })
86
+ require .Empty (t , dims )
87
+ fmt .Printf ("got: %f -- Should be NAN, but with CPU PJRT it's not\n " , got [0 ])
88
+ // TODO: re-enable this test when bug is fixed.
89
+ // See https://github.com/openxla/xla/issues/21461
90
+ // require.True(t, math.IsNaN(float64(got[0])))
91
+ builder .Destroy ()
92
+ }
93
+
36
94
// Test with ReduceSum and ReduceProduct
37
95
{
38
96
builder := New (fmt .Sprintf ("%s-ReduceProduct-ReduceSum" , t .Name ()))
@@ -99,6 +157,25 @@ func TestReduce(t *testing.T) {
99
157
}
100
158
}
101
159
160
+ func TestReduceMaxBuggy (t * testing.T ) {
161
+ client := getPJRTClient (t )
162
+ {
163
+ builder := New (fmt .Sprintf ("%s-ReduceMax with NaN as parameter" , t .Name ()))
164
+ input := capture (Parameter (builder , "x" , 0 , MakeShape (dtypes .Float32 , 2 ))).Test (t )
165
+ output := capture (ReduceMax (input , 0 )).Test (t )
166
+ comp := capture (builder .Build (output )).Test (t )
167
+ fmt .Printf ("HLO:\n %s\n " , comp .TextHLO ())
168
+ exec := compile (t , client , comp )
169
+ got , dims := execWithSlices (t , client , exec , []float32 {float32 (math .NaN ()), 1 })
170
+ require .Empty (t , dims )
171
+ fmt .Printf ("got: %f -- Should be NAN, but with CPU PJRT it's not\n " , got [0 ])
172
+ // TODO: re-enable this test when bug is fixed.
173
+ // See https://github.com/openxla/xla/issues/21461
174
+ //require.True(t, math.IsNaN(float64(got[0])))
175
+ builder .Destroy ()
176
+ }
177
+ }
178
+
102
179
func TestReduceWindow (t * testing.T ) {
103
180
client := getPJRTClient (t )
104
181
builder := New (t .Name ())
0 commit comments