Skip to content

Commit

Permalink
more datalayer python/rust parametrization (#19269)
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky authored Feb 14, 2025
1 parent 125a234 commit 06a1e25
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 24 deletions.
44 changes: 23 additions & 21 deletions chia/_tests/core/data_layer/test_merkle_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,36 +188,36 @@ def test_raw_node_to_blob(case: RawNodeFromBlobCase[RawMerkleNodeProtocol]) -> N
assert blob == case.packed


def test_merkle_blob_one_leaf_loads() -> None:
def test_merkle_blob_one_leaf_loads(merkle_types: MerkleTypes) -> None:
# TODO: need to persist reference data
leaf = RawLeafMerkleNode(
leaf = merkle_types.leaf(
hash=bytes32(range(32)),
parent=None,
key=KeyId(KeyOrValueId(int64(0x0405060708090A0B))),
value=ValueId(KeyOrValueId(int64(0x0405060708090A1B))),
)
blob = bytearray(bytes(NodeMetadata(type=NodeType.leaf, dirty=False)) + pack_raw_node(leaf))

merkle_blob = MerkleBlob(blob=blob)
merkle_blob = merkle_types.blob(blob=blob)
assert merkle_blob.get_raw_node(TreeIndex(0)) == leaf


def test_merkle_blob_two_leafs_loads() -> None:
def test_merkle_blob_two_leafs_loads(merkle_types: MerkleTypes) -> None:
# TODO: break this test down into some reusable data and multiple tests
# TODO: need to persist reference data
root = RawInternalMerkleNode(
root = merkle_types.internal(
hash=bytes32(range(32)),
parent=None,
left=TreeIndex(1),
right=TreeIndex(2),
)
left_leaf = RawLeafMerkleNode(
left_leaf = merkle_types.leaf(
hash=bytes32(range(32)),
parent=TreeIndex(0),
key=KeyId(KeyOrValueId(int64(0x0405060708090A0B))),
value=ValueId(KeyOrValueId(int64(0x0405060708090A1B))),
)
right_leaf = RawLeafMerkleNode(
right_leaf = merkle_types.leaf(
hash=bytes32(range(32)),
parent=TreeIndex(0),
key=KeyId(KeyOrValueId(int64(0x1415161718191A1B))),
Expand All @@ -228,7 +228,7 @@ def test_merkle_blob_two_leafs_loads() -> None:
blob.extend(bytes(NodeMetadata(type=NodeType.leaf, dirty=False)) + pack_raw_node(left_leaf))
blob.extend(bytes(NodeMetadata(type=NodeType.leaf, dirty=False)) + pack_raw_node(right_leaf))

merkle_blob = MerkleBlob(blob=blob)
merkle_blob = merkle_types.blob(blob=blob)
assert merkle_blob.get_raw_node(TreeIndex(0)) == root
assert merkle_blob.get_raw_node(root.left) == left_leaf
assert merkle_blob.get_raw_node(root.right) == right_leaf
Expand All @@ -238,7 +238,7 @@ def test_merkle_blob_two_leafs_loads() -> None:
assert merkle_blob.get_raw_node(right_leaf.parent) == root

assert merkle_blob.get_lineage_with_indexes(TreeIndex(0)) == [(TreeIndex(0), root)]
expected: list[tuple[TreeIndex, RawMerkleNodeProtocol]] = [
expected = [
(TreeIndex(1), left_leaf),
(TreeIndex(0), root),
]
Expand All @@ -263,8 +263,8 @@ def generate_hash(seed: int) -> bytes32:
return bytes32(hash_obj.digest())


def test_insert_delete_loads_all_keys() -> None:
merkle_blob = MerkleBlob(blob=bytearray())
def test_insert_delete_loads_all_keys(merkle_types: MerkleTypes) -> None:
merkle_blob = merkle_types.blob(blob=bytearray())
num_keys = 200000
extra_keys = 100000
max_height = 25
Expand All @@ -290,7 +290,7 @@ def test_insert_delete_loads_all_keys() -> None:
key, value = generate_kvid(seed)
hash = generate_hash(seed)
merkle_blob.insert(key, value, hash)
key_index = merkle_blob.key_to_index[key]
key_index = merkle_blob.get_key_index(key)
lineage = merkle_blob.get_lineage_with_indexes(key_index)
assert len(lineage) <= max_height
keys_values[key] = value
Expand All @@ -304,20 +304,20 @@ def test_insert_delete_loads_all_keys() -> None:

assert merkle_blob.get_keys_values() == keys_values

merkle_blob_2 = MerkleBlob(blob=bytearray(merkle_blob.blob))
merkle_blob_2 = merkle_types.blob(blob=bytearray(merkle_blob.blob))
for seed in range(num_keys, num_keys + extra_keys):
key, value = generate_kvid(seed)
hash = generate_hash(seed)
merkle_blob_2.upsert(key, value, hash)
key_index = merkle_blob_2.key_to_index[key]
key_index = merkle_blob_2.get_key_index(key)
lineage = merkle_blob_2.get_lineage_with_indexes(key_index)
assert len(lineage) <= max_height
keys_values[key] = value
assert merkle_blob_2.get_keys_values() == keys_values


def test_small_insert_deletes() -> None:
merkle_blob = MerkleBlob(blob=bytearray())
def test_small_insert_deletes(merkle_types: MerkleTypes) -> None:
merkle_blob = merkle_types.blob(blob=bytearray())
num_repeats = 100
max_inserts = 25
seed = 0
Expand Down Expand Up @@ -403,8 +403,8 @@ def test_proof_of_inclusion_merkle_blob() -> None:


@pytest.mark.parametrize(argnames="index", argvalues=[TreeIndex(1), undefined_index])
def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex) -> None:
merkle_blob = MerkleBlob(blob=bytearray())
def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex, merkle_types: MerkleTypes) -> None:
merkle_blob = merkle_types.blob(blob=bytearray())
merkle_blob.insert(
KeyId(KeyOrValueId(int64(0x1415161718191A1B))),
ValueId(KeyOrValueId(int64(0x1415161718191A1B))),
Expand All @@ -419,8 +419,10 @@ def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex) -> None:
with pytest.raises(expected):
merkle_blob.get_raw_node(index)

with pytest.raises(InvalidIndexError):
merkle_blob._get_metadata(index)
# this is a python-implementation detail test
if isinstance(merkle_blob, MerkleBlob):
with pytest.raises(InvalidIndexError):
merkle_blob._get_metadata(index)


def test_helper_methods(merkle_types: MerkleTypes) -> None:
Expand Down Expand Up @@ -452,7 +454,7 @@ def test_insert_with_reference_key_and_side(merkle_types: MerkleTypes) -> None:
merkle_blob.insert(key, value, hash, reference_kid, side)
if reference_kid is not None:
assert side is not None
index = merkle_blob.key_to_index[key]
index = merkle_blob.get_key_index(key)
node = merkle_blob.get_raw_node(index)
parent = merkle_blob.get_raw_node(node.parent)
if side == Side.LEFT:
Expand Down
7 changes: 5 additions & 2 deletions chia/data_layer/util/merkle_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, ClassVar, Optional, Protocol, TypeVar, Union, cast, final
from typing import TYPE_CHECKING, ClassVar, Optional, Protocol, SupportsBytes, TypeVar, Union, cast, final

from chia_rs.datalayer import KeyId, TreeIndex, ValueId

Expand Down Expand Up @@ -122,6 +122,9 @@ def get_new_index(self) -> TreeIndex:

return self.free_indexes.pop()

def get_key_index(self, key: KeyId) -> TreeIndex:
return self.key_to_index[key]

def get_raw_node(self, index: TreeIndex) -> RawMerkleNodeProtocol:
if undefined_index.raw <= index.raw:
raise InvalidIndexError(index=index)
Expand Down Expand Up @@ -689,7 +692,7 @@ def unpack_raw_node(index: TreeIndex, metadata: NodeMetadata, data: bytes) -> Ra


# TODO: allow broader bytes'ish types
def pack_raw_node(raw_node: RawMerkleNodeProtocol) -> bytes:
def pack_raw_node(raw_node: SupportsBytes) -> bytes:
data = bytes(raw_node)
padding = data_size - len(data)
assert padding >= 0, f"unexpected negative padding: {padding}"
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ typing-extensions = "*"
type = "git"
url = "https://github.com/chia-network/chia_rs"
reference = "long_lived/initial_datalayer"
resolved_reference = "7ad929dd1f261151588cee8c72ca1f1260c86de9"
resolved_reference = "b3524140bfebf59d5009d9b4db0a5e1faba5ce15"
subdirectory = "wheel/"

[[package]]
Expand Down

0 comments on commit 06a1e25

Please sign in to comment.