diff --git a/consumergroup.go b/consumergroup.go index f4bb382c..705cc0d0 100644 --- a/consumergroup.go +++ b/consumergroup.go @@ -497,7 +497,7 @@ func (g *Generation) heartbeatLoop(interval time.Duration) { // a bad spot and should rebalance. Commonly you will see an error here if there // is a problem with the connection to the coordinator and a rebalance will // establish a new connection to the coordinator. -func (g *Generation) partitionWatcher(interval time.Duration, topic string) { +func (g *Generation) partitionWatcher(interval time.Duration, topic string, startPartitions int) { g.Start(func(ctx context.Context) { g.log(func(l Logger) { l.Printf("started partition watcher for group, %v, topic %v [%v]", g.GroupID, topic, interval) @@ -509,14 +509,6 @@ func (g *Generation) partitionWatcher(interval time.Duration, topic string) { ticker := time.NewTicker(interval) defer ticker.Stop() - ops, err := g.conn.readPartitions(topic) - if err != nil { - g.logError(func(l Logger) { - l.Printf("Problem getting partitions during startup, %v\n, Returning and setting up nextGeneration", err) - }) - return - } - oParts := len(ops) for { select { case <-ctx.Done(): @@ -525,7 +517,7 @@ func (g *Generation) partitionWatcher(interval time.Duration, topic string) { ops, err := g.conn.readPartitions(topic) switch { case err == nil, errors.Is(err, UnknownTopicOrPartition): - if len(ops) != oParts { + if len(ops) != startPartitions { g.log(func(l Logger) { l.Printf("Partition changes found, rebalancing group: %v.", g.GroupID) }) @@ -651,11 +643,13 @@ func NewConsumerGroup(config ConsumerGroupConfig) (*ConsumerGroup, error) { } cg := &ConsumerGroup{ - config: config, - next: make(chan *Generation), - errs: make(chan error), - done: make(chan struct{}), + config: config, + partitionsPerTopic: make(map[string]int, len(config.Topics)), + next: make(chan *Generation), + errs: make(chan error), + done: make(chan struct{}), } + cg.wg.Add(1) go func() { cg.run() @@ -670,9 +664,10 @@ func NewConsumerGroup(config ConsumerGroupConfig) (*ConsumerGroup, error) { // Generation is where partition assignments and offset management occur. // Callers will use Next to get a handle to the Generation. type ConsumerGroup struct { - config ConsumerGroupConfig - next chan *Generation - errs chan error + config ConsumerGroupConfig + partitionsPerTopic map[string]int + next chan *Generation + errs chan error closeOnce sync.Once wg sync.WaitGroup @@ -789,13 +784,9 @@ func (cg *ConsumerGroup) nextGeneration(memberID string) (string, error) { } defer conn.Close() - var generationID int32 - var groupAssignments GroupMemberAssignments - var assignments map[string][]int32 - // join group. this will join the group and prepare assignments if our // consumer is elected leader. it may also change or assign the member ID. - memberID, generationID, groupAssignments, err = cg.joinGroup(conn, memberID) + memberID, generationID, groupAssignments, err := cg.joinGroup(conn, memberID) if err != nil { cg.withErrorLogger(func(log Logger) { log.Printf("Failed to join group %s: %v", cg.config.ID, err) @@ -807,7 +798,7 @@ func (cg *ConsumerGroup) nextGeneration(memberID string) (string, error) { }) // sync group - assignments, err = cg.syncGroup(conn, memberID, generationID, groupAssignments) + assignments, err := cg.syncGroup(conn, memberID, generationID, groupAssignments) if err != nil { cg.withErrorLogger(func(log Logger) { log.Printf("Failed to sync group %s: %v", cg.config.ID, err) @@ -844,8 +835,8 @@ func (cg *ConsumerGroup) nextGeneration(memberID string) (string, error) { // complete. gen.heartbeatLoop(cg.config.HeartbeatInterval) if cg.config.WatchPartitionChanges { - for _, topic := range cg.config.Topics { - gen.partitionWatcher(cg.config.PartitionWatchInterval, topic) + for topic, startPartitions := range cg.partitionsPerTopic { + gen.partitionWatcher(cg.config.PartitionWatchInterval, topic, startPartitions) } } @@ -953,6 +944,7 @@ func (cg *ConsumerGroup) joinGroup(conn coordinator, memberID string) (string, i }) var assignments GroupMemberAssignments + if iAmLeader := response.MemberID == response.LeaderID; iAmLeader { v, err := cg.assignTopicPartitions(conn, response) if err != nil { @@ -1036,6 +1028,14 @@ func (cg *ConsumerGroup) assignTopicPartitions(conn coordinator, group joinGroup return nil, err } + // resetting old values of partitions per topic + cg.partitionsPerTopic = make(map[string]int, len(topics)) + + // setting new values of partitions per topic + for _, partition := range partitions { + cg.partitionsPerTopic[partition.Topic] += 1 + } + cg.withLogger(func(l Logger) { l.Printf("using '%v' balancer to assign group, %v", group.GroupProtocol, cg.config.ID) for _, member := range members { diff --git a/consumergroup_test.go b/consumergroup_test.go index 0d3e290a..884e6546 100644 --- a/consumergroup_test.go +++ b/consumergroup_test.go @@ -637,7 +637,7 @@ func TestGenerationExitsOnPartitionChange(t *testing.T) { done := make(chan struct{}) go func() { - gen.partitionWatcher(watchTime, "topic-1") + gen.partitionWatcher(watchTime, "topic-1", 1) close(done) }()