diff --git a/async_producer_test.go b/async_producer_test.go index 42b034807..b23def1ec 100644 --- a/async_producer_test.go +++ b/async_producer_test.go @@ -4,7 +4,6 @@ package sarama import ( "errors" - "github.com/stretchr/testify/assert" "log" "math" "os" @@ -18,6 +17,7 @@ import ( "github.com/fortytw2/leaktest" "github.com/rcrowley/go-metrics" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -638,6 +638,68 @@ func TestAsyncProducerMultipleRetriesWithBackoffFunc(t *testing.T) { } } +func TestAsyncProducerWithExponentialBackoffDurations(t *testing.T) { + var backoffDurations []time.Duration + var mu sync.Mutex + + topic := "my_topic" + maxBackoff := 2 * time.Second + config := NewTestConfig() + + innerBackoffFunc := NewExponentialBackoff(defaultRetryBackoff, maxBackoff) + backoffFunc := func(retries, maxRetries int) time.Duration { + duration := innerBackoffFunc(retries, maxRetries) + mu.Lock() + backoffDurations = append(backoffDurations, duration) + mu.Unlock() + return duration + } + + config.Producer.Flush.Messages = 5 + config.Producer.Return.Successes = true + config.Producer.Retry.Max = 3 + config.Producer.Retry.BackoffFunc = backoffFunc + + broker := NewMockBroker(t, 1) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(broker.Addr(), broker.BrokerID()) + metadataResponse.AddTopicPartition(topic, 0, broker.BrokerID(), nil, nil, nil, ErrNoError) + broker.Returns(metadataResponse) + + producer, err := NewAsyncProducer([]string{broker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + failResponse := new(ProduceResponse) + failResponse.AddTopicPartition(topic, 0, ErrNotLeaderForPartition) + successResponse := new(ProduceResponse) + successResponse.AddTopicPartition(topic, 0, ErrNoError) + + broker.Returns(failResponse) + broker.Returns(metadataResponse) + broker.Returns(failResponse) + broker.Returns(metadataResponse) + broker.Returns(successResponse) + + for i := 0; i < 5; i++ { + producer.Input() <- &ProducerMessage{Topic: topic, Value: StringEncoder("test")} + } + + expectResults(t, producer, 5, 0) + closeProducer(t, producer) + broker.Close() + + assert.Greater(t, backoffDurations[0], time.Duration(0), + "Expected first backoff duration to be greater than 0") + for i := 1; i < len(backoffDurations); i++ { + assert.Greater(t, backoffDurations[i], time.Duration(0)) + assert.GreaterOrEqual(t, backoffDurations[i], backoffDurations[i-1]) + assert.LessOrEqual(t, backoffDurations[i], maxBackoff) + } +} + // https://github.com/IBM/sarama/issues/2129 func TestAsyncProducerMultipleRetriesWithConcurrentRequests(t *testing.T) { // Logger = log.New(os.Stdout, "[sarama] ", log.LstdFlags) diff --git a/utils.go b/utils.go index 7f0a84fe2..83a992a5f 100644 --- a/utils.go +++ b/utils.go @@ -3,8 +3,15 @@ package sarama import ( "bufio" "fmt" + "math/rand" "net" "regexp" + "time" +) + +const ( + defaultRetryBackoff = 100 * time.Millisecond + defaultRetryMaxBackoff = 1000 * time.Millisecond ) type none struct{} @@ -344,3 +351,39 @@ func (v KafkaVersion) String() string { return fmt.Sprintf("%d.%d.%d", v.version[0], v.version[1], v.version[2]) } + +// NewExponentialBackoff returns a function that implements an exponential backoff strategy with jitter. +// It follows KIP-580, implementing the formula: +// MIN(retry.backoff.max.ms, (retry.backoff.ms * 2**(failures - 1)) * random(0.8, 1.2)) +// This ensures retries start with `backoff` and exponentially increase until `maxBackoff`, with added jitter. +// The behavior when `failures = 0` is not explicitly defined in KIP-580 and is left to implementation discretion. +// +// Example usage: +// +// backoffFunc := sarama.NewExponentialBackoff(config.Producer.Retry.Backoff, 2*time.Second) +// config.Producer.Retry.BackoffFunc = backoffFunc +func NewExponentialBackoff(backoff time.Duration, maxBackoff time.Duration) func(retries, maxRetries int) time.Duration { + if backoff <= 0 { + backoff = defaultRetryBackoff + } + if maxBackoff <= 0 { + maxBackoff = defaultRetryMaxBackoff + } + + if backoff > maxBackoff { + Logger.Println("Warning: backoff is greater than maxBackoff, using maxBackoff instead.") + backoff = maxBackoff + } + + return func(retries, maxRetries int) time.Duration { + if retries <= 0 { + return backoff + } + + calculatedBackoff := backoff * time.Duration(1<<(retries-1)) + jitter := 0.8 + 0.4*rand.Float64() + calculatedBackoff = time.Duration(float64(calculatedBackoff) * jitter) + + return min(calculatedBackoff, maxBackoff) + } +} diff --git a/utils_test.go b/utils_test.go index 96b2c4b54..18d014d9e 100644 --- a/utils_test.go +++ b/utils_test.go @@ -2,7 +2,10 @@ package sarama -import "testing" +import ( + "testing" + "time" +) func TestVersionCompare(t *testing.T) { if V0_8_2_0.IsAtLeast(V0_8_2_1) { @@ -95,3 +98,47 @@ func TestVersionParsing(t *testing.T) { } } } + +func TestExponentialBackoffValidCases(t *testing.T) { + testCases := []struct { + retries int + maxRetries int + minBackoff time.Duration + maxBackoffExpected time.Duration + }{ + {1, 5, 80 * time.Millisecond, 120 * time.Millisecond}, + {3, 5, 320 * time.Millisecond, 480 * time.Millisecond}, + {5, 5, 1280 * time.Millisecond, 1920 * time.Millisecond}, + } + + for _, tc := range testCases { + backoffFunc := NewExponentialBackoff(100*time.Millisecond, 2*time.Second) + backoff := backoffFunc(tc.retries, tc.maxRetries) + if backoff < tc.minBackoff || backoff > tc.maxBackoffExpected { + t.Errorf("backoff(%d, %d): expected between %v and %v, got %v", tc.retries, tc.maxRetries, tc.minBackoff, tc.maxBackoffExpected, backoff) + } + } +} + +func TestExponentialBackoffDefaults(t *testing.T) { + testCases := []struct { + backoff time.Duration + maxBackoff time.Duration + }{ + {-100 * time.Millisecond, 2 * time.Second}, + {100 * time.Millisecond, -2 * time.Second}, + {-100 * time.Millisecond, -2 * time.Second}, + {0 * time.Millisecond, 2 * time.Second}, + {100 * time.Millisecond, 0 * time.Second}, + {0 * time.Millisecond, 0 * time.Second}, + } + + for _, tc := range testCases { + backoffFunc := NewExponentialBackoff(tc.backoff, tc.maxBackoff) + backoff := backoffFunc(2, 5) + if backoff < defaultRetryBackoff || backoff > defaultRetryMaxBackoff { + t.Errorf("backoff(%v, %v): expected between %v and %v, got %v", + tc.backoff, tc.maxBackoff, defaultRetryBackoff, defaultRetryMaxBackoff, backoff) + } + } +}