diff --git a/tests/test_smt.py b/tests/test_smt.py index b68331d..ccd8544 100644 --- a/tests/test_smt.py +++ b/tests/test_smt.py @@ -1,4 +1,5 @@ from hypothesis import ( + assume, given, strategies as st, ) @@ -11,12 +12,25 @@ ) +@st.composite +def binary_tuples(draw): + size = draw(st.integers(min_value=1, max_value=32)) + v = draw(st.binary(min_size=size, max_size=size)) + default = draw(st.binary(min_size=size, max_size=size)) + + # Ensure v and default are not equal + assume(v != default) + + return (v, default) + + @given( k=st.binary(min_size=1, max_size=32), - v=st.binary(min_size=1, max_size=32), + values=binary_tuples(), ) -def test_simple_kv(k, v): - smt = SparseMerkleTree(key_size=len(k)) +def test_simple_kv(k, values): + v, default = values + smt = SparseMerkleTree(key_size=len(k), default=default) empty_root = smt.root_hash # Nothing has been added yet diff --git a/trie/smt.py b/trie/smt.py index 6a7fe86..795bed3 100644 --- a/trie/smt.py +++ b/trie/smt.py @@ -252,7 +252,7 @@ def get(self, key: bytes) -> bytes: value, _ = self._get(key) # Ensure that it isn't blank! - if value == BLANK_NODE: + if value == self._default: raise KeyError("Key does not exist") return value @@ -261,7 +261,7 @@ def branch(self, key: bytes) -> Tuple[Hash32]: value, branch = self._get(key) # Ensure that it isn't blank! - if value == BLANK_NODE: + if value == self._default: raise KeyError("Key does not exist") return branch