5 from __future__
import annotations
6 from os.path
import exists
9 from abc
import ABC, abstractmethod
10 from typing
import Dict, List, Tuple, Iterator, cast, Any
13 ZERO_ENTRY = bytes.fromhex(
14 "0000000000000000000000000000000000000000000000000000000000000000")
19 Abstract interface for a hash function to be used in a Merkle tree
24 def hash(self, left: bytes, right: bytes) -> bytes:
30 Simple container to be persisted for a client-side Merkle tree. Does not
31 perform any computation. Layers are ordered from top (smallest) to bottom.
36 default_values: List[bytes],
37 layers: List[List[bytes]]):
44 depth = cast(int, json_dict[
"depth"])
45 default_values = _to_list_bytes(
46 cast(List[str], json_dict[
"default_values"]))
49 for layer
in cast(List[List[str]], json_dict[
"layers"])]
56 "layers": [_to_list_str(layer)
for layer
in self.
layers],
62 Merkle tree structure matching that used in the mixer contract. Simple
63 implementation where unpopulated values (zeroes) are also stored.
67 tree_data: MerkleTreeData,
69 tree_hash: ITreeHash):
71 self.
depth = tree_data.depth
77 def _empty_data_with_depth(
78 depth: int, tree_hash: ITreeHash) -> MerkleTreeData:
80 default_values = [ZERO_ENTRY] * (depth + 1)
81 for i
in range(depth - 1, -1, -1):
82 default_values[i] = tree_hash.hash(
83 default_values[i + 1], default_values[i + 1])
87 layers: List[List[bytes]] = [[default_values[0]]]
88 layers.extend([[]
for _
in range(depth)])
90 assert len(layers) == depth + 1
96 MerkleTree._empty_data_with_depth(depth, tree_hash), depth, tree_hash)
100 depth =
int(math.log(num_leaves, 2))
101 assert pow(2, depth) == num_leaves, f
"Non-pow-2 size {num_leaves} given"
102 return MerkleTree.empty_with_depth(depth, tree_hash)
109 if index < len(leaves):
116 def get_node(self, layer_idx: int, node_idx: int) -> bytes:
117 assert layer_idx <= self.
depth
119 layer_idx = self.
depth - layer_idx
121 if node_idx < len(layer):
122 return layer[node_idx]
123 return self.
tree_data.default_values[layer_idx]
127 Public layers iterator.
144 After some new leaves have been added, perform the minimal set of hashes
145 to recompute the tree, expanding each layer to accommodate new nodes.
152 layer_default, layer = next(layers_it)
157 for parent_default, parent_layer
in layers_it:
161 start_idx, end_idx = _recompute_layer(
169 layer_default = parent_default
170 layer_size =
int(layer_size / 2)
173 assert len(layer) == 1
174 assert layer_size == 1
177 def _get_layers(self) -> Iterator[Tuple[bytes, List[bytes]]]:
179 Internal version of layers iterator for use during updating.
180 With 0-th layer as the leaves (matching the public interface).
182 default_values = self.
tree_data.default_values
184 for i
in range(self.
depth, -1, -1):
185 yield (default_values[i], layers[i])
190 Given an "address" (index into leaves of a Merkle tree), compute the path to
193 merkle_path: List[str] = []
199 for depth
in range(mk_tree.depth):
200 address_bit = address & 0x1
202 merkle_path.append(mk_tree.get_node(depth, address - 1).hex())
204 merkle_path.append(mk_tree.get_node(depth, address + 1).hex())
205 address = address >> 1
211 Version of MerkleTree that also supports persistence.
216 tree_data: MerkleTreeData,
218 tree_hash: ITreeHash):
219 MerkleTree.__init__(self, tree_data, depth, tree_hash)
226 tree_hash: ITreeHash) -> PersistentMerkleTree:
227 depth =
int(math.log(max_num_leaves, 2))
228 assert max_num_leaves ==
int(math.pow(2, depth))
230 with open(filename,
"r")
as tree_f:
231 json_dict = json.load(tree_f)
232 tree_data = MerkleTreeData.from_json_dict(json_dict)
233 assert depth == tree_data.depth
235 tree_data = MerkleTree._empty_data_with_depth(depth, tree_hash)
241 json.dump(self.
tree_data.to_json_dict(), tree_f)
244 def _leaf_address_to_node_address(address_leaf: int, tree_depth: int) -> int:
246 Converts the relative address of a leaf to an absolute address in the tree
247 Important note: The merkle root index is 0 (not 1!)
249 address = address_leaf + (2 ** tree_depth - 1)
250 if address > (2 ** (tree_depth + 1) - 1):
255 def _recompute_layer(
256 child_layer: List[bytes],
257 child_start_idx: int,
259 child_default_value: bytes,
260 parent_layer: List[bytes],
261 tree_hash: ITreeHash) -> Tuple[int, int]:
263 Recompute nodes in the parent layer that are affected by entries
264 [child_start_idx, child_end_idx[ in the child layer. If `child_end_idx` is
265 required in the calculation, the final entry of the child layer is used
266 (since this contains the default entry for the layer if the tree is not
267 full). Returns the start and end indices (within the parent layer) of
268 touched parent nodes.
279 new_parent_layer_length =
int((child_end_idx + 1) / 2)
281 [ZERO_ENTRY] * (new_parent_layer_length - len(parent_layer)))
286 child_left_idx_rend =
int(child_start_idx / 2) * 2
290 if child_end_idx & 1:
291 child_left_idx = child_end_idx - 1
292 parent_layer[child_left_idx >> 1] = tree_hash.hash(
293 child_layer[child_left_idx], child_default_value)
295 child_left_idx = child_end_idx
299 while child_left_idx > child_left_idx_rend:
300 child_left_idx = child_left_idx - 2
301 parent_layer[child_left_idx >> 1] = tree_hash.hash(
302 child_layer[child_left_idx], child_layer[child_left_idx + 1])
304 return child_start_idx >> 1, new_parent_layer_length
307 def _to_list_bytes(list_str: List[str]) -> List[bytes]:
308 return [bytes.fromhex(entry)
for entry
in list_str]
311 def _to_list_str(list_bytes: List[bytes]) -> List[str]:
312 return [entry.hex()
for entry
in list_bytes]