Skip to content

Commit 063546c

Browse files
committed
Added test (disabled for now) to openxla/xla#21461
1 parent 0cbbb76 commit 063546c

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

xlabuilder/reduce_test.go

+77
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,38 @@ import (
55
"github.com/gomlx/gopjrt/dtypes"
66
. "github.com/gomlx/gopjrt/xlabuilder"
77
"github.com/stretchr/testify/require"
8+
"math"
89
"testing"
910
)
1011

12+
// TestMax tests the Max function, as part of the ReduceMax test.
13+
// See https://github.com/openxla/xla/issues/21461
14+
func TestMax(t *testing.T) {
15+
client := getPJRTClient(t)
16+
{
17+
builder := New(fmt.Sprintf("%s-Max(NaN, 1) as Constant", t.Name()))
18+
input0 := capture(Constant(builder, NewScalarLiteral(math.NaN()))).Test(t)
19+
input1 := capture(Constant(builder, NewScalarLiteral(1.0))).Test(t)
20+
output := capture(Max(input0, input1)).Test(t)
21+
exec := compile(t, client, capture(builder.Build(output)).Test(t))
22+
got := execScalarOutput[float64](t, client, exec)
23+
require.True(t, math.IsNaN(got))
24+
builder.Destroy()
25+
}
26+
{
27+
builder := New(fmt.Sprintf("%s-Max(NaN, 1) as Parameter", t.Name()))
28+
input0 := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float64))).Test(t)
29+
input1 := capture(Parameter(builder, "y", 1, MakeShape(dtypes.Float64))).Test(t)
30+
input0 = capture(Sqrt(input0)).Test(t)
31+
input1 = capture(Sqrt(input1)).Test(t)
32+
output := capture(Max(input0, input1)).Test(t)
33+
exec := compile(t, client, capture(builder.Build(output)).Test(t))
34+
got := execWithScalars(t, client, exec, -1.0, 1.0)
35+
require.True(t, math.IsNaN(got))
36+
builder.Destroy()
37+
}
38+
}
39+
1140
func TestReduce(t *testing.T) {
1241
client := getPJRTClient(t)
1342

@@ -33,6 +62,35 @@ func TestReduce(t *testing.T) {
3362
builder.Destroy()
3463
}
3564

65+
{
66+
builder := New(fmt.Sprintf("%s-ReduceMax with NaN as constant", t.Name()))
67+
literal := capture(NewArrayLiteral([]float32{float32(math.NaN()), 1}, 2)).Test(t)
68+
input := capture(Constant(builder, literal)).Test(t)
69+
output := capture(ReduceMax(input, 0)).Test(t)
70+
comp := capture(builder.Build(output)).Test(t)
71+
fmt.Printf("HLO:\n%s\n", comp.TextHLO())
72+
exec := compile(t, client, comp)
73+
got := execWithScalars[float32](t, client, exec)
74+
require.True(t, math.IsNaN(float64(got)))
75+
builder.Destroy()
76+
}
77+
78+
{
79+
builder := New(fmt.Sprintf("%s-ReduceMax with NaN as parameter", t.Name()))
80+
input := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float32, 2))).Test(t)
81+
output := capture(ReduceMax(input, 0)).Test(t)
82+
comp := capture(builder.Build(output)).Test(t)
83+
fmt.Printf("HLO:\n%s\n", comp.TextHLO())
84+
exec := compile(t, client, comp)
85+
got, dims := execWithSlices(t, client, exec, []float32{float32(math.NaN()), 1})
86+
require.Empty(t, dims)
87+
fmt.Printf("got: %f -- Should be NAN, but with CPU PJRT it's not\n", got[0])
88+
// TODO: re-enable this test when bug is fixed.
89+
// See https://github.com/openxla/xla/issues/21461
90+
// require.True(t, math.IsNaN(float64(got[0])))
91+
builder.Destroy()
92+
}
93+
3694
// Test with ReduceSum and ReduceProduct
3795
{
3896
builder := New(fmt.Sprintf("%s-ReduceProduct-ReduceSum", t.Name()))
@@ -99,6 +157,25 @@ func TestReduce(t *testing.T) {
99157
}
100158
}
101159

160+
func TestReduceMaxBuggy(t *testing.T) {
161+
client := getPJRTClient(t)
162+
{
163+
builder := New(fmt.Sprintf("%s-ReduceMax with NaN as parameter", t.Name()))
164+
input := capture(Parameter(builder, "x", 0, MakeShape(dtypes.Float32, 2))).Test(t)
165+
output := capture(ReduceMax(input, 0)).Test(t)
166+
comp := capture(builder.Build(output)).Test(t)
167+
fmt.Printf("HLO:\n%s\n", comp.TextHLO())
168+
exec := compile(t, client, comp)
169+
got, dims := execWithSlices(t, client, exec, []float32{float32(math.NaN()), 1})
170+
require.Empty(t, dims)
171+
fmt.Printf("got: %f -- Should be NAN, but with CPU PJRT it's not\n", got[0])
172+
// TODO: re-enable this test when bug is fixed.
173+
// See https://github.com/openxla/xla/issues/21461
174+
//require.True(t, math.IsNaN(float64(got[0])))
175+
builder.Destroy()
176+
}
177+
}
178+
102179
func TestReduceWindow(t *testing.T) {
103180
client := getPJRTClient(t)
104181
builder := New(t.Name())

0 commit comments

Comments
 (0)