From f28e4641e76a27e17f16a78f6843bc3a3c9178ae Mon Sep 17 00:00:00 2001
From: Tyler Rockwood <rockwood@redpanda.com>
Date: Tue, 17 Dec 2024 20:45:01 +0000
Subject: [PATCH] decimal: fix default values

Unlike timestamp types which accepts native values, decimal types must
be big.Rat, which means defaults need to be converted.

This fixes: https://github.com/linkedin/goavro/issues/202
---
 binary_test.go       |  2 +-
 logical_type_test.go |  7 ++++++-
 record.go            | 28 +++++++++++++++++++++-------
 record_test.go       | 21 +++++++++++++++------
 4 files changed, 43 insertions(+), 15 deletions(-)

diff --git a/binary_test.go b/binary_test.go
index 45f12cf9..2b735a2c 100644
--- a/binary_test.go
+++ b/binary_test.go
@@ -88,7 +88,7 @@ func testBinaryDecodePass(t *testing.T, schema string, datum interface{}, encode
 	t.Helper()
 	codec, err := NewCodec(schema)
 	if err != nil {
-		t.Fatal(err)
+		t.Fatalf("unable to create codec: %s", err)
 	}
 
 	value, remaining, err := codec.NativeFromBinary(encoded)
diff --git a/logical_type_test.go b/logical_type_test.go
index 20f7903b..a52fc1d5 100644
--- a/logical_type_test.go
+++ b/logical_type_test.go
@@ -151,7 +151,6 @@ func TestDecimalBytesLogicalTypeEncode(t *testing.T) {
 	d, _ := new(big.Int).SetString("100000000000000000000000000000000000000", 10)
 	largeRat := new(big.Rat).SetFrac(n, d)
 	testBinaryCodecPass(t, largeDecimalSchema, largeRat, []byte("\x40\x1b\x4b\x68\x19\x26\x11\xfa\xea\x20\x8f\xca\x21\x62\x7b\xe9\xda\xee\x32\x19\x83\x83\x95\x5d\xe8\x13\x1f\x4b\xf1\xc7\x1c\x71\xc7"))
-
 }
 
 func TestDecimalFixedLogicalTypeEncode(t *testing.T) {
@@ -178,6 +177,12 @@ func TestDecimalBytesLogicalTypeInRecordEncode(t *testing.T) {
 	testBinaryCodecPass(t, schema, map[string]interface{}{"mydecimal": big.NewRat(617, 50)}, []byte("\x04\x04\xd2"))
 }
 
+func TestDecimalBytesLogicalTypeInRecordDecodeWithDefault(t *testing.T) {
+	schema := `{"type": "record", "name": "myrecord", "fields" : [
+    {"name": "mydecimal", "type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2, "default":"\u0000"}]}`
+	testBinaryCodecPass(t, schema, map[string]interface{}{"mydecimal": big.NewRat(617, 50)}, []byte("\x04\x04\xd2"))
+}
+
 func TestValidatedStringLogicalTypeInRecordEncode(t *testing.T) {
 	schema := `{
 		"type": "record",
diff --git a/record.go b/record.go
index e5ac9e41..fd522ef7 100644
--- a/record.go
+++ b/record.go
@@ -66,43 +66,43 @@ func makeRecordCodec(st map[string]*Codec, enclosingNamespace string, schemaMap
 			case "boolean":
 				v, ok := defaultValue.(bool)
 				if !ok {
-					return nil, fmt.Errorf("Record %q field %q: default value ought to encode using field schema: %s", c.typeName, fieldName, err)
+					return nil, fmt.Errorf("Record %q field %q: default value ought to have a bool type, got: %T", c.typeName, fieldName, defaultValue)
 				}
 				defaultValue = v
 			case "bytes":
 				v, ok := defaultValue.(string)
 				if !ok {
-					return nil, fmt.Errorf("Record %q field %q: default value ought to encode using field schema: %s", c.typeName, fieldName, err)
+					return nil, fmt.Errorf("Record %q field %q: default value ought to have a string type got: %T", c.typeName, fieldName, defaultValue)
 				}
 				defaultValue = []byte(v)
 			case "double":
 				v, ok := defaultValue.(float64)
 				if !ok {
-					return nil, fmt.Errorf("Record %q field %q: default value ought to encode using field schema: %s", c.typeName, fieldName, err)
+					return nil, fmt.Errorf("Record %q field %q: default value ought to have a double type got: %T", c.typeName, fieldName, defaultValue)
 				}
 				defaultValue = v
 			case "float":
 				v, ok := defaultValue.(float64)
 				if !ok {
-					return nil, fmt.Errorf("Record %q field %q: default value ought to encode using field schema: %s", c.typeName, fieldName, err)
+					return nil, fmt.Errorf("Record %q field %q: default value ought to have a float type got: %T", c.typeName, fieldName, defaultValue)
 				}
 				defaultValue = float32(v)
 			case "int":
 				v, ok := defaultValue.(float64)
 				if !ok {
-					return nil, fmt.Errorf("Record %q field %q: default value ought to encode using field schema: %s", c.typeName, fieldName, err)
+					return nil, fmt.Errorf("Record %q field %q: default value ought to have a number type got: %T", c.typeName, fieldName, defaultValue)
 				}
 				defaultValue = int32(v)
 			case "long":
 				v, ok := defaultValue.(float64)
 				if !ok {
-					return nil, fmt.Errorf("Record %q field %q: default value ought to encode using field schema: %s", c.typeName, fieldName, err)
+					return nil, fmt.Errorf("Record %q field %q: default value ought to have a number type got: %T", c.typeName, fieldName, defaultValue)
 				}
 				defaultValue = int64(v)
 			case "string":
 				v, ok := defaultValue.(string)
 				if !ok {
-					return nil, fmt.Errorf("Record %q field %q: default value ought to encode using field schema: %s", c.typeName, fieldName, err)
+					return nil, fmt.Errorf("Record %q field %q: default value ought to have a string type got: %T", c.typeName, fieldName, defaultValue)
 				}
 				defaultValue = v
 			case "union":
@@ -118,6 +118,20 @@ func makeRecordCodec(st map[string]*Codec, enclosingNamespace string, schemaMap
 				defaultValue = Union(fieldCodec.schemaOriginal, defaultValue)
 			default:
 				debug("fieldName: %q; type: %q; defaultValue: %T(%#v)\n", fieldName, c.typeName, defaultValue, defaultValue)
+
+				// Support defaults for logical types
+				if logicalType, ok := fieldSchemaMap["logicalType"]; ok {
+					if logicalType == "decimal" {
+						v, ok := defaultValue.(string)
+						if !ok {
+							return nil, fmt.Errorf("Record %q field %q: default value ought to have a string type got: %T", c.typeName, fieldName, defaultValue)
+						}
+						defaultValue, _, err = fieldCodec.nativeFromBinary([]byte(v))
+						if err != nil {
+							return nil, fmt.Errorf("Record %q field %q: default value ought to decode from textual: %w", c.typeName, fieldName, err)
+						}
+					}
+				}
 			}
 
 			// attempt to encode default value using codec
diff --git a/record_test.go b/record_test.go
index f8145c03..c9726513 100644
--- a/record_test.go
+++ b/record_test.go
@@ -12,6 +12,7 @@ package goavro
 import (
 	"bytes"
 	"fmt"
+	"math/big"
 	"testing"
 )
 
@@ -389,7 +390,7 @@ func TestRecordFieldDefaultValue(t *testing.T) {
 	testSchemaValid(t, `{"type":"record","name":"r1","fields":[{"name":"f1","type":"string","default":"foo"}]}`)
 	testSchemaInvalid(t,
 		`{"type":"record","name":"r1","fields":[{"name":"f1","type":"int","default":"foo"}]}`,
-		"default value ought to encode using field schema")
+		"default value ought to have a number type")
 }
 
 func TestRecordFieldUnionDefaultValue(t *testing.T) {
@@ -618,7 +619,7 @@ func TestRecordFieldFixedDefaultValue(t *testing.T) {
 
 func TestRecordFieldDefaultValueTypes(t *testing.T) {
 	t.Run("success", func(t *testing.T) {
-		codec, err := NewCodec(`{"type": "record", "name": "r1", "fields":[{"name": "someBoolean", "type": "boolean", "default": true},{"name": "someBytes", "type": "bytes", "default": "0"},{"name": "someDouble", "type": "double", "default": 0},{"name": "someFloat", "type": "float", "default": 0},{"name": "someInt", "type": "int", "default": 0},{"name": "someLong", "type": "long", "default": 0},{"name": "someString", "type": "string", "default": "0"}]}`)
+		codec, err := NewCodec(`{"type": "record", "name": "r1", "fields":[{"name": "someBoolean", "type": "boolean", "default": true},{"name": "someBytes", "type": "bytes", "default": "0"},{"name": "someDouble", "type": "double", "default": 0},{"name": "someFloat", "type": "float", "default": 0},{"name": "someInt", "type": "int", "default": 0},{"name": "someLong", "type": "long", "default": 0},{"name": "someString", "type": "string", "default": "0"}, {"name":"someTimestamp", "type":"long", "logicalType":"timestamp-millis","default":0}, {"name": "someDecimal", "type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2, "default":"\u0000"}]}`)
 		ensureError(t, err)
 
 		r1, _, err := codec.NativeFromTextual([]byte("{}"))
@@ -660,24 +661,32 @@ func TestRecordFieldDefaultValueTypes(t *testing.T) {
 		if _, ok := someString.(string); !ok {
 			t.Errorf("GOT: %T; WANT: string", someString)
 		}
+		someTimestamp := r1m["someTimestamp"]
+		if _, ok := someTimestamp.(float64); !ok {
+			t.Errorf("GOT: %T; WANT: float64", someTimestamp)
+		}
+		someDecimal := r1m["someDecimal"]
+		if _, ok := someDecimal.(*big.Rat); !ok {
+			t.Errorf("GOT: %T; WANT: *big.Rat", someDecimal)
+		}
 	})
 
 	t.Run("provided default is wrong type", func(t *testing.T) {
 		t.Run("long", func(t *testing.T) {
 			_, err := NewCodec(`{"type": "record", "name": "r1", "fields":[{"name": "someLong", "type": "long", "default": "0"},{"name": "someInt", "type": "int", "default": 0},{"name": "someFloat", "type": "float", "default": 0},{"name": "someDouble", "type": "double", "default": 0}]}`)
-			ensureError(t, err, "field schema")
+			ensureError(t, err, "default value ought to have a number type")
 		})
 		t.Run("int", func(t *testing.T) {
 			_, err := NewCodec(`{"type": "record", "name": "r1", "fields":[{"name": "someLong", "type": "long", "default": 0},{"name": "someInt", "type": "int", "default": "0"},{"name": "someFloat", "type": "float", "default": 0},{"name": "someDouble", "type": "double", "default": 0}]}`)
-			ensureError(t, err, "field schema")
+			ensureError(t, err, "default value ought to have a number type")
 		})
 		t.Run("float", func(t *testing.T) {
 			_, err := NewCodec(`{"type": "record", "name": "r1", "fields":[{"name": "someLong", "type": "long", "default": 0},{"name": "someInt", "type": "int", "default": 0},{"name": "someFloat", "type": "float", "default": "0"},{"name": "someDouble", "type": "double", "default": 0}]}`)
-			ensureError(t, err, "field schema")
+			ensureError(t, err, "default value ought to have a float type")
 		})
 		t.Run("double", func(t *testing.T) {
 			_, err := NewCodec(`{"type": "record", "name": "r1", "fields":[{"name": "someLong", "type": "long", "default": 0},{"name": "someInt", "type": "int", "default": 0},{"name": "someFloat", "type": "float", "default": 0},{"name": "someDouble", "type": "double", "default": "0"}]}`)
-			ensureError(t, err, "field schema")
+			ensureError(t, err, "default value ought to have a double type")
 		})
 	})