Clearmatics Libsnark  0.1
C++ library for zkSNARK proofs
tally_cp.tcc
Go to the documentation of this file.
1 /** @file
2  *****************************************************************************
3 
4  Implementation of interfaces for the tally compliance predicate.
5 
6  See tally_cp.hpp .
7 
8  *****************************************************************************
9  * @author This file is part of libsnark, developed by SCIPR Lab
10  * and contributors (see AUTHORS).
11  * @copyright MIT license (see LICENSE file)
12  *****************************************************************************/
13 
14 #ifndef TALLY_CP_TCC_
15 #define TALLY_CP_TCC_
16 
17 #include <algorithm>
18 #include <functional>
19 #include <libff/algebra/fields/field_utils.hpp>
20 
21 namespace libsnark
22 {
23 
24 template<typename FieldT>
25 tally_pcd_message<FieldT>::tally_pcd_message(
26  const size_t type,
27  const size_t wordsize,
28  const size_t sum,
29  const size_t count)
30  : r1cs_pcd_message<FieldT>(type), wordsize(wordsize), sum(sum), count(count)
31 {
32 }
33 
34 template<typename FieldT>
35 r1cs_variable_assignment<FieldT> tally_pcd_message<
36  FieldT>::payload_as_r1cs_variable_assignment() const
37 {
38  std::function<FieldT(bool)> bit_to_FieldT = [](const bool bit) {
39  return bit ? FieldT::one() : FieldT::zero();
40  };
41 
42  const libff::bit_vector sum_bits =
43  libff::convert_field_element_to_bit_vector<FieldT>(sum, wordsize);
44  const libff::bit_vector count_bits =
45  libff::convert_field_element_to_bit_vector<FieldT>(count, wordsize);
46 
47  r1cs_variable_assignment<FieldT> result(2 * wordsize);
48  std::transform(
49  sum_bits.begin(), sum_bits.end(), result.begin(), bit_to_FieldT);
50  std::transform(
51  count_bits.begin(),
52  count_bits.end(),
53  result.begin() + wordsize,
54  bit_to_FieldT);
55 
56  return result;
57 }
58 
59 template<typename FieldT> void tally_pcd_message<FieldT>::print() const
60 {
61  printf("Tally message of type %zu:\n", this->type);
62  printf(" wordsize: %zu\n", wordsize);
63  printf(" sum: %zu\n", sum);
64  printf(" count: %zu\n", count);
65 }
66 
67 template<typename FieldT>
68 tally_pcd_local_data<FieldT>::tally_pcd_local_data(const size_t summand)
69  : summand(summand)
70 {
71 }
72 
73 template<typename FieldT>
74 r1cs_variable_assignment<FieldT> tally_pcd_local_data<
75  FieldT>::as_r1cs_variable_assignment() const
76 {
77  const r1cs_variable_assignment<FieldT> result = {FieldT(summand)};
78  return result;
79 }
80 
81 template<typename FieldT> void tally_pcd_local_data<FieldT>::print() const
82 {
83  printf("Tally PCD local data:\n");
84  printf(" summand: %zu\n", summand);
85 }
86 
87 template<typename FieldT>
88 class tally_pcd_message_variable : public r1cs_pcd_message_variable<FieldT>
89 {
90 public:
91  pb_variable_array<FieldT> sum_bits;
92  pb_variable_array<FieldT> count_bits;
93  size_t wordsize;
94 
95  tally_pcd_message_variable(
96  protoboard<FieldT> &pb,
97  const size_t wordsize,
98  const std::string &annotation_prefix)
99  : r1cs_pcd_message_variable<FieldT>(pb, annotation_prefix)
100  , wordsize(wordsize)
101  {
102  sum_bits.allocate(pb, wordsize, FMT(annotation_prefix, " sum_bits"));
103  count_bits.allocate(
104  pb, wordsize, FMT(annotation_prefix, " count_bits"));
105 
106  this->update_all_vars();
107  }
108 
109  std::shared_ptr<r1cs_pcd_message<FieldT>> get_message() const
110  {
111  const size_t type_val = this->pb.val(this->type).as_ulong();
112  const size_t sum_val =
113  sum_bits.get_field_element_from_bits(this->pb).as_ulong();
114  const size_t count_val =
115  count_bits.get_field_element_from_bits(this->pb).as_ulong();
116 
117  std::shared_ptr<r1cs_pcd_message<FieldT>> result;
118  result.reset(new tally_pcd_message<FieldT>(
119  type_val, wordsize, sum_val, count_val));
120  return result;
121  }
122 
123  ~tally_pcd_message_variable() = default;
124 };
125 
126 template<typename FieldT>
127 class tally_pcd_local_data_variable
128  : public r1cs_pcd_local_data_variable<FieldT>
129 {
130 public:
131  pb_variable<FieldT> summand;
132 
133  tally_pcd_local_data_variable(
134  protoboard<FieldT> &pb, const std::string &annotation_prefix)
135  : r1cs_pcd_local_data_variable<FieldT>(pb, annotation_prefix)
136  {
137  summand.allocate(pb, FMT(annotation_prefix, " summand"));
138 
139  this->update_all_vars();
140  }
141 
142  std::shared_ptr<r1cs_pcd_local_data<FieldT>> get_local_data() const
143  {
144  const size_t summand_val = this->pb.val(summand).as_ulong();
145 
146  std::shared_ptr<r1cs_pcd_local_data<FieldT>> result;
147  result.reset(new tally_pcd_local_data<FieldT>(summand_val));
148  return result;
149  }
150 
151  ~tally_pcd_local_data_variable() = default;
152 };
153 
154 template<typename FieldT>
155 tally_cp_handler<FieldT>::tally_cp_handler(
156  const size_t type,
157  const size_t max_arity,
158  const size_t wordsize,
159  const bool relies_on_same_type_inputs,
160  const std::set<size_t> accepted_input_types)
161  : compliance_predicate_handler<FieldT, protoboard<FieldT>>(
162  protoboard<FieldT>(),
163  type * 100,
164  type,
165  max_arity,
166  relies_on_same_type_inputs,
167  accepted_input_types)
168  , wordsize(wordsize)
169 {
170  this->outgoing_message.reset(new tally_pcd_message_variable<FieldT>(
171  this->pb, wordsize, "outgoing_message"));
172  this->arity.allocate(this->pb, "arity");
173 
174  for (size_t i = 0; i < max_arity; ++i) {
175  this->incoming_messages[i].reset(new tally_pcd_message_variable<FieldT>(
176  this->pb, wordsize, FMT("", "incoming_messages_%zu", i)));
177  }
178 
179  this->local_data.reset(
180  new tally_pcd_local_data_variable<FieldT>(this->pb, "local_data"));
181 
182  sum_out_packed.allocate(this->pb, "sum_out_packed");
183  count_out_packed.allocate(this->pb, "count_out_packed");
184 
185  sum_in_packed.allocate(this->pb, max_arity, "sum_in_packed");
186  count_in_packed.allocate(this->pb, max_arity, "count_in_packed");
187 
188  sum_in_packed_aux.allocate(this->pb, max_arity, "sum_in_packed_aux");
189  count_in_packed_aux.allocate(this->pb, max_arity, "count_in_packed_aux");
190 
191  type_val_inner_product.allocate(this->pb, "type_val_inner_product");
192  for (auto &msg : this->incoming_messages) {
193  incoming_types.emplace_back(msg->type);
194  }
195 
196  compute_type_val_inner_product.reset(new inner_product_gadget<FieldT>(
197  this->pb,
198  incoming_types,
199  sum_in_packed,
200  type_val_inner_product,
201  "compute_type_val_inner_product"));
202 
203  unpack_sum_out.reset(new packing_gadget<FieldT>(
204  this->pb,
205  std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
206  this->outgoing_message)
207  ->sum_bits,
208  sum_out_packed,
209  "pack_sum_out"));
210  unpack_count_out.reset(new packing_gadget<FieldT>(
211  this->pb,
212  std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
213  this->outgoing_message)
214  ->count_bits,
215  count_out_packed,
216  "pack_count_out"));
217 
218  for (size_t i = 0; i < max_arity; ++i) {
219  pack_sum_in.emplace_back(packing_gadget<FieldT>(
220  this->pb,
221  std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
222  this->incoming_messages[i])
223  ->sum_bits,
224  sum_in_packed[i],
225  FMT("", "pack_sum_in_%zu", i)));
226  pack_count_in.emplace_back(packing_gadget<FieldT>(
227  this->pb,
228  std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
229  this->incoming_messages[i])
230  ->sum_bits,
231  count_in_packed[i],
232  FMT("", "pack_count_in_%zu", i)));
233  }
234 
235  arity_indicators.allocate(this->pb, max_arity + 1, "arity_indicators");
236 }
237 
238 template<typename FieldT>
239 void tally_cp_handler<FieldT>::generate_r1cs_constraints()
240 {
241  unpack_sum_out->generate_r1cs_constraints(true);
242  unpack_count_out->generate_r1cs_constraints(true);
243 
244  for (size_t i = 0; i < this->max_arity; ++i) {
245  pack_sum_in[i].generate_r1cs_constraints(true);
246  pack_count_in[i].generate_r1cs_constraints(true);
247  }
248 
249  for (size_t i = 0; i < this->max_arity; ++i) {
250  this->pb.add_r1cs_constraint(
251  r1cs_constraint<FieldT>(
252  incoming_types[i], sum_in_packed_aux[i], sum_in_packed[i]),
253  FMT("", "initial_sum_%zu_is_zero", i));
254  this->pb.add_r1cs_constraint(
255  r1cs_constraint<FieldT>(
256  incoming_types[i], count_in_packed_aux[i], count_in_packed[i]),
257  FMT("", "initial_sum_%zu_is_zero", i));
258  }
259 
260  /* constrain arity indicator variables so that arity_indicators[arity] = 1
261  * and arity_indicators[i] = 0 for any other i */
262  for (size_t i = 0; i < this->max_arity; ++i) {
263  this->pb.add_r1cs_constraint(
264  r1cs_constraint<FieldT>(
265  this->arity - FieldT(i), arity_indicators[i], 0),
266  FMT("", "arity_indicators_%zu", i));
267  }
268 
269  this->pb.add_r1cs_constraint(
270  r1cs_constraint<FieldT>(1, pb_sum<FieldT>(arity_indicators), 1),
271  "arity_indicators");
272 
273  /* require that types of messages that are past arity (i.e. unbound wires)
274  * carry 0 */
275  for (size_t i = 0; i < this->max_arity; ++i) {
276  this->pb.add_r1cs_constraint(
277  r1cs_constraint<FieldT>(
278  0 + pb_sum<FieldT>(pb_variable_array<FieldT>(
279  arity_indicators.begin(),
280  arity_indicators.begin() + i)),
281  incoming_types[i],
282  0),
283  FMT("", "unbound_types_%zu", i));
284  }
285 
286  /* sum_out = local_data + \sum_i type[i] * sum_in[i] */
287  compute_type_val_inner_product->generate_r1cs_constraints();
288  this->pb.add_r1cs_constraint(
289  r1cs_constraint<FieldT>(
290  1,
291  type_val_inner_product +
292  std::dynamic_pointer_cast<
293  tally_pcd_local_data_variable<FieldT>>(this->local_data)
294  ->summand,
295  sum_out_packed),
296  "update_sum");
297 
298  /* count_out = 1 + \sum_i count_in[i] */
299  this->pb.add_r1cs_constraint(
300  r1cs_constraint<FieldT>(
301  1, 1 + pb_sum<FieldT>(count_in_packed), count_out_packed),
302  "update_count");
303 }
304 
305 template<typename FieldT>
306 void tally_cp_handler<FieldT>::generate_r1cs_witness(
307  const std::vector<std::shared_ptr<r1cs_pcd_message<FieldT>>>
308  &incoming_messages,
309  const std::shared_ptr<r1cs_pcd_local_data<FieldT>> &local_data)
310 {
311  base_handler::generate_r1cs_witness(incoming_messages, local_data);
312 
313  for (size_t i = 0; i < this->max_arity; ++i) {
314  pack_sum_in[i].generate_r1cs_witness_from_bits();
315  pack_count_in[i].generate_r1cs_witness_from_bits();
316 
317  if (!this->pb.val(incoming_types[i]).is_zero()) {
318  this->pb.val(sum_in_packed_aux[i]) =
319  this->pb.val(sum_in_packed[i]) *
320  this->pb.val(incoming_types[i]).inverse();
321  this->pb.val(count_in_packed_aux[i]) =
322  this->pb.val(count_in_packed[i]) *
323  this->pb.val(incoming_types[i]).inverse();
324  }
325  }
326 
327  for (size_t i = 0; i < this->max_arity + 1; ++i) {
328  this->pb.val(arity_indicators[i]) =
329  (incoming_messages.size() == i ? FieldT::one() : FieldT::zero());
330  }
331 
332  compute_type_val_inner_product->generate_r1cs_witness();
333  this->pb.val(sum_out_packed) =
334  this->pb.val(
335  std::dynamic_pointer_cast<tally_pcd_local_data_variable<FieldT>>(
336  this->local_data)
337  ->summand) +
338  this->pb.val(type_val_inner_product);
339 
340  this->pb.val(count_out_packed) = FieldT::one();
341  for (size_t i = 0; i < this->max_arity; ++i) {
342  this->pb.val(count_out_packed) += this->pb.val(count_in_packed[i]);
343  }
344 
345  unpack_sum_out->generate_r1cs_witness_from_packed();
346  unpack_count_out->generate_r1cs_witness_from_packed();
347 }
348 
349 template<typename FieldT>
350 std::shared_ptr<r1cs_pcd_message<FieldT>> tally_cp_handler<
351  FieldT>::get_base_case_message() const
352 {
353  const size_t type = 0;
354  const size_t sum = 0;
355  const size_t count = 0;
356 
357  std::shared_ptr<r1cs_pcd_message<FieldT>> result;
358  result.reset(new tally_pcd_message<FieldT>(type, wordsize, sum, count));
359 
360  return result;
361 }
362 
363 } // namespace libsnark
364 
365 #endif // TALLY_CP_TCC_