|
| 1 | +package arrowutils |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "testing" |
| 6 | + |
| 7 | + "github.com/apache/arrow/go/v14/arrow" |
| 8 | + "github.com/apache/arrow/go/v14/arrow/array" |
| 9 | + "github.com/apache/arrow/go/v14/arrow/memory" |
| 10 | + "github.com/stretchr/testify/require" |
| 11 | +) |
| 12 | + |
| 13 | +func TestSortRecord(t *testing.T) { |
| 14 | + ctx := context.Background() |
| 15 | + schema := arrow.NewSchema( |
| 16 | + []arrow.Field{ |
| 17 | + {Name: "int", Type: arrow.PrimitiveTypes.Int64}, |
| 18 | + {Name: "string", Type: arrow.BinaryTypes.String}, |
| 19 | + }, |
| 20 | + nil, |
| 21 | + ) |
| 22 | + |
| 23 | + mem := memory.DefaultAllocator |
| 24 | + ib := array.NewInt64Builder(mem) |
| 25 | + ib.Append(0) |
| 26 | + ib.AppendNull() |
| 27 | + ib.Append(3) |
| 28 | + ib.Append(5) |
| 29 | + ib.Append(1) |
| 30 | + |
| 31 | + sb := array.NewStringBuilder(mem) |
| 32 | + sb.Append("d") |
| 33 | + sb.Append("c") |
| 34 | + sb.Append("b") |
| 35 | + sb.AppendNull() |
| 36 | + sb.Append("a") |
| 37 | + |
| 38 | + record := array.NewRecord(schema, []arrow.Array{ib.NewArray(), sb.NewArray()}, int64(5)) |
| 39 | + |
| 40 | + // Sort the record by the first column - int64 |
| 41 | + { |
| 42 | + sortedIndices, err := SortRecord(mem, record, []int{record.Schema().FieldIndices("int")[0]}) |
| 43 | + require.NoError(t, err) |
| 44 | + require.Equal(t, []int64{0, 4, 2, 3, 1}, sortedIndices.Int64Values()) |
| 45 | + |
| 46 | + sortedByInts, err := ReorderRecord(ctx, record, sortedIndices) |
| 47 | + require.NoError(t, err) |
| 48 | + |
| 49 | + // check that the column got sortedIndices |
| 50 | + intCol := sortedByInts.Column(0).(*array.Int64) |
| 51 | + require.Equal(t, []int64{0, 1, 3, 5, 0}, intCol.Int64Values()) |
| 52 | + require.True(t, intCol.IsNull(intCol.Len()-1)) // last is NULL |
| 53 | + // make sure the other column got updated too |
| 54 | + strings := make([]string, sortedByInts.NumRows()) |
| 55 | + stringCol := sortedByInts.Column(1).(*array.String) |
| 56 | + for i := 0; i < int(sortedByInts.NumRows()); i++ { |
| 57 | + strings[i] = stringCol.Value(i) |
| 58 | + } |
| 59 | + require.Equal(t, []string{"d", "a", "b", "", "c"}, strings) |
| 60 | + } |
| 61 | + |
| 62 | + // Sort the record by the second column - string |
| 63 | + { |
| 64 | + sortedIndices, err := SortRecord(mem, record, []int{record.Schema().FieldIndices("string")[0]}) |
| 65 | + require.NoError(t, err) |
| 66 | + require.Equal(t, []int64{4, 2, 1, 0, 3}, sortedIndices.Int64Values()) |
| 67 | + |
| 68 | + sortedByStrings, err := ReorderRecord(ctx, record, sortedIndices) |
| 69 | + require.NoError(t, err) |
| 70 | + |
| 71 | + // check that the column got sortedByInts |
| 72 | + intCol := sortedByStrings.Column(0).(*array.Int64) |
| 73 | + require.Equal(t, []int64{1, 3, 0, 0, 5}, intCol.Int64Values()) |
| 74 | + // make sure the other column got updated too |
| 75 | + strings := make([]string, sortedByStrings.NumRows()) |
| 76 | + stringCol := sortedByStrings.Column(1).(*array.String) |
| 77 | + for i := 0; i < int(sortedByStrings.NumRows()); i++ { |
| 78 | + strings[i] = stringCol.Value(i) |
| 79 | + } |
| 80 | + require.Equal(t, []string{"a", "b", "c", "d", ""}, strings) |
| 81 | + require.True(t, stringCol.IsNull(stringCol.Len()-1)) // last is NULL |
| 82 | + } |
| 83 | +} |
0 commit comments