Clearmatics Libsnark  0.1
C++ library for zkSNARK proofs
benes_routing_gadget.tcc
Go to the documentation of this file.
1 /** @file
2  *****************************************************************************
3 
4  Implementation of interfaces for the Benes routing gadget.
5 
6  See benes_routing_gadget.hpp .
7 
8  *****************************************************************************
9  * @author This file is part of libsnark, developed by SCIPR Lab
10  * and contributors (see AUTHORS).
11  * @copyright MIT license (see LICENSE file)
12  *****************************************************************************/
13 
14 #ifndef BENES_ROUTING_GADGET_TCC_
15 #define BENES_ROUTING_GADGET_TCC_
16 
17 #include <algorithm>
18 #include <libff/common/profiling.hpp>
19 
20 namespace libsnark
21 {
22 
23 template<typename FieldT>
24 benes_routing_gadget<FieldT>::benes_routing_gadget(
25  protoboard<FieldT> &pb,
26  const size_t num_packets,
27  const std::vector<pb_variable_array<FieldT>> &routing_input_bits,
28  const std::vector<pb_variable_array<FieldT>> &routing_output_bits,
29  const size_t lines_to_unpack,
30  const std::string &annotation_prefix)
31  : gadget<FieldT>(pb, annotation_prefix)
32  , num_packets(num_packets)
33  , num_columns(benes_num_columns(num_packets))
34  , routing_input_bits(routing_input_bits)
35  , routing_output_bits(routing_output_bits)
36  , lines_to_unpack(lines_to_unpack)
37  , packet_size(routing_input_bits[0].size())
38  , num_subpackets(libff::div_ceil(packet_size, FieldT::capacity()))
39 {
40  assert(lines_to_unpack <= routing_input_bits.size());
41  assert(num_packets == 1ul << libff::log2(num_packets));
42  assert(routing_input_bits.size() == num_packets);
43 
44  neighbors = generate_benes_topology(num_packets);
45 
46  routed_packets.resize(num_columns + 1);
47  for (size_t column_idx = 0; column_idx <= num_columns; ++column_idx) {
48  routed_packets[column_idx].resize(num_packets);
49  for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
50  routed_packets[column_idx][packet_idx].allocate(
51  pb,
52  num_subpackets,
53  FMT(annotation_prefix,
54  " routed_packets_%zu_%zu",
55  column_idx,
56  packet_idx));
57  }
58  }
59 
60  pack_inputs.reserve(num_packets);
61  unpack_outputs.reserve(num_packets);
62 
63  for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
64  pack_inputs.emplace_back(multipacking_gadget<FieldT>(
65  pb,
66  pb_variable_array<FieldT>(
67  routing_input_bits[packet_idx].begin(),
68  routing_input_bits[packet_idx].end()),
69  routed_packets[0][packet_idx],
70  FieldT::capacity(),
71  FMT(this->annotation_prefix, " pack_inputs_%zu", packet_idx)));
72  if (packet_idx < lines_to_unpack) {
73  unpack_outputs.emplace_back(multipacking_gadget<FieldT>(
74  pb,
75  pb_variable_array<FieldT>(
76  routing_output_bits[packet_idx].begin(),
77  routing_output_bits[packet_idx].end()),
78  routed_packets[num_columns][packet_idx],
79  FieldT::capacity(),
80  FMT(this->annotation_prefix,
81  " unpack_outputs_%zu",
82  packet_idx)));
83  }
84  }
85 
86  if (num_subpackets > 1) {
87  benes_switch_bits.resize(num_columns);
88  for (size_t column_idx = 0; column_idx < num_columns; ++column_idx) {
89  benes_switch_bits[column_idx].allocate(
90  pb,
91  num_packets,
92  FMT(this->annotation_prefix,
93  " benes_switch_bits_%zu",
94  column_idx));
95  }
96  }
97 }
98 
99 template<typename FieldT>
100 void benes_routing_gadget<FieldT>::generate_r1cs_constraints()
101 {
102  /* packing/unpacking */
103  for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
104  pack_inputs[packet_idx].generate_r1cs_constraints(false);
105  if (packet_idx < lines_to_unpack) {
106  unpack_outputs[packet_idx].generate_r1cs_constraints(true);
107  } else {
108  for (size_t subpacket_idx = 0; subpacket_idx < num_subpackets;
109  ++subpacket_idx) {
110  this->pb.add_r1cs_constraint(
111  r1cs_constraint<FieldT>(
112  1,
113  routed_packets[0][packet_idx][subpacket_idx],
114  routed_packets[num_columns][packet_idx][subpacket_idx]),
115  FMT(this->annotation_prefix,
116  " fix_line_%zu_subpacket_%zu",
117  packet_idx,
118  subpacket_idx));
119  }
120  }
121  }
122 
123  /* actual routing constraints */
124  for (size_t column_idx = 0; column_idx < num_columns; ++column_idx) {
125  for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
126  const size_t straight_edge =
127  neighbors[column_idx][packet_idx].first;
128  const size_t cross_edge = neighbors[column_idx][packet_idx].second;
129 
130  if (num_subpackets == 1) {
131  /* easy case: (cur-next)*(cur-cross) = 0 */
132  this->pb.add_r1cs_constraint(
133  r1cs_constraint<FieldT>(
134  routed_packets[column_idx][packet_idx][0] -
135  routed_packets[column_idx + 1][straight_edge][0],
136  routed_packets[column_idx][packet_idx][0] -
137  routed_packets[column_idx + 1][cross_edge][0],
138  0),
139  FMT(this->annotation_prefix,
140  " easy_route_%zu_%zu",
141  column_idx,
142  packet_idx));
143  } else {
144  /* routing bit must be boolean */
145  generate_boolean_r1cs_constraint<FieldT>(
146  this->pb,
147  benes_switch_bits[column_idx][packet_idx],
148  FMT(this->annotation_prefix,
149  " routing_bit_%zu_%zu",
150  column_idx,
151  packet_idx));
152 
153  /* route forward according to routing bits */
154  for (size_t subpacket_idx = 0; subpacket_idx < num_subpackets;
155  ++subpacket_idx) {
156  /*
157  (1-switch_bit) * (cur-straight_edge) + switch_bit *
158  (cur-cross_edge) = 0 switch_bit *
159  (cross_edge-straight_edge) = cur-straight_edge
160  */
161  this->pb.add_r1cs_constraint(
162  r1cs_constraint<FieldT>(
163  benes_switch_bits[column_idx][packet_idx],
164  routed_packets[column_idx + 1][cross_edge]
165  [subpacket_idx] -
166  routed_packets[column_idx + 1][straight_edge]
167  [subpacket_idx],
168  routed_packets[column_idx][packet_idx]
169  [subpacket_idx] -
170  routed_packets[column_idx + 1][straight_edge]
171  [subpacket_idx]),
172  FMT(this->annotation_prefix,
173  " route_forward_%zu_%zu_%zu",
174  column_idx,
175  packet_idx,
176  subpacket_idx));
177  }
178  }
179  }
180  }
181 }
182 
183 template<typename FieldT>
184 void benes_routing_gadget<FieldT>::generate_r1cs_witness(
185  const integer_permutation &permutation)
186 {
187  /* pack inputs */
188  for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
189  pack_inputs[packet_idx].generate_r1cs_witness_from_bits();
190  }
191 
192  /* do the routing */
193  const benes_routing routing = get_benes_routing(permutation);
194 
195  for (size_t column_idx = 0; column_idx < num_columns; ++column_idx) {
196  for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
197  const size_t straight_edge =
198  neighbors[column_idx][packet_idx].first;
199  const size_t cross_edge = neighbors[column_idx][packet_idx].second;
200 
201  if (num_subpackets > 1) {
202  this->pb.val(benes_switch_bits[column_idx][packet_idx]) =
203  FieldT(routing[column_idx][packet_idx] ? 1 : 0);
204  }
205 
206  for (size_t subpacket_idx = 0; subpacket_idx < num_subpackets;
207  ++subpacket_idx) {
208  this->pb.val(
209  routing[column_idx][packet_idx]
210  ? routed_packets[column_idx + 1][cross_edge]
211  [subpacket_idx]
212  : routed_packets[column_idx + 1][straight_edge]
213  [subpacket_idx]) =
214  this->pb.val(
215  routed_packets[column_idx][packet_idx][subpacket_idx]);
216  }
217  }
218  }
219 
220  /* unpack outputs */
221  for (size_t packet_idx = 0; packet_idx < lines_to_unpack; ++packet_idx) {
222  unpack_outputs[packet_idx].generate_r1cs_witness_from_packed();
223  }
224 }
225 
226 template<typename FieldT>
227 void test_benes_routing_gadget(
228  const size_t num_packets, const size_t packet_size)
229 {
230  const size_t dimension = libff::log2(num_packets);
231  assert(num_packets == 1ul << dimension);
232 
233  printf(
234  "testing benes_routing_gadget by routing 2^%zu-entry vector of %zu "
235  "bits (Fp fits all %zu bit integers)\n",
236  dimension,
237  packet_size,
238  FieldT::capacity());
239 
240  protoboard<FieldT> pb;
241  integer_permutation permutation(num_packets);
242  permutation.random_shuffle();
243  libff::print_time("generated permutation");
244 
245  std::vector<pb_variable_array<FieldT>> randbits(num_packets),
246  outbits(num_packets);
247  for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
248  randbits[packet_idx].allocate(
249  pb, packet_size, FMT("", "randbits_%zu", packet_idx));
250  outbits[packet_idx].allocate(
251  pb, packet_size, FMT("", "outbits_%zu", packet_idx));
252 
253  for (size_t bit_idx = 0; bit_idx < packet_size; ++bit_idx) {
254  pb.val(randbits[packet_idx][bit_idx]) =
255  (rand() % 2) ? FieldT::one() : FieldT::zero();
256  }
257  }
258  libff::print_time("generated bits to be routed");
259 
260  benes_routing_gadget<FieldT> r(
261  pb, num_packets, randbits, outbits, num_packets, "main_routing_gadget");
262  r.generate_r1cs_constraints();
263  libff::print_time("generated routing constraints");
264 
265  r.generate_r1cs_witness(permutation);
266  libff::print_time("generated routing assignment");
267 
268  printf("positive test\n");
269  assert(pb.is_satisfied());
270  for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
271  for (size_t bit_idx = 0; bit_idx < packet_size; ++bit_idx) {
272  assert(
273  pb.val(outbits[permutation.get(packet_idx)][bit_idx]) ==
274  pb.val(randbits[packet_idx][bit_idx]));
275  }
276  }
277 
278  printf("negative test\n");
279  pb.val(pb_variable<FieldT>(10)) = FieldT(12345);
280  assert(!pb.is_satisfied());
281 
282  printf(
283  "num_constraints = %zu, num_variables = %zu\n",
284  pb.num_constraints(),
285  pb.constraint_system.num_variables);
286 }
287 
288 } // namespace libsnark
289 
290 #endif // BENES_ROUTING_GADGET_TCC_