Zeth - Zerocash on Ethereum  0.8
Reference implementation of the Zeth protocol by Clearmatics
merkle_tree.py
Go to the documentation of this file.
1 # Copyright (c) 2015-2022 Clearmatics Technologies Ltd
2 #
3 # SPDX-License-Identifier: LGPL-3.0+
4 
5 from __future__ import annotations
6 from os.path import exists
7 import json
8 import math
9 from abc import ABC, abstractmethod
10 from typing import Dict, List, Tuple, Iterator, cast, Any
11 
12 
13 ZERO_ENTRY = bytes.fromhex(
14  "0000000000000000000000000000000000000000000000000000000000000000")
15 
16 
17 class ITreeHash(ABC):
18  """
19  Abstract interface for a hash function to be used in a Merkle tree
20  structure.
21  """
22 
23  @abstractmethod
24  def hash(self, left: bytes, right: bytes) -> bytes:
25  pass
26 
27 
29  """
30  Simple container to be persisted for a client-side Merkle tree. Does not
31  perform any computation. Layers are ordered from top (smallest) to bottom.
32  """
33  def __init__(
34  self,
35  depth: int,
36  default_values: List[bytes],
37  layers: List[List[bytes]]):
38  self.depth = depth
39  self.default_values = default_values
40  self.layers = layers
41 
42  @staticmethod
43  def from_json_dict(json_dict: Dict[str, Any]) -> MerkleTreeData:
44  depth = cast(int, json_dict["depth"])
45  default_values = _to_list_bytes(
46  cast(List[str], json_dict["default_values"]))
47  layers = [
48  _to_list_bytes(layer)
49  for layer in cast(List[List[str]], json_dict["layers"])]
50  return MerkleTreeData(depth, default_values, layers)
51 
52  def to_json_dict(self) -> Dict[str, Any]:
53  return {
54  "depth": self.depth,
55  "default_values": _to_list_str(self.default_values),
56  "layers": [_to_list_str(layer) for layer in self.layers],
57  }
58 
59 
60 class MerkleTree:
61  """
62  Merkle tree structure matching that used in the mixer contract. Simple
63  implementation where unpopulated values (zeroes) are also stored.
64  """
65  def __init__(
66  self,
67  tree_data: MerkleTreeData,
68  depth: int,
69  tree_hash: ITreeHash):
70  self.max_num_leaves = pow(2, depth)
71  self.depth = tree_data.depth
72  self.tree_data = tree_data
73  self.num_new_leaves = 0
74  self.tree_hash = tree_hash
75 
76  @staticmethod
77  def _empty_data_with_depth(
78  depth: int, tree_hash: ITreeHash) -> MerkleTreeData:
79  # Compute default values for each layer
80  default_values = [ZERO_ENTRY] * (depth + 1)
81  for i in range(depth - 1, -1, -1):
82  default_values[i] = tree_hash.hash(
83  default_values[i + 1], default_values[i + 1])
84 
85  # Initial layer data (fill the 0-th layer with the default root so it's
86  # always available).
87  layers: List[List[bytes]] = [[default_values[0]]]
88  layers.extend([[] for _ in range(depth)])
89 
90  assert len(layers) == depth + 1
91  return MerkleTreeData(depth, default_values, layers)
92 
93  @staticmethod
94  def empty_with_depth(depth: int, tree_hash: ITreeHash) -> MerkleTree:
95  return MerkleTree(
96  MerkleTree._empty_data_with_depth(depth, tree_hash), depth, tree_hash)
97 
98  @staticmethod
99  def empty_with_size(num_leaves: int, tree_hash: ITreeHash) -> MerkleTree:
100  depth = int(math.log(num_leaves, 2))
101  assert pow(2, depth) == num_leaves, f"Non-pow-2 size {num_leaves} given"
102  return MerkleTree.empty_with_depth(depth, tree_hash)
103 
104  def get_num_entries(self) -> int:
105  return len(self.tree_data.layers[self.depth])
106 
107  def get_leaf(self, index: int) -> bytes:
108  leaves = self.tree_data.layers[self.depth]
109  if index < len(leaves):
110  return leaves[index]
111  return ZERO_ENTRY
112 
113  def get_leaves(self) -> List[bytes]:
114  return self.tree_data.layers[self.depth]
115 
116  def get_node(self, layer_idx: int, node_idx: int) -> bytes:
117  assert layer_idx <= self.depth
118  assert self.num_new_leaves == 0
119  layer_idx = self.depth - layer_idx
120  layer = self.tree_data.layers[layer_idx]
121  if node_idx < len(layer):
122  return layer[node_idx]
123  return self.tree_data.default_values[layer_idx]
124 
125  def get_layers(self) -> Iterator[Tuple[bytes, List[bytes]]]:
126  """
127  Public layers iterator.
128  """
129  assert self.num_new_leaves == 0
130  return self._get_layers()
131 
132  def get_root(self) -> bytes:
133  assert self.num_new_leaves == 0
134  return self.tree_data.layers[0][0]
135 
136  def insert(self, value: bytes) -> None:
137  leaves = self.tree_data.layers[self.depth]
138  assert len(leaves) < self.max_num_leaves
139  leaves.append(value)
140  self.num_new_leaves = self.num_new_leaves + 1
141 
142  def recompute_root(self) -> bytes:
143  """
144  After some new leaves have been added, perform the minimal set of hashes
145  to recompute the tree, expanding each layer to accommodate new nodes.
146  """
147  if self.num_new_leaves == 0:
148  return self.get_root()
149 
150  layers_it = self._get_layers()
151 
152  layer_default, layer = next(layers_it)
153  end_idx = len(layer)
154  start_idx = end_idx - self.num_new_leaves
155  layer_size = self.max_num_leaves
156 
157  for parent_default, parent_layer in layers_it:
158  # Computation for each layer is performed in _recompute_layer, which
159  # also computes the start and end indices for the next layer in the
160  # tree.
161  start_idx, end_idx = _recompute_layer(
162  layer,
163  start_idx,
164  end_idx,
165  layer_default,
166  parent_layer,
167  self.tree_hash)
168  layer = parent_layer
169  layer_default = parent_default
170  layer_size = int(layer_size / 2)
171 
172  self.num_new_leaves = 0
173  assert len(layer) == 1
174  assert layer_size == 1
175  return layer[0]
176 
177  def _get_layers(self) -> Iterator[Tuple[bytes, List[bytes]]]:
178  """
179  Internal version of layers iterator for use during updating.
180  With 0-th layer as the leaves (matching the public interface).
181  """
182  default_values = self.tree_data.default_values
183  layers = self.tree_data.layers
184  for i in range(self.depth, -1, -1):
185  yield (default_values[i], layers[i])
186 
187 
188 def compute_merkle_path(address: int, mk_tree: MerkleTree) -> List[str]:
189  """
190  Given an "address" (index into leaves of a Merkle tree), compute the path to
191  the root.
192  """
193  merkle_path: List[str] = []
194  if address == -1:
195  return merkle_path
196 
197  # Check each bit of address in turn. If it is set, take the left node,
198  # otherwise take the right node.
199  for depth in range(mk_tree.depth):
200  address_bit = address & 0x1
201  if address_bit:
202  merkle_path.append(mk_tree.get_node(depth, address - 1).hex())
203  else:
204  merkle_path.append(mk_tree.get_node(depth, address + 1).hex())
205  address = address >> 1
206  return merkle_path
207 
208 
210  """
211  Version of MerkleTree that also supports persistence.
212  """
213  def __init__(
214  self,
215  filename: str,
216  tree_data: MerkleTreeData,
217  depth: int,
218  tree_hash: ITreeHash):
219  MerkleTree.__init__(self, tree_data, depth, tree_hash)
220  self.filename = filename
221 
222  @staticmethod
223  def open(
224  filename: str,
225  max_num_leaves: int,
226  tree_hash: ITreeHash) -> PersistentMerkleTree:
227  depth = int(math.log(max_num_leaves, 2))
228  assert max_num_leaves == int(math.pow(2, depth))
229  if exists(filename):
230  with open(filename, "r") as tree_f:
231  json_dict = json.load(tree_f)
232  tree_data = MerkleTreeData.from_json_dict(json_dict)
233  assert depth == tree_data.depth
234  else:
235  tree_data = MerkleTree._empty_data_with_depth(depth, tree_hash)
236 
237  return PersistentMerkleTree(filename, tree_data, depth, tree_hash)
238 
239  def save(self) -> None:
240  with open(self.filename, "w") as tree_f:
241  json.dump(self.tree_data.to_json_dict(), tree_f)
242 
243 
244 def _leaf_address_to_node_address(address_leaf: int, tree_depth: int) -> int:
245  """
246  Converts the relative address of a leaf to an absolute address in the tree
247  Important note: The merkle root index is 0 (not 1!)
248  """
249  address = address_leaf + (2 ** tree_depth - 1)
250  if address > (2 ** (tree_depth + 1) - 1):
251  return -1
252  return address
253 
254 
255 def _recompute_layer(
256  child_layer: List[bytes],
257  child_start_idx: int,
258  child_end_idx: int,
259  child_default_value: bytes,
260  parent_layer: List[bytes],
261  tree_hash: ITreeHash) -> Tuple[int, int]:
262  """
263  Recompute nodes in the parent layer that are affected by entries
264  [child_start_idx, child_end_idx[ in the child layer. If `child_end_idx` is
265  required in the calculation, the final entry of the child layer is used
266  (since this contains the default entry for the layer if the tree is not
267  full). Returns the start and end indices (within the parent layer) of
268  touched parent nodes.
269  """
270 
271  # / \ / \ / \
272  # Parent: ? ? F G H 0
273  # / \ / \ / \ / \ / \ / \
274  # Child: ? ? ? ? A B C D E ? ? 0
275  # ^ ^
276  # child_start_idx child_end_idx
277 
278  # Extend the parent layer to ensure it has enough capacity.
279  new_parent_layer_length = int((child_end_idx + 1) / 2)
280  parent_layer.extend(
281  [ZERO_ENTRY] * (new_parent_layer_length - len(parent_layer)))
282 
283  # Compute the further right pair to compute, and iterate left until we reach
284  # `child_idx_rend` (reverse-end). `child_idx_rend` is the `child_start_idx`
285  # rounded down to the next even number.
286  child_left_idx_rend = int(child_start_idx / 2) * 2
287 
288  # If the child_end_idx is odd, the first hash must use the child layer's
289  # default value on the right.
290  if child_end_idx & 1:
291  child_left_idx = child_end_idx - 1
292  parent_layer[child_left_idx >> 1] = tree_hash.hash(
293  child_layer[child_left_idx], child_default_value)
294  else:
295  child_left_idx = child_end_idx
296 
297  # At this stage, all remaining pairs are populated. Hash pairs and write
298  # them to the parent layer.
299  while child_left_idx > child_left_idx_rend:
300  child_left_idx = child_left_idx - 2
301  parent_layer[child_left_idx >> 1] = tree_hash.hash(
302  child_layer[child_left_idx], child_layer[child_left_idx + 1])
303 
304  return child_start_idx >> 1, new_parent_layer_length
305 
306 
307 def _to_list_bytes(list_str: List[str]) -> List[bytes]:
308  return [bytes.fromhex(entry) for entry in list_str]
309 
310 
311 def _to_list_str(list_bytes: List[bytes]) -> List[str]:
312  return [entry.hex() for entry in list_bytes]
zeth.core.merkle_tree.MerkleTreeData.layers
layers
Definition: merkle_tree.py:36
zeth.core.merkle_tree.MerkleTreeData
Definition: merkle_tree.py:28
zeth.core.merkle_tree.MerkleTree.max_num_leaves
max_num_leaves
Definition: merkle_tree.py:66
zeth.cli.zeth_deploy.int
int
Definition: zeth_deploy.py:27
zeth.core.merkle_tree.MerkleTree.empty_with_depth
MerkleTree empty_with_depth(int depth, ITreeHash tree_hash)
Definition: merkle_tree.py:94
zeth.core.merkle_tree.MerkleTree.insert
None insert(self, bytes value)
Definition: merkle_tree.py:136
zeth.core.merkle_tree.MerkleTreeData.__init__
def __init__(self, int depth, List[bytes] default_values, List[List[bytes]] layers)
Definition: merkle_tree.py:33
zeth.core.merkle_tree.MerkleTree.num_new_leaves
num_new_leaves
Definition: merkle_tree.py:69
zeth.core.merkle_tree.PersistentMerkleTree
Definition: merkle_tree.py:209
zeth.core.merkle_tree.MerkleTree._get_layers
Iterator[Tuple[bytes, List[bytes]]] _get_layers(self)
Definition: merkle_tree.py:177
zeth.core.merkle_tree.compute_merkle_path
List[str] compute_merkle_path(int address, MerkleTree mk_tree)
Definition: merkle_tree.py:188
zeth.core.merkle_tree.PersistentMerkleTree.open
PersistentMerkleTree open(str filename, int max_num_leaves, ITreeHash tree_hash)
Definition: merkle_tree.py:223
zeth.core.merkle_tree.MerkleTreeData.depth
depth
Definition: merkle_tree.py:34
zeth.core.merkle_tree.MerkleTree.recompute_root
bytes recompute_root(self)
Definition: merkle_tree.py:142
zeth.core.merkle_tree.MerkleTree.__init__
def __init__(self, MerkleTreeData tree_data, int depth, ITreeHash tree_hash)
Definition: merkle_tree.py:65
zeth.core.merkle_tree.MerkleTree.get_num_entries
int get_num_entries(self)
Definition: merkle_tree.py:104
zeth.core.merkle_tree.MerkleTree.get_leaf
bytes get_leaf(self, int index)
Definition: merkle_tree.py:107
zeth.core.merkle_tree.PersistentMerkleTree.__init__
def __init__(self, str filename, MerkleTreeData tree_data, int depth, ITreeHash tree_hash)
Definition: merkle_tree.py:213
zeth.core.merkle_tree.MerkleTree.get_root
bytes get_root(self)
Definition: merkle_tree.py:132
zeth.core.merkle_tree.MerkleTreeData.from_json_dict
MerkleTreeData from_json_dict(Dict[str, Any] json_dict)
Definition: merkle_tree.py:43
zeth.core.merkle_tree.ITreeHash.hash
bytes hash(self, bytes left, bytes right)
Definition: merkle_tree.py:24
zeth.core.merkle_tree.MerkleTreeData.to_json_dict
Dict[str, Any] to_json_dict(self)
Definition: merkle_tree.py:52
zeth.core.merkle_tree.MerkleTree.get_node
bytes get_node(self, int layer_idx, int node_idx)
Definition: merkle_tree.py:116
zeth.core.merkle_tree.PersistentMerkleTree.filename
filename
Definition: merkle_tree.py:215
zeth.core.merkle_tree.MerkleTree.get_layers
Iterator[Tuple[bytes, List[bytes]]] get_layers(self)
Definition: merkle_tree.py:125
zeth.core.merkle_tree.PersistentMerkleTree.save
None save(self)
Definition: merkle_tree.py:239
zeth.core.merkle_tree.MerkleTree.depth
depth
Definition: merkle_tree.py:67
zeth.core.merkle_tree.MerkleTree.get_leaves
List[bytes] get_leaves(self)
Definition: merkle_tree.py:113
zeth.core.merkle_tree.MerkleTree.tree_data
tree_data
Definition: merkle_tree.py:68
zeth.core.merkle_tree.ITreeHash
Definition: merkle_tree.py:17
zeth.core.merkle_tree.MerkleTree.tree_hash
tree_hash
Definition: merkle_tree.py:70
zeth.core.merkle_tree.MerkleTree.empty_with_size
MerkleTree empty_with_size(int num_leaves, ITreeHash tree_hash)
Definition: merkle_tree.py:99
zeth.core.merkle_tree.MerkleTreeData.default_values
default_values
Definition: merkle_tree.py:35
zeth.core.merkle_tree.MerkleTree
Definition: merkle_tree.py:60