diff --git a/x/dsmr/node.go b/x/dsmr/node.go index 3a7b40863a..0645853dba 100644 --- a/x/dsmr/node.go +++ b/x/dsmr/node.go @@ -65,6 +65,7 @@ type Validator struct { type Rules interface { GetValidityWindow() int64 + GetMaxAccumulatedProducerChunkWeight() uint64 } type RuleFactory interface { @@ -183,6 +184,9 @@ func (n *Node[T]) BuildChunk( // we have duplicates return ErrDuplicateChunk } + if err := n.storage.CheckRateLimit(chunk); err != nil { + return fmt.Errorf("failed to meet chunk rate limits threshold : %w", err) + } packer := wrappers.Packer{MaxSize: MaxMessageSize} if err := codec.LinearCodec.MarshalInto(chunkRef, &packer); err != nil { diff --git a/x/dsmr/node_test.go b/x/dsmr/node_test.go index 84fa3ea88b..8d446533a9 100644 --- a/x/dsmr/node_test.go +++ b/x/dsmr/node_test.go @@ -35,6 +35,7 @@ import ( const ( networkID = uint32(123) testingDefaultValidityWindowDuration = 5 * time.Second + testingDefaultMaxProducerChunkWeight = 1024 * 1024 ) var ( @@ -45,7 +46,8 @@ var ( chainID = ids.Empty testRuleFactory = ruleFactory{ rules: rules{ - validityWindow: int64(testingDefaultValidityWindowDuration), + validityWindow: int64(testingDefaultValidityWindowDuration), + maxProducerChunkWeight: testingDefaultMaxProducerChunkWeight, }, } @@ -559,7 +561,7 @@ func TestNode_GetChunkSignature_SignValidChunk(t *testing.T) { }, } - chunkStorage, err := NewChunkStorage[dsmrtest.Tx](tt.verifier, memdb.New()) + chunkStorage, err := NewChunkStorage[dsmrtest.Tx](tt.verifier, memdb.New(), testRuleFactory) r.NoError(err) chainState := newTestChainState(validators, 1, 1) @@ -1395,7 +1397,7 @@ func newTestNodes(t *testing.T, n int) []*Node[dsmrtest.Tx] { chainState, testRuleFactory, ) - chunkStorage, err := NewChunkStorage[dsmrtest.Tx](verifier, memdb.New()) + chunkStorage, err := NewChunkStorage[dsmrtest.Tx](verifier, memdb.New(), testRuleFactory) require.NoError(t, err) getChunkHandler := &GetChunkHandler[dsmrtest.Tx]{ @@ -1587,7 +1589,9 @@ type ruleFactory struct { func (r ruleFactory) GetRules(int64) Rules { return r.rules } type rules struct { - validityWindow int64 + validityWindow int64 + maxProducerChunkWeight uint64 } -func (r rules) GetValidityWindow() int64 { return r.validityWindow } +func (r rules) GetValidityWindow() int64 { return r.validityWindow } +func (r rules) GetMaxAccumulatedProducerChunkWeight() uint64 { return r.maxProducerChunkWeight } diff --git a/x/dsmr/p2p.go b/x/dsmr/p2p.go index 5fdf298a6a..f79187ea20 100644 --- a/x/dsmr/p2p.go +++ b/x/dsmr/p2p.go @@ -127,9 +127,7 @@ func (c ChunkSignatureRequestVerifier[T]) Verify( return ErrInvalidChunk } - // check to see if this chunk was already accepted. - _, err = c.storage.GetChunkBytes(chunk.Expiry, chunk.id) - if err != nil && !errors.Is(err, database.ErrNotFound) { + if err := c.storage.CheckRateLimit(chunk); err != nil { return &common.AppError{ Code: p2p.ErrUnexpected.Code, Message: err.Error(), diff --git a/x/dsmr/storage.go b/x/dsmr/storage.go index dafa1b56bc..4c4fcc6078 100644 --- a/x/dsmr/storage.go +++ b/x/dsmr/storage.go @@ -35,6 +35,7 @@ var minSlotKey []byte = []byte{metadataByte, minSlotByte} var ( ErrChunkProducerNotValidator = errors.New("chunk producer is not in the validator set") ErrInvalidChunkCertificate = errors.New("invalid chunk certificate") + ErrChunkRateLimitSurpassed = errors.New("chunk rate limit surpassed") ) type Verifier[T Tx] interface { @@ -81,8 +82,6 @@ func (c ChunkVerifier[T]) Verify(chunk Chunk[T]) error { return fmt.Errorf("%w: producer node id %v", ErrChunkProducerNotValidator, chunk.UnsignedChunk.Producer) } - // TODO: - // add rate limiting for a given producer. return chunk.Verify(c.chainState.GetNetworkID(), c.chainState.GetChainID()) } @@ -123,12 +122,18 @@ type ChunkStorage[T Tx] struct { // TODO do we need caching // Chunk + signature + cert - chunkMap map[ids.ID]*StoredChunkSignature[T] + pendingChunkMap map[ids.ID]*StoredChunkSignature[T] + + // pendingChunksSizes map a chunk producer to the total size of storage being used for it's pending chunks. + pendingChunksSizes map[ids.NodeID]uint64 + + ruleFactory RuleFactory } func NewChunkStorage[T Tx]( verifier Verifier[T], db database.Database, + ruleFactory RuleFactory, ) (*ChunkStorage[T], error) { minSlot := int64(0) minSlotBytes, err := db.Get(minSlotKey) @@ -145,11 +150,13 @@ func NewChunkStorage[T Tx]( } storage := &ChunkStorage[T]{ - minimumExpiry: minSlot, - chunkEMap: emap.NewEMap[emapChunk[T]](), - chunkMap: make(map[ids.ID]*StoredChunkSignature[T]), - chunkDB: db, - verifier: verifier, + minimumExpiry: minSlot, + chunkEMap: emap.NewEMap[emapChunk[T]](), + pendingChunkMap: make(map[ids.ID]*StoredChunkSignature[T]), + pendingChunksSizes: make(map[ids.NodeID]uint64), + chunkDB: db, + verifier: verifier, + ruleFactory: ruleFactory, } return storage, storage.init() } @@ -168,7 +175,9 @@ func (s *ChunkStorage[T]) init() error { return fmt.Errorf("failed to parse chunk %s at slot %d", chunkID, slot) } s.chunkEMap.Add([]emapChunk[T]{{chunk: chunk}}) - s.chunkMap[chunk.id] = &StoredChunkSignature[T]{Chunk: chunk} + storedChunkSig := &StoredChunkSignature[T]{Chunk: chunk} + s.pendingChunkMap[chunk.id] = storedChunkSig + s.pendingChunksSizes[chunk.Producer] += uint64(len(chunk.bytes)) } if err := iter.Error(); err != nil { @@ -192,7 +201,7 @@ func (s *ChunkStorage[T]) SetChunkCert(ctx context.Context, chunkID ids.ID, cert s.lock.Lock() defer s.lock.Unlock() - storedChunk, ok := s.chunkMap[chunkID] + storedChunk, ok := s.pendingChunkMap[chunkID] if !ok { return fmt.Errorf("failed to store cert for non-existent chunk: %s", chunkID) } @@ -211,11 +220,12 @@ func (s *ChunkStorage[T]) SetChunkCert(ctx context.Context, chunkID ids.ID, cert // 4. Return the local signature share // TODO refactor and merge with AddLocalChunkWithCert // Assumes caller has already verified this does not add a duplicate chunk +// Assumes that if the given chunk is a pending chunk, it would not surpass the producer's rate limit. func (s *ChunkStorage[T]) VerifyRemoteChunk(c Chunk[T]) (*warp.BitSetSignature, error) { s.lock.Lock() defer s.lock.Unlock() - chunkCertInfo, ok := s.chunkMap[c.id] + chunkCertInfo, ok := s.pendingChunkMap[c.id] if ok { return chunkCertInfo.Cert.Signature, nil } @@ -228,20 +238,25 @@ func (s *ChunkStorage[T]) VerifyRemoteChunk(c Chunk[T]) (*warp.BitSetSignature, return nil, nil } +// putVerifiedChunk assumes that the given chunk is guaranteed not to surpass the producer's rate limit. +// The rate limit is being checked via a call to CheckRateLimit from BuildChunk (for locally generated chunks) +// and ChunkSignatureRequestVerifier.Verify for incoming chunk signature requests. func (s *ChunkStorage[T]) putVerifiedChunk(c Chunk[T], cert *ChunkCertificate) error { if err := s.chunkDB.Put(pendingChunkKey(c.Expiry, c.id), c.bytes); err != nil { return err } s.chunkEMap.Add([]emapChunk[T]{{chunk: c}}) - if chunkCert, ok := s.chunkMap[c.id]; ok { + if chunkCert, ok := s.pendingChunkMap[c.id]; ok { if cert != nil { chunkCert.Cert = cert } return nil } chunkCert := &StoredChunkSignature[T]{Chunk: c, Cert: cert} - s.chunkMap[c.id] = chunkCert + s.pendingChunkMap[c.id] = chunkCert + s.pendingChunksSizes[c.Producer] += uint64(len(c.bytes)) + return nil } @@ -260,22 +275,22 @@ func (s *ChunkStorage[T]) SetMin(updatedMin int64, saveChunks []ids.ID) error { return fmt.Errorf("failed to update persistent min slot: %w", err) } for _, saveChunkID := range saveChunks { - chunk, ok := s.chunkMap[saveChunkID] + chunk, ok := s.pendingChunkMap[saveChunkID] if !ok { return fmt.Errorf("failed to save chunk %s", saveChunkID) } if err := batch.Put(acceptedChunkKey(chunk.Chunk.Expiry, chunk.Chunk.id), chunk.Chunk.bytes); err != nil { return fmt.Errorf("failed to save chunk %s: %w", saveChunkID, err) } - delete(s.chunkMap, saveChunkID) + s.discardPendingChunk(saveChunkID) } expiredChunks := s.chunkEMap.SetMin(updatedMin) for _, chunkID := range expiredChunks { - chunk, ok := s.chunkMap[chunkID] + chunk, ok := s.pendingChunkMap[chunkID] if !ok { continue } - delete(s.chunkMap, chunkID) + s.discardPendingChunk(chunkID) // TODO: switch to using DeleteRange(nil, pendingChunkKey(updatedMin, ids.Empty)) after // merging main if err := batch.Delete(pendingChunkKey(chunk.Chunk.Expiry, chunk.Chunk.id)); err != nil { @@ -290,6 +305,20 @@ func (s *ChunkStorage[T]) SetMin(updatedMin int64, saveChunks []ids.ID) error { return nil } +// discardPendingChunk removes the given chunkID from the +// pending chunk map as well as from the pending chunks producers map. +func (s *ChunkStorage[T]) discardPendingChunk(chunkID ids.ID) { + chunk, ok := s.pendingChunkMap[chunkID] + if !ok { + return + } + delete(s.pendingChunkMap, chunkID) + s.pendingChunksSizes[chunk.Chunk.Producer] -= uint64(len(chunk.Chunk.bytes)) + if s.pendingChunksSizes[chunk.Chunk.Producer] == 0 { + delete(s.pendingChunksSizes, chunk.Chunk.Producer) + } +} + // GatherChunkCerts provides a slice of chunk certificates to build // a chunk based block. // TODO: switch from returning random chunk certs to ordered by expiry @@ -297,8 +326,8 @@ func (s *ChunkStorage[T]) GatherChunkCerts() []*ChunkCertificate { s.lock.RLock() defer s.lock.RUnlock() - chunkCerts := make([]*ChunkCertificate, 0, len(s.chunkMap)) - for _, chunk := range s.chunkMap { + chunkCerts := make([]*ChunkCertificate, 0, len(s.pendingChunkMap)) + for _, chunk := range s.pendingChunkMap { if chunk.Cert == nil { continue } @@ -314,7 +343,7 @@ func (s *ChunkStorage[T]) GetChunkBytes(expiry int64, chunkID ids.ID) ([]byte, e s.lock.RLock() defer s.lock.RUnlock() - chunk, ok := s.chunkMap[chunkID] + chunk, ok := s.pendingChunkMap[chunkID] if ok { return chunk.Chunk.bytes, nil } @@ -326,6 +355,15 @@ func (s *ChunkStorage[T]) GetChunkBytes(expiry int64, chunkID ids.ID) ([]byte, e return chunkBytes, nil } +func (s *ChunkStorage[T]) CheckRateLimit(chunk Chunk[T]) error { + weightLimit := s.ruleFactory.GetRules(chunk.Expiry).GetMaxAccumulatedProducerChunkWeight() + + if uint64(len(chunk.bytes))+s.pendingChunksSizes[chunk.Producer] > weightLimit { + return ErrChunkRateLimitSurpassed + } + return nil +} + func createChunkKey(prefix byte, slot int64, chunkID ids.ID) []byte { b := make([]byte, chunkKeySize) b[0] = prefix diff --git a/x/dsmr/storage_test.go b/x/dsmr/storage_test.go index f085a1cc0f..4d943fea62 100644 --- a/x/dsmr/storage_test.go +++ b/x/dsmr/storage_test.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "slices" "testing" "time" @@ -26,6 +27,8 @@ var errInvalidTestItem = errors.New("invalid test item") var _ Verifier[Tx] = testVerifier[Tx]{} +var testDefaultProducer = ids.GenerateTestNodeID() + type testVerifier[T Tx] struct { correctIDs set.Set[ids.ID] correctCerts set.Set[*ChunkCertificate] @@ -47,7 +50,7 @@ func (t testVerifier[T]) VerifyCertificate(_ context.Context, cert *ChunkCertifi return fmt.Errorf("%w: %s", errInvalidTestItem, cert.ChunkID) } -func createTestStorage(t *testing.T, validChunkExpiry, invalidChunkExpiry []int64) ( +func createTestStorage(t *testing.T, validChunkExpiry, invalidChunkExpiry []int64, ruleFactory RuleFactory) ( *ChunkStorage[dsmrtest.Tx], []Chunk[dsmrtest.Tx], []Chunk[dsmrtest.Tx], @@ -60,7 +63,7 @@ func createTestStorage(t *testing.T, validChunkExpiry, invalidChunkExpiry []int6 for _, expiry := range validChunkExpiry { chunk, err := newChunk( UnsignedChunk[dsmrtest.Tx]{ - Producer: ids.EmptyNodeID, + Producer: testDefaultProducer, Beneficiary: codec.Address{}, Expiry: expiry, Txs: []dsmrtest.Tx{{ID: ids.GenerateTestID(), Expiry: 1_000_000}}, @@ -76,7 +79,7 @@ func createTestStorage(t *testing.T, validChunkExpiry, invalidChunkExpiry []int6 for _, expiry := range invalidChunkExpiry { chunk, err := newChunk( UnsignedChunk[dsmrtest.Tx]{ - Producer: ids.EmptyNodeID, + Producer: testDefaultProducer, Beneficiary: codec.Address{}, Expiry: expiry, Txs: []dsmrtest.Tx{{ID: ids.GenerateTestID(), Expiry: 1_000_000}}, @@ -101,6 +104,7 @@ func createTestStorage(t *testing.T, validChunkExpiry, invalidChunkExpiry []int6 storage, err := NewChunkStorage[dsmrtest.Tx]( testVerifier, db, + ruleFactory, ) require.NoError(err) @@ -112,6 +116,7 @@ func createTestStorage(t *testing.T, validChunkExpiry, invalidChunkExpiry []int6 storage, err := NewChunkStorage[dsmrtest.Tx]( testVerifier, db, + ruleFactory, ) require.NoError(err) return storage @@ -122,7 +127,7 @@ func createTestStorage(t *testing.T, validChunkExpiry, invalidChunkExpiry []int6 func TestStoreAndSaveValidChunk(t *testing.T) { require := require.New(t) - storage, validChunks, _, _, verifier := createTestStorage(t, []int64{time.Now().Unix()}, []int64{}) + storage, validChunks, _, _, verifier := createTestStorage(t, []int64{time.Now().Unix()}, []int64{}, testRuleFactory) chunk := validChunks[0] _, err := storage.VerifyRemoteChunk(chunk) @@ -164,7 +169,7 @@ func TestStoreAndSaveValidChunk(t *testing.T) { func TestStoreAndExpireValidChunk(t *testing.T) { require := require.New(t) - storage, validChunks, _, _, verifier := createTestStorage(t, []int64{time.Now().Unix()}, []int64{}) + storage, validChunks, _, _, verifier := createTestStorage(t, []int64{time.Now().Unix()}, []int64{}, testRuleFactory) chunk := validChunks[0] _, err := storage.VerifyRemoteChunk(chunk) @@ -202,7 +207,7 @@ func TestStoreAndExpireValidChunk(t *testing.T) { func TestStoreInvalidChunk(t *testing.T) { require := require.New(t) - storage, _, invalidChunks, _, _ := createTestStorage(t, []int64{}, []int64{time.Now().Unix()}) + storage, _, invalidChunks, _, _ := createTestStorage(t, []int64{}, []int64{time.Now().Unix()}, testRuleFactory) chunk := invalidChunks[0] _, err := storage.VerifyRemoteChunk(chunk) @@ -218,7 +223,7 @@ func TestStoreInvalidChunk(t *testing.T) { func TestStoreAndSaveLocalChunk(t *testing.T) { require := require.New(t) - storage, validChunks, _, _, _ := createTestStorage(t, []int64{time.Now().Unix()}, []int64{}) + storage, validChunks, _, _, _ := createTestStorage(t, []int64{time.Now().Unix()}, []int64{}, testRuleFactory) chunk := validChunks[0] chunkCert := &ChunkCertificate{ ChunkReference: ChunkReference{ @@ -251,7 +256,7 @@ func TestStoreAndSaveLocalChunk(t *testing.T) { func TestStoreAndExpireLocalChunk(t *testing.T) { require := require.New(t) - storage, validChunks, _, _, _ := createTestStorage(t, []int64{time.Now().Unix()}, []int64{}) + storage, validChunks, _, _, _ := createTestStorage(t, []int64{time.Now().Unix()}, []int64{}, testRuleFactory) chunk := validChunks[0] chunkCert := &ChunkCertificate{ ChunkReference: ChunkReference{ @@ -291,7 +296,7 @@ func TestRestartSavedChunks(t *testing.T) { // 5. Pending local chunk // 6. Pending remote chunk numChunks := 6 - storage, validChunks, _, restart, verifier := createTestStorage(t, []int64{2, 2, 1, 1, 2, 2}, []int64{}) + storage, validChunks, _, restart, verifier := createTestStorage(t, []int64{2, 2, 1, 1, 2, 2}, []int64{}, testRuleFactory) chunkCerts := make([]*ChunkCertificate, 0, numChunks) for _, chunk := range validChunks { chunkCert := &ChunkCertificate{ @@ -369,3 +374,122 @@ func TestRestartSavedChunks(t *testing.T) { storage = restart() confirmChunkStorage(storage) } + +func TestChunkProducerRateLimiting(t *testing.T) { + chunk, err := newChunk( + UnsignedChunk[dsmrtest.Tx]{ + Producer: testDefaultProducer, + Beneficiary: codec.Address{}, + Expiry: 1, + Txs: []dsmrtest.Tx{{ID: ids.GenerateTestID(), Expiry: 1_000_000}}, + }, + [48]byte{}, + [96]byte{}, + ) + require.NoError(t, err) + chunkSize := uint64(len(chunk.bytes)) + + testCases := []struct { + name string + expiryTimes []int64 + newChunkExpiry int64 + weightLimit uint64 + minTime int64 + acceptedChunksExpiry []int64 + expectedErr error + }{ + { + name: "success - first", + expiryTimes: []int64{}, + newChunkExpiry: 50, + weightLimit: chunkSize * 5, + expectedErr: nil, + }, + { + name: "success - after", + expiryTimes: []int64{1}, + newChunkExpiry: 50, + weightLimit: chunkSize * 5, + expectedErr: nil, + }, + { + name: "success - before", + expiryTimes: []int64{50}, + newChunkExpiry: 1, + weightLimit: chunkSize * 5, + expectedErr: nil, + }, + { + name: "fail - localized saturated range", + expiryTimes: []int64{0, 50, 99, 150, 500}, + newChunkExpiry: 75, + weightLimit: chunkSize * 2, + expectedErr: ErrChunkRateLimitSurpassed, + }, + { + name: "fail - localized saturated range with multiple elements", + expiryTimes: []int64{0, 100, 120, 150, 200}, + newChunkExpiry: 130, + weightLimit: chunkSize * 3, + expectedErr: ErrChunkRateLimitSurpassed, + }, + { + name: "success - accepted block clear previous limit", + expiryTimes: []int64{0, 50, 100}, + newChunkExpiry: 150, + weightLimit: chunkSize * 3, + expectedErr: nil, + minTime: 10, + acceptedChunksExpiry: []int64{50}, + }, + { + name: "success - expired chunks clear previous limit", + expiryTimes: []int64{50, 100, 150}, + newChunkExpiry: 200, + weightLimit: chunkSize * 3, + expectedErr: nil, + minTime: 55, + acceptedChunksExpiry: []int64{}, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + require := require.New(t) + rules := &ruleFactory{ + rules: rules{ + maxProducerChunkWeight: testCase.weightLimit, + }, + } + storage, chunks, _, _, _ := createTestStorage(t, testCase.expiryTimes, []int64{}, rules) + + for _, chunk := range chunks { + require.NoError(storage.AddLocalChunkWithCert(chunk, nil)) + } + + var acceptedChunks []ids.ID + for _, chunkExpiry := range testCase.acceptedChunksExpiry { + // find the chunk that corresponds to this expiry in the chunks slice. + chunkIndex := slices.IndexFunc(chunks, func(chunk Chunk[dsmrtest.Tx]) bool { + return chunk.Expiry == chunkExpiry + }) + require.NotEqual(-1, chunkIndex, "acceptedChunksExpiry contains an expiry time missing from expiryTimes") + acceptedChunks = append(acceptedChunks, chunks[chunkIndex].id) + } + require.NoError(storage.SetMin(testCase.minTime, acceptedChunks)) + + chunk, err := newChunk( + UnsignedChunk[dsmrtest.Tx]{ + Producer: testDefaultProducer, + Beneficiary: codec.Address{}, + Expiry: testCase.newChunkExpiry, + Txs: []dsmrtest.Tx{{ID: ids.GenerateTestID(), Expiry: 1_000_000}}, + }, + [48]byte{}, + [96]byte{}, + ) + require.NoError(err) + + require.ErrorIs(storage.CheckRateLimit(chunk), testCase.expectedErr) + }) + } +}