Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dsmr: implement chunk rate limiter #1922

Merged
merged 16 commits into from
Feb 13, 2025
4 changes: 4 additions & 0 deletions x/dsmr/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type Validator struct {

type Rules interface {
GetValidityWindow() int64
GetMaxAccumulatedProducerChunkWeight() uint64
}

type RuleFactory interface {
Expand Down Expand Up @@ -183,6 +184,9 @@ func (n *Node[T]) BuildChunk(
// we have duplicates
return ErrDuplicateChunk
}
if err := n.storage.CheckRateLimit(chunk); err != nil {
return fmt.Errorf("failed to meet chunk rate limits threshold : %w", err)
}

packer := wrappers.Packer{MaxSize: MaxMessageSize}
if err := codec.LinearCodec.MarshalInto(chunkRef, &packer); err != nil {
Expand Down
14 changes: 9 additions & 5 deletions x/dsmr/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
const (
networkID = uint32(123)
testingDefaultValidityWindowDuration = 5 * time.Second
testingDefaultMaxProducerChunkWeight = 1024 * 1024
)

var (
Expand All @@ -45,7 +46,8 @@ var (
chainID = ids.Empty
testRuleFactory = ruleFactory{
rules: rules{
validityWindow: int64(testingDefaultValidityWindowDuration),
validityWindow: int64(testingDefaultValidityWindowDuration),
maxProducerChunkWeight: testingDefaultMaxProducerChunkWeight,
},
}

Expand Down Expand Up @@ -559,7 +561,7 @@ func TestNode_GetChunkSignature_SignValidChunk(t *testing.T) {
},
}

chunkStorage, err := NewChunkStorage[dsmrtest.Tx](tt.verifier, memdb.New())
chunkStorage, err := NewChunkStorage[dsmrtest.Tx](tt.verifier, memdb.New(), testRuleFactory)
r.NoError(err)

chainState := newTestChainState(validators, 1, 1)
Expand Down Expand Up @@ -1395,7 +1397,7 @@ func newTestNodes(t *testing.T, n int) []*Node[dsmrtest.Tx] {
chainState,
testRuleFactory,
)
chunkStorage, err := NewChunkStorage[dsmrtest.Tx](verifier, memdb.New())
chunkStorage, err := NewChunkStorage[dsmrtest.Tx](verifier, memdb.New(), testRuleFactory)
require.NoError(t, err)

getChunkHandler := &GetChunkHandler[dsmrtest.Tx]{
Expand Down Expand Up @@ -1587,7 +1589,9 @@ type ruleFactory struct {
func (r ruleFactory) GetRules(int64) Rules { return r.rules }

type rules struct {
validityWindow int64
validityWindow int64
maxProducerChunkWeight uint64
}

func (r rules) GetValidityWindow() int64 { return r.validityWindow }
func (r rules) GetValidityWindow() int64 { return r.validityWindow }
func (r rules) GetMaxAccumulatedProducerChunkWeight() uint64 { return r.maxProducerChunkWeight }
4 changes: 1 addition & 3 deletions x/dsmr/p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ func (c ChunkSignatureRequestVerifier[T]) Verify(
return ErrInvalidChunk
}

// check to see if this chunk was already accepted.
_, err = c.storage.GetChunkBytes(chunk.Expiry, chunk.id)
if err != nil && !errors.Is(err, database.ErrNotFound) {
if err := c.storage.CheckRateLimit(chunk); err != nil {
return &common.AppError{
Code: p2p.ErrUnexpected.Code,
Message: err.Error(),
Expand Down
78 changes: 58 additions & 20 deletions x/dsmr/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var minSlotKey []byte = []byte{metadataByte, minSlotByte}
var (
ErrChunkProducerNotValidator = errors.New("chunk producer is not in the validator set")
ErrInvalidChunkCertificate = errors.New("invalid chunk certificate")
ErrChunkRateLimitSurpassed = errors.New("chunk rate limit surpassed")
)

type Verifier[T Tx] interface {
Expand Down Expand Up @@ -81,8 +82,6 @@ func (c ChunkVerifier[T]) Verify(chunk Chunk[T]) error {
return fmt.Errorf("%w: producer node id %v", ErrChunkProducerNotValidator, chunk.UnsignedChunk.Producer)
}

// TODO:
// add rate limiting for a given producer.
return chunk.Verify(c.chainState.GetNetworkID(), c.chainState.GetChainID())
}

Expand Down Expand Up @@ -123,12 +122,18 @@ type ChunkStorage[T Tx] struct {

// TODO do we need caching
// Chunk + signature + cert
chunkMap map[ids.ID]*StoredChunkSignature[T]
pendingChunkMap map[ids.ID]*StoredChunkSignature[T]

// pendingChunksSizes map a chunk producer to the total size of storage being used for it's pending chunks.
pendingChunksSizes map[ids.NodeID]uint64

ruleFactory RuleFactory
}

func NewChunkStorage[T Tx](
verifier Verifier[T],
db database.Database,
ruleFactory RuleFactory,
) (*ChunkStorage[T], error) {
minSlot := int64(0)
minSlotBytes, err := db.Get(minSlotKey)
Expand All @@ -145,11 +150,13 @@ func NewChunkStorage[T Tx](
}

storage := &ChunkStorage[T]{
minimumExpiry: minSlot,
chunkEMap: emap.NewEMap[emapChunk[T]](),
chunkMap: make(map[ids.ID]*StoredChunkSignature[T]),
chunkDB: db,
verifier: verifier,
minimumExpiry: minSlot,
chunkEMap: emap.NewEMap[emapChunk[T]](),
pendingChunkMap: make(map[ids.ID]*StoredChunkSignature[T]),
pendingChunksSizes: make(map[ids.NodeID]uint64),
chunkDB: db,
verifier: verifier,
ruleFactory: ruleFactory,
}
return storage, storage.init()
}
Expand All @@ -168,7 +175,9 @@ func (s *ChunkStorage[T]) init() error {
return fmt.Errorf("failed to parse chunk %s at slot %d", chunkID, slot)
}
s.chunkEMap.Add([]emapChunk[T]{{chunk: chunk}})
s.chunkMap[chunk.id] = &StoredChunkSignature[T]{Chunk: chunk}
storedChunkSig := &StoredChunkSignature[T]{Chunk: chunk}
s.pendingChunkMap[chunk.id] = storedChunkSig
s.pendingChunksSizes[chunk.Producer] += uint64(len(chunk.bytes))
}

if err := iter.Error(); err != nil {
Expand All @@ -192,7 +201,7 @@ func (s *ChunkStorage[T]) SetChunkCert(ctx context.Context, chunkID ids.ID, cert
s.lock.Lock()
defer s.lock.Unlock()

storedChunk, ok := s.chunkMap[chunkID]
storedChunk, ok := s.pendingChunkMap[chunkID]
if !ok {
return fmt.Errorf("failed to store cert for non-existent chunk: %s", chunkID)
}
Expand All @@ -211,11 +220,12 @@ func (s *ChunkStorage[T]) SetChunkCert(ctx context.Context, chunkID ids.ID, cert
// 4. Return the local signature share
// TODO refactor and merge with AddLocalChunkWithCert
// Assumes caller has already verified this does not add a duplicate chunk
// Assumes that if the given chunk is a pending chunk, it would not surpass the producer's rate limit.
func (s *ChunkStorage[T]) VerifyRemoteChunk(c Chunk[T]) (*warp.BitSetSignature, error) {
s.lock.Lock()
defer s.lock.Unlock()

chunkCertInfo, ok := s.chunkMap[c.id]
chunkCertInfo, ok := s.pendingChunkMap[c.id]
if ok {
return chunkCertInfo.Cert.Signature, nil
}
Expand All @@ -228,20 +238,25 @@ func (s *ChunkStorage[T]) VerifyRemoteChunk(c Chunk[T]) (*warp.BitSetSignature,
return nil, nil
}

// putVerifiedChunk assumes that the given chunk is guaranteed not to surpass the producer's rate limit.
// The rate limit is being checked via a call to CheckRateLimit from BuildChunk ( for local generate chunks )
// as well as by ChunkSignatureRequestVerifier.Verify for incoming chunk signature requests.
aaronbuchwald marked this conversation as resolved.
Show resolved Hide resolved
func (s *ChunkStorage[T]) putVerifiedChunk(c Chunk[T], cert *ChunkCertificate) error {
if err := s.chunkDB.Put(pendingChunkKey(c.Expiry, c.id), c.bytes); err != nil {
return err
}
s.chunkEMap.Add([]emapChunk[T]{{chunk: c}})

if chunkCert, ok := s.chunkMap[c.id]; ok {
if chunkCert, ok := s.pendingChunkMap[c.id]; ok {
if cert != nil {
chunkCert.Cert = cert
}
return nil
}
chunkCert := &StoredChunkSignature[T]{Chunk: c, Cert: cert}
s.chunkMap[c.id] = chunkCert
s.pendingChunkMap[c.id] = chunkCert
s.pendingChunksSizes[c.Producer] += uint64(len(c.bytes))

return nil
}

Expand All @@ -260,22 +275,22 @@ func (s *ChunkStorage[T]) SetMin(updatedMin int64, saveChunks []ids.ID) error {
return fmt.Errorf("failed to update persistent min slot: %w", err)
}
for _, saveChunkID := range saveChunks {
chunk, ok := s.chunkMap[saveChunkID]
chunk, ok := s.pendingChunkMap[saveChunkID]
if !ok {
return fmt.Errorf("failed to save chunk %s", saveChunkID)
}
if err := batch.Put(acceptedChunkKey(chunk.Chunk.Expiry, chunk.Chunk.id), chunk.Chunk.bytes); err != nil {
return fmt.Errorf("failed to save chunk %s: %w", saveChunkID, err)
}
delete(s.chunkMap, saveChunkID)
s.discardPendingChunk(saveChunkID)
}
expiredChunks := s.chunkEMap.SetMin(updatedMin)
for _, chunkID := range expiredChunks {
chunk, ok := s.chunkMap[chunkID]
chunk, ok := s.pendingChunkMap[chunkID]
if !ok {
continue
}
delete(s.chunkMap, chunkID)
s.discardPendingChunk(chunkID)
// TODO: switch to using DeleteRange(nil, pendingChunkKey(updatedMin, ids.Empty)) after
// merging main
if err := batch.Delete(pendingChunkKey(chunk.Chunk.Expiry, chunk.Chunk.id)); err != nil {
Expand All @@ -290,15 +305,29 @@ func (s *ChunkStorage[T]) SetMin(updatedMin int64, saveChunks []ids.ID) error {
return nil
}

// discardPendingChunk removes the given chunkID from the
// pending chunk map as well as from the pending chunks producers map.
func (s *ChunkStorage[T]) discardPendingChunk(chunkID ids.ID) {
chunk, ok := s.pendingChunkMap[chunkID]
if !ok {
return
}
delete(s.pendingChunkMap, chunkID)
s.pendingChunksSizes[chunk.Chunk.Producer] -= uint64(len(chunk.Chunk.bytes))
if s.pendingChunksSizes[chunk.Chunk.Producer] == 0 {
delete(s.pendingChunksSizes, chunk.Chunk.Producer)
}
}

// GatherChunkCerts provides a slice of chunk certificates to build
// a chunk based block.
// TODO: switch from returning random chunk certs to ordered by expiry
func (s *ChunkStorage[T]) GatherChunkCerts() []*ChunkCertificate {
s.lock.RLock()
defer s.lock.RUnlock()

chunkCerts := make([]*ChunkCertificate, 0, len(s.chunkMap))
for _, chunk := range s.chunkMap {
chunkCerts := make([]*ChunkCertificate, 0, len(s.pendingChunkMap))
for _, chunk := range s.pendingChunkMap {
if chunk.Cert == nil {
continue
}
Expand All @@ -314,7 +343,7 @@ func (s *ChunkStorage[T]) GetChunkBytes(expiry int64, chunkID ids.ID) ([]byte, e
s.lock.RLock()
defer s.lock.RUnlock()

chunk, ok := s.chunkMap[chunkID]
chunk, ok := s.pendingChunkMap[chunkID]
if ok {
return chunk.Chunk.bytes, nil
}
Expand All @@ -326,6 +355,15 @@ func (s *ChunkStorage[T]) GetChunkBytes(expiry int64, chunkID ids.ID) ([]byte, e
return chunkBytes, nil
}

func (s *ChunkStorage[T]) CheckRateLimit(chunk Chunk[T]) error {
weightLimit := s.ruleFactory.GetRules(chunk.Expiry).GetMaxAccumulatedProducerChunkWeight()

if uint64(len(chunk.bytes))+s.pendingChunksSizes[chunk.Producer] > weightLimit {
return ErrChunkRateLimitSurpassed
}
return nil
}
aaronbuchwald marked this conversation as resolved.
Show resolved Hide resolved

func createChunkKey(prefix byte, slot int64, chunkID ids.ID) []byte {
b := make([]byte, chunkKeySize)
b[0] = prefix
Expand Down
Loading
Loading