Skip to content

Commit

Permalink
dsmr: implement chunk rate limiter (#1922)
Browse files Browse the repository at this point in the history
Signed-off-by: Tsachi Herman <[email protected]>
  • Loading branch information
tsachiherman authored Feb 13, 2025
1 parent cdd2245 commit 2ad6559
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 37 deletions.
4 changes: 4 additions & 0 deletions x/dsmr/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type Validator struct {

type Rules interface {
GetValidityWindow() int64
GetMaxAccumulatedProducerChunkWeight() uint64
}

type RuleFactory interface {
Expand Down Expand Up @@ -183,6 +184,9 @@ func (n *Node[T]) BuildChunk(
// we have duplicates
return ErrDuplicateChunk
}
if err := n.storage.CheckRateLimit(chunk); err != nil {
return fmt.Errorf("failed to meet chunk rate limits threshold : %w", err)
}

packer := wrappers.Packer{MaxSize: MaxMessageSize}
if err := codec.LinearCodec.MarshalInto(chunkRef, &packer); err != nil {
Expand Down
14 changes: 9 additions & 5 deletions x/dsmr/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
const (
networkID = uint32(123)
testingDefaultValidityWindowDuration = 5 * time.Second
testingDefaultMaxProducerChunkWeight = 1024 * 1024
)

var (
Expand All @@ -45,7 +46,8 @@ var (
chainID = ids.Empty
testRuleFactory = ruleFactory{
rules: rules{
validityWindow: int64(testingDefaultValidityWindowDuration),
validityWindow: int64(testingDefaultValidityWindowDuration),
maxProducerChunkWeight: testingDefaultMaxProducerChunkWeight,
},
}

Expand Down Expand Up @@ -559,7 +561,7 @@ func TestNode_GetChunkSignature_SignValidChunk(t *testing.T) {
},
}

chunkStorage, err := NewChunkStorage[dsmrtest.Tx](tt.verifier, memdb.New())
chunkStorage, err := NewChunkStorage[dsmrtest.Tx](tt.verifier, memdb.New(), testRuleFactory)
r.NoError(err)

chainState := newTestChainState(validators, 1, 1)
Expand Down Expand Up @@ -1395,7 +1397,7 @@ func newTestNodes(t *testing.T, n int) []*Node[dsmrtest.Tx] {
chainState,
testRuleFactory,
)
chunkStorage, err := NewChunkStorage[dsmrtest.Tx](verifier, memdb.New())
chunkStorage, err := NewChunkStorage[dsmrtest.Tx](verifier, memdb.New(), testRuleFactory)
require.NoError(t, err)

getChunkHandler := &GetChunkHandler[dsmrtest.Tx]{
Expand Down Expand Up @@ -1587,7 +1589,9 @@ type ruleFactory struct {
func (r ruleFactory) GetRules(int64) Rules { return r.rules }

type rules struct {
validityWindow int64
validityWindow int64
maxProducerChunkWeight uint64
}

func (r rules) GetValidityWindow() int64 { return r.validityWindow }
func (r rules) GetValidityWindow() int64 { return r.validityWindow }
func (r rules) GetMaxAccumulatedProducerChunkWeight() uint64 { return r.maxProducerChunkWeight }
4 changes: 1 addition & 3 deletions x/dsmr/p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ func (c ChunkSignatureRequestVerifier[T]) Verify(
return ErrInvalidChunk
}

// check to see if this chunk was already accepted.
_, err = c.storage.GetChunkBytes(chunk.Expiry, chunk.id)
if err != nil && !errors.Is(err, database.ErrNotFound) {
if err := c.storage.CheckRateLimit(chunk); err != nil {
return &common.AppError{
Code: p2p.ErrUnexpected.Code,
Message: err.Error(),
Expand Down
78 changes: 58 additions & 20 deletions x/dsmr/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var minSlotKey []byte = []byte{metadataByte, minSlotByte}
var (
ErrChunkProducerNotValidator = errors.New("chunk producer is not in the validator set")
ErrInvalidChunkCertificate = errors.New("invalid chunk certificate")
ErrChunkRateLimitSurpassed = errors.New("chunk rate limit surpassed")
)

type Verifier[T Tx] interface {
Expand Down Expand Up @@ -81,8 +82,6 @@ func (c ChunkVerifier[T]) Verify(chunk Chunk[T]) error {
return fmt.Errorf("%w: producer node id %v", ErrChunkProducerNotValidator, chunk.UnsignedChunk.Producer)
}

// TODO:
// add rate limiting for a given producer.
return chunk.Verify(c.chainState.GetNetworkID(), c.chainState.GetChainID())
}

Expand Down Expand Up @@ -123,12 +122,18 @@ type ChunkStorage[T Tx] struct {

// TODO do we need caching
// Chunk + signature + cert
chunkMap map[ids.ID]*StoredChunkSignature[T]
pendingChunkMap map[ids.ID]*StoredChunkSignature[T]

// pendingChunksSizes map a chunk producer to the total size of storage being used for it's pending chunks.
pendingChunksSizes map[ids.NodeID]uint64

ruleFactory RuleFactory
}

func NewChunkStorage[T Tx](
verifier Verifier[T],
db database.Database,
ruleFactory RuleFactory,
) (*ChunkStorage[T], error) {
minSlot := int64(0)
minSlotBytes, err := db.Get(minSlotKey)
Expand All @@ -145,11 +150,13 @@ func NewChunkStorage[T Tx](
}

storage := &ChunkStorage[T]{
minimumExpiry: minSlot,
chunkEMap: emap.NewEMap[emapChunk[T]](),
chunkMap: make(map[ids.ID]*StoredChunkSignature[T]),
chunkDB: db,
verifier: verifier,
minimumExpiry: minSlot,
chunkEMap: emap.NewEMap[emapChunk[T]](),
pendingChunkMap: make(map[ids.ID]*StoredChunkSignature[T]),
pendingChunksSizes: make(map[ids.NodeID]uint64),
chunkDB: db,
verifier: verifier,
ruleFactory: ruleFactory,
}
return storage, storage.init()
}
Expand All @@ -168,7 +175,9 @@ func (s *ChunkStorage[T]) init() error {
return fmt.Errorf("failed to parse chunk %s at slot %d", chunkID, slot)
}
s.chunkEMap.Add([]emapChunk[T]{{chunk: chunk}})
s.chunkMap[chunk.id] = &StoredChunkSignature[T]{Chunk: chunk}
storedChunkSig := &StoredChunkSignature[T]{Chunk: chunk}
s.pendingChunkMap[chunk.id] = storedChunkSig
s.pendingChunksSizes[chunk.Producer] += uint64(len(chunk.bytes))
}

if err := iter.Error(); err != nil {
Expand All @@ -192,7 +201,7 @@ func (s *ChunkStorage[T]) SetChunkCert(ctx context.Context, chunkID ids.ID, cert
s.lock.Lock()
defer s.lock.Unlock()

storedChunk, ok := s.chunkMap[chunkID]
storedChunk, ok := s.pendingChunkMap[chunkID]
if !ok {
return fmt.Errorf("failed to store cert for non-existent chunk: %s", chunkID)
}
Expand All @@ -211,11 +220,12 @@ func (s *ChunkStorage[T]) SetChunkCert(ctx context.Context, chunkID ids.ID, cert
// 4. Return the local signature share
// TODO refactor and merge with AddLocalChunkWithCert
// Assumes caller has already verified this does not add a duplicate chunk
// Assumes that if the given chunk is a pending chunk, it would not surpass the producer's rate limit.
func (s *ChunkStorage[T]) VerifyRemoteChunk(c Chunk[T]) (*warp.BitSetSignature, error) {
s.lock.Lock()
defer s.lock.Unlock()

chunkCertInfo, ok := s.chunkMap[c.id]
chunkCertInfo, ok := s.pendingChunkMap[c.id]
if ok {
return chunkCertInfo.Cert.Signature, nil
}
Expand All @@ -228,20 +238,25 @@ func (s *ChunkStorage[T]) VerifyRemoteChunk(c Chunk[T]) (*warp.BitSetSignature,
return nil, nil
}

// putVerifiedChunk assumes that the given chunk is guaranteed not to surpass the producer's rate limit.
// The rate limit is being checked via a call to CheckRateLimit from BuildChunk (for locally generated chunks)
// and ChunkSignatureRequestVerifier.Verify for incoming chunk signature requests.
func (s *ChunkStorage[T]) putVerifiedChunk(c Chunk[T], cert *ChunkCertificate) error {
if err := s.chunkDB.Put(pendingChunkKey(c.Expiry, c.id), c.bytes); err != nil {
return err
}
s.chunkEMap.Add([]emapChunk[T]{{chunk: c}})

if chunkCert, ok := s.chunkMap[c.id]; ok {
if chunkCert, ok := s.pendingChunkMap[c.id]; ok {
if cert != nil {
chunkCert.Cert = cert
}
return nil
}
chunkCert := &StoredChunkSignature[T]{Chunk: c, Cert: cert}
s.chunkMap[c.id] = chunkCert
s.pendingChunkMap[c.id] = chunkCert
s.pendingChunksSizes[c.Producer] += uint64(len(c.bytes))

return nil
}

Expand All @@ -260,22 +275,22 @@ func (s *ChunkStorage[T]) SetMin(updatedMin int64, saveChunks []ids.ID) error {
return fmt.Errorf("failed to update persistent min slot: %w", err)
}
for _, saveChunkID := range saveChunks {
chunk, ok := s.chunkMap[saveChunkID]
chunk, ok := s.pendingChunkMap[saveChunkID]
if !ok {
return fmt.Errorf("failed to save chunk %s", saveChunkID)
}
if err := batch.Put(acceptedChunkKey(chunk.Chunk.Expiry, chunk.Chunk.id), chunk.Chunk.bytes); err != nil {
return fmt.Errorf("failed to save chunk %s: %w", saveChunkID, err)
}
delete(s.chunkMap, saveChunkID)
s.discardPendingChunk(saveChunkID)
}
expiredChunks := s.chunkEMap.SetMin(updatedMin)
for _, chunkID := range expiredChunks {
chunk, ok := s.chunkMap[chunkID]
chunk, ok := s.pendingChunkMap[chunkID]
if !ok {
continue
}
delete(s.chunkMap, chunkID)
s.discardPendingChunk(chunkID)
// TODO: switch to using DeleteRange(nil, pendingChunkKey(updatedMin, ids.Empty)) after
// merging main
if err := batch.Delete(pendingChunkKey(chunk.Chunk.Expiry, chunk.Chunk.id)); err != nil {
Expand All @@ -290,15 +305,29 @@ func (s *ChunkStorage[T]) SetMin(updatedMin int64, saveChunks []ids.ID) error {
return nil
}

// discardPendingChunk removes the given chunkID from the
// pending chunk map as well as from the pending chunks producers map.
func (s *ChunkStorage[T]) discardPendingChunk(chunkID ids.ID) {
chunk, ok := s.pendingChunkMap[chunkID]
if !ok {
return
}
delete(s.pendingChunkMap, chunkID)
s.pendingChunksSizes[chunk.Chunk.Producer] -= uint64(len(chunk.Chunk.bytes))
if s.pendingChunksSizes[chunk.Chunk.Producer] == 0 {
delete(s.pendingChunksSizes, chunk.Chunk.Producer)
}
}

// GatherChunkCerts provides a slice of chunk certificates to build
// a chunk based block.
// TODO: switch from returning random chunk certs to ordered by expiry
func (s *ChunkStorage[T]) GatherChunkCerts() []*ChunkCertificate {
s.lock.RLock()
defer s.lock.RUnlock()

chunkCerts := make([]*ChunkCertificate, 0, len(s.chunkMap))
for _, chunk := range s.chunkMap {
chunkCerts := make([]*ChunkCertificate, 0, len(s.pendingChunkMap))
for _, chunk := range s.pendingChunkMap {
if chunk.Cert == nil {
continue
}
Expand All @@ -314,7 +343,7 @@ func (s *ChunkStorage[T]) GetChunkBytes(expiry int64, chunkID ids.ID) ([]byte, e
s.lock.RLock()
defer s.lock.RUnlock()

chunk, ok := s.chunkMap[chunkID]
chunk, ok := s.pendingChunkMap[chunkID]
if ok {
return chunk.Chunk.bytes, nil
}
Expand All @@ -326,6 +355,15 @@ func (s *ChunkStorage[T]) GetChunkBytes(expiry int64, chunkID ids.ID) ([]byte, e
return chunkBytes, nil
}

func (s *ChunkStorage[T]) CheckRateLimit(chunk Chunk[T]) error {
weightLimit := s.ruleFactory.GetRules(chunk.Expiry).GetMaxAccumulatedProducerChunkWeight()

if uint64(len(chunk.bytes))+s.pendingChunksSizes[chunk.Producer] > weightLimit {
return ErrChunkRateLimitSurpassed
}
return nil
}

func createChunkKey(prefix byte, slot int64, chunkID ids.ID) []byte {
b := make([]byte, chunkKeySize)
b[0] = prefix
Expand Down
Loading

0 comments on commit 2ad6559

Please sign in to comment.