diff --git a/balancer.go b/balancer.go index f4768cf8..4136fce7 100644 --- a/balancer.go +++ b/balancer.go @@ -36,11 +36,14 @@ func (f BalancerFunc) Balance(msg Message, partitions ...int) int { } // RoundRobin is an Balancer implementation that equally distributes messages -// across all available partitions. +// across all available partitions. It can take an optional chunk size to send +// ChunkSize messages to the same partition before moving to the next partition. +// This can be used to improve batch sizes. type RoundRobin struct { + ChunkSize int // Use a 32 bits integer so RoundRobin values don't need to be aligned to // apply atomic increments. - offset uint32 + counter uint32 } // Balance satisfies the Balancer interface. @@ -49,8 +52,14 @@ func (rr *RoundRobin) Balance(msg Message, partitions ...int) int { } func (rr *RoundRobin) balance(partitions []int) int { - length := uint32(len(partitions)) - offset := atomic.AddUint32(&rr.offset, 1) - 1 + if rr.ChunkSize < 1 { + rr.ChunkSize = 1 + } + + length := len(partitions) + counterNow := atomic.LoadUint32(&rr.counter) + offset := int(counterNow / uint32(rr.ChunkSize)) + atomic.AddUint32(&rr.counter, 1) return partitions[offset%length] } @@ -122,7 +131,7 @@ var ( // // The logic to calculate the partition is: // -// hasher.Sum32() % len(partitions) => partition +// hasher.Sum32() % len(partitions) => partition // // By default, Hash uses the FNV-1a algorithm. This is the same algorithm used // by the Sarama Producer and ensures that messages produced by kafka-go will @@ -173,7 +182,7 @@ func (h *Hash) Balance(msg Message, partitions ...int) int { // // The logic to calculate the partition is: // -// (int32(hasher.Sum32()) & 0x7fffffff) % len(partitions) => partition +// (int32(hasher.Sum32()) & 0x7fffffff) % len(partitions) => partition // // By default, ReferenceHash uses the FNV-1a algorithm. This is the same algorithm as // the Sarama NewReferenceHashPartitioner and ensures that messages produced by kafka-go will diff --git a/balancer_test.go b/balancer_test.go index a078f192..149bc680 100644 --- a/balancer_test.go +++ b/balancer_test.go @@ -411,3 +411,68 @@ func TestLeastBytes(t *testing.T) { }) } } + +func TestRoundRobin(t *testing.T) { + testCases := map[string]struct { + Partitions []int + ChunkSize int + }{ + "default - odd partition count": { + Partitions: []int{0, 1, 2, 3, 4, 5, 6}, + }, + "negative chunk size - odd partition count": { + Partitions: []int{0, 1, 2, 3, 4, 5, 6}, + ChunkSize: -1, + }, + "0 chunk size - odd partition count": { + Partitions: []int{0, 1, 2, 3, 4, 5, 6}, + ChunkSize: 0, + }, + "5 chunk size - odd partition count": { + Partitions: []int{0, 1, 2, 3, 4, 5, 6}, + ChunkSize: 5, + }, + "12 chunk size - odd partition count": { + Partitions: []int{0, 1, 2, 3, 4, 5, 6}, + ChunkSize: 12, + }, + "default - even partition count": { + Partitions: []int{0, 1, 2, 3, 4, 5, 6, 7}, + }, + "negative chunk size - even partition count": { + Partitions: []int{0, 1, 2, 3, 4, 5, 6, 7}, + ChunkSize: -1, + }, + "0 chunk size - even partition count": { + Partitions: []int{0, 1, 2, 3, 4, 5, 6, 7}, + ChunkSize: 0, + }, + "5 chunk size - even partition count": { + Partitions: []int{0, 1, 2, 3, 4, 5, 6, 7}, + ChunkSize: 5, + }, + "12 chunk size - even partition count": { + Partitions: []int{0, 1, 2, 3, 4, 5, 6, 7}, + ChunkSize: 12, + }, + } + for label, test := range testCases { + t.Run(label, func(t *testing.T) { + lb := &RoundRobin{ChunkSize: test.ChunkSize} + msg := Message{} + var partition int + var i int + expectedChunkSize := test.ChunkSize + if expectedChunkSize < 1 { + expectedChunkSize = 1 + } + partitions := test.Partitions + for i = 0; i < 50; i++ { + partition = lb.Balance(msg, partitions...) + if partition != i/expectedChunkSize%len(partitions) { + t.Error("Returned partition", partition, "expecting", i/expectedChunkSize%len(partitions)) + } + } + }) + } +}