Zeth - Zerocash on Ethereum  0.8
Reference implementation of the Zeth protocol by Clearmatics
binary_operation.tcc
Go to the documentation of this file.
1 // Copyright (c) 2015-2022 Clearmatics Technologies Ltd
2 //
3 // SPDX-License-Identifier: LGPL-3.0+
4 
5 #ifndef __ZETH_CIRCUITS_BINARY_OPERATION_TCC__
6 #define __ZETH_CIRCUITS_BINARY_OPERATION_TCC__
7 
8 #include "libzeth/circuits/circuit_utils.hpp"
9 #include "libzeth/core/bits.hpp"
10 
11 #include <libsnark/gadgetlib1/gadget.hpp>
12 #include <libsnark/gadgetlib1/gadgets/basic_gadgets.hpp>
13 
14 namespace libzeth
15 {
16 
17 template<typename FieldT>
18 xor_gadget<FieldT>::xor_gadget(
19  libsnark::protoboard<FieldT> &pb,
20  const libsnark::pb_variable_array<FieldT> &a,
21  const libsnark::pb_variable_array<FieldT> &b,
22  const libsnark::pb_variable_array<FieldT> &res,
23  const std::string &annotation_prefix)
24  : libsnark::gadget<FieldT>(pb, annotation_prefix), a(a), b(b), res(res)
25 {
26  assert(a.size() == b.size());
27  assert(b.size() == res.size());
28 }
29 
30 template<typename FieldT> void xor_gadget<FieldT>::generate_r1cs_constraints()
31 {
32  // Set the constraints (#constraints = length of bit string)
33  for (size_t i = 0; i < a.size(); i++) {
34  // res = a XOR b <=> (2.a) * b = a + b - res
35  this->pb.add_r1cs_constraint(
36  libsnark::r1cs_constraint<FieldT>(
37  2 * a[i], b[i], a[i] + b[i] - res[i]),
38  FMT(this->annotation_prefix, " xored_bits_%zu", i));
39  }
40 }
41 
42 template<typename FieldT> void xor_gadget<FieldT>::generate_r1cs_witness()
43 {
44  for (size_t i = 0; i < a.size(); i++) {
45  if (this->pb.val(a[i]) == FieldT("1") &&
46  this->pb.val(b[i]) == FieldT("1")) {
47  this->pb.val(res[i]) = FieldT("0");
48  } else {
49  this->pb.val(res[i]) = this->pb.val(a[i]) + this->pb.val(b[i]);
50  }
51  }
52 }
53 
54 template<typename FieldT>
55 xor_constant_gadget<FieldT>::xor_constant_gadget(
56  libsnark::protoboard<FieldT> &pb,
57  const libsnark::pb_variable_array<FieldT> &a,
58  const libsnark::pb_variable_array<FieldT> &b,
59  const std::vector<FieldT> &c,
60  const libsnark::pb_variable_array<FieldT> &res,
61  const std::string &annotation_prefix)
62  : libsnark::gadget<FieldT>(pb, annotation_prefix)
63  , a(a)
64  , b(b)
65  , c(c)
66  , res(res)
67 {
68  assert(a.size() == b.size());
69  assert(b.size() == c.size());
70  assert(c.size() == res.size());
71 }
72 
73 template<typename FieldT>
74 void xor_constant_gadget<FieldT>::generate_r1cs_constraints()
75 {
76  // Set the constraints (#constraints = length of bit string)
77  //
78  // We know that: res = a XOR b <=> (2.a) * b = a + b - res
79  // we can write: res = a + b - 2ab
80  //
81  // Hence, res2 = a XOR b XOR c = (a XOR b) XOR c <=> res XOR c
82  // Thus: res2 = res XOR c is constrainted as:
83  // 2(res) * c = res + c - res2
84  // which leads to:
85  // 2(a + b - 2ab) * c = a + b - 2ab + c - res2
86  // => res2 = a + b - 2ab + c - 2ac - 2bc + 4abc
87  // => res2 - c = b * (4ac - 2a) + a + b - 2ac - 2bc
88  // => res2 - c = b * (4ac - 2a) + a * (1 - 2c) + b * (1 - 2c)
89  // => res2 - c - a * (1 - 2c) - b * (1 - 2c) = b * (4ac - 2a)
90  // and b * (4ac - 2a) = b * [2 * (2c - 1) *a] = b * [-2 * (1 - 2c) *a]
91  for (size_t i = 0; i < a.size(); i++) {
92  this->pb.add_r1cs_constraint(
93  libsnark::r1cs_constraint<FieldT>(
94  -FieldT("2") * (FieldT("1") - FieldT("2") * c[i]) * a[i],
95  b[i],
96  res[i] - c[i] - a[i] * (FieldT("1") - FieldT("2") * c[i]) -
97  b[i] * (FieldT("1") - FieldT("2") * c[i])),
98  FMT(this->annotation_prefix, " rotated_xored_bits_%zu", i));
99  }
100 }
101 
102 template<typename FieldT>
103 void xor_constant_gadget<FieldT>::generate_r1cs_witness()
104 {
105  for (size_t i = 0; i < a.size(); i++) {
106  if ((this->pb.val(a[i]) == FieldT("0") &&
107  this->pb.val(b[i]) == FieldT("0") && c[i] == FieldT("0")) ||
108  (this->pb.val(a[i]) == FieldT("1") &&
109  this->pb.val(b[i]) == FieldT("0") && c[i] == FieldT("1")) ||
110  (this->pb.val(a[i]) == FieldT("0") &&
111  this->pb.val(b[i]) == FieldT("1") && c[i] == FieldT("1")) ||
112  (this->pb.val(a[i]) == FieldT("1") &&
113  this->pb.val(b[i]) == FieldT("1") && c[i] == FieldT("0"))) {
114  this->pb.val(res[i]) = FieldT("0");
115  } else {
116  this->pb.val(res[i]) = FieldT("1");
117  }
118  }
119 }
120 
121 template<typename FieldT>
122 xor_rot_gadget<FieldT>::xor_rot_gadget(
123  libsnark::protoboard<FieldT> &pb,
124  const libsnark::pb_variable_array<FieldT> &a,
125  const libsnark::pb_variable_array<FieldT> &b,
126  const size_t shift,
127  const libsnark::pb_variable_array<FieldT> &res,
128  const std::string &annotation_prefix)
129  : libsnark::gadget<FieldT>(pb, annotation_prefix)
130  , a(a)
131  , b(b)
132  , shift(shift)
133  , res(res)
134 {
135  assert(a.size() == b.size());
136  assert(b.size() == res.size());
137 }
138 
139 template<typename FieldT>
140 void xor_rot_gadget<FieldT>::generate_r1cs_constraints()
141 {
142  // Set the constraints (#constraints = length of bit string)
143  for (size_t i = 0; i < a.size(); i++) {
144  this->pb.add_r1cs_constraint(
145  libsnark::r1cs_constraint<FieldT>(
146  2 * a[i], b[i], a[i] + b[i] - res[(i + shift) % a.size()]),
147  FMT(this->annotation_prefix, " rotated_xored_bits_%zu", i));
148  }
149 }
150 
151 template<typename FieldT> void xor_rot_gadget<FieldT>::generate_r1cs_witness()
152 {
153  // Set the witness (#values = length of bit string)
154  for (size_t i = 0; i < a.size(); i++) {
155  if (this->pb.val(a[i]) == FieldT("1") &&
156  this->pb.val(b[i]) == FieldT("1")) {
157  this->pb.val(res[(i + shift) % a.size()]) = FieldT("0");
158  } else {
159  this->pb.val(res[(i + shift) % a.size()]) =
160  this->pb.val(a[i]) + this->pb.val(b[i]);
161  }
162  }
163 }
164 
165 template<typename FieldT>
166 double_bit32_sum_eq_gadget<FieldT>::double_bit32_sum_eq_gadget(
167  libsnark::protoboard<FieldT> &pb,
168  const libsnark::pb_variable_array<FieldT> &a,
169  const libsnark::pb_variable_array<FieldT> &b,
170  const libsnark::pb_variable_array<FieldT> &res,
171  const std::string &annotation_prefix)
172  : libsnark::gadget<FieldT>(pb, annotation_prefix), a(a), b(b), res(res)
173 {
174  assert(a.size() == 32);
175  assert(a.size() == b.size());
176  assert(a.size() == res.size());
177 }
178 
179 template<typename FieldT>
180 void double_bit32_sum_eq_gadget<FieldT>::generate_r1cs_constraints(
181  bool enforce_boolean)
182 {
183  // We want to check that a + b = c mod 2^32
184  // A way to do this is to follow the proposed implementation
185  // section A.3.7 of the Zcash protocol spec:
186  // https://github.com/zcash/zips/blob/master/protocol/protocol.pdf
187  //
188  // Below, we propose an alternative way to constraint the result to
189  // be a boolean string and to be the valid sum of a and b.
190  //
191  // Let a and b be the input bit string of length 32 bits (uint32)
192  // Let res be the claimed result of a + b of length 33 bits (an additional
193  // bit account for the potential carry of the addition of a and b)
194  //
195  // The goal here is to:
196  // 1. Constraint the 33bits of res to make sure it is a bit string of
197  // length 33bits
198  // $\forall i \in {0, 32} c_i*(c_i-1) = 0$ (33 constraints)
199  // 2. Constraint a,b and res to make sure that a + b = res % 2^32
200  // $\sum_{i=0}^{31} (a_i + b_i) * 2^i = \sum_{i=0}^{32} c_i * 2^i$
201  //
202  // The first set of constraints can be re-written as:
203  // 1.1 $\forall i \in {0, 31} c_i*(c_i-1) = 0$ (32 constraints)
204  // 1.2 $c_{32}*(c_{32}-1) = 0$ => $2^{32}c_{32}*(2^{32}c_{32}-2^{32}) = 0$
205  // (multiply by $2^{{32}^2}$)
206  //
207  // and 2. can be rewritten as:
208  // $\sum_{i=0}^{31} (a_i + b_i - c_i) * 2^i = c_{32} * 2^{32}$
209  //
210  // Now, we can replace $2^{32}c_{32}$ in 1.2 by
211  // $\sum_{i=0}^{31} (a_i + b_i - c_i) * 2^i$
212  // and we obtain:
213  // 1.2' $[\sum_{i=0}^{31} (a_i + b_i - c_i) * 2^i] *
214  // ([\sum_{i=0}^{31} (a_i + b_i - c_i) * 2^i] - 2^{32}) = 0$
215  //
216  // Hence, we finally obtain the following constraint system of 33
217  // constraints:
218  // 1. $\forall i \in {0, 31} c_i*(c_i-1) = 0$ (32 constraints)
219  // 2. $[\sum_{i=0}^{31} (a_i + b_i - c_i) * 2^i] *
220  // ([\sum_{i=0}^{31} (a_i + b_i - c_i) * 2^i] - 2^{32}) = 0$
221  // (1 constraint)
222 
223  // 1. Implement the first set of constraints:
224  // $\forall i \in {0, 31} c_i*(c_i-1) = 0$
225  if (enforce_boolean) {
226  for (size_t i = 0; i < 32; i++) {
227  libsnark::generate_boolean_r1cs_constraint<FieldT>(
228  this->pb, res[i], FMT(this->annotation_prefix, " res[%zu]", i));
229  }
230  }
231 
232  libsnark::linear_combination<FieldT> left_side =
233  packed_addition(a) + packed_addition(b);
234 
235  // 2. Final constraint:
236  // $[\sum_{i=0}^{31} (a_i + b_i - c_i) * 2^i] *
237  // ([\sum_{i=0}^{31} (a_i + b_i - c_i) * 2^i] - 2^{32}) = 0$
238  // The only way to satisfy this constraint is to have either:
239  // a. left_side = res + 0 * 2*32, or
240  // b. left_side = res + 1 * 2^32
241  // This constraint leverages the fact that the sum of two N-bit numbers
242  // can at most lead to a (N+1)-bit number.
243  this->pb.add_r1cs_constraint(
244  libsnark::r1cs_constraint<FieldT>(
245  (left_side - packed_addition(res)),
246  (left_side - packed_addition(res) - pow(2, 32)),
247  0),
248  FMT(this->annotation_prefix, " sum_equal_sum_constraint"));
249 }
250 
251 template<typename FieldT>
252 void double_bit32_sum_eq_gadget<FieldT>::generate_r1cs_witness()
253 {
254  bits32 a_bits32 = bits32::from_vector(a.get_bits(this->pb));
255  bits32 b_bits32 = bits32::from_vector(b.get_bits(this->pb));
256  bits32 left_side_acc = bits_add<32>(a_bits32, b_bits32, false);
257  left_side_acc.fill_pb_variable_array(this->pb, res);
258 }
259 
260 } // namespace libzeth
261 
262 #endif // __ZETH_CIRCUITS_BINARY_OPERATION_TCC__