2 *****************************************************************************
4 Implementation of interfaces for the Benes routing gadget.
6 See benes_routing_gadget.hpp .
8 *****************************************************************************
9 * @author This file is part of libsnark, developed by SCIPR Lab
10 * and contributors (see AUTHORS).
11 * @copyright MIT license (see LICENSE file)
12 *****************************************************************************/
14 #ifndef BENES_ROUTING_GADGET_TCC_
15 #define BENES_ROUTING_GADGET_TCC_
18 #include <libff/common/profiling.hpp>
23 template<typename FieldT>
24 benes_routing_gadget<FieldT>::benes_routing_gadget(
25 protoboard<FieldT> &pb,
26 const size_t num_packets,
27 const std::vector<pb_variable_array<FieldT>> &routing_input_bits,
28 const std::vector<pb_variable_array<FieldT>> &routing_output_bits,
29 const size_t lines_to_unpack,
30 const std::string &annotation_prefix)
31 : gadget<FieldT>(pb, annotation_prefix)
32 , num_packets(num_packets)
33 , num_columns(benes_num_columns(num_packets))
34 , routing_input_bits(routing_input_bits)
35 , routing_output_bits(routing_output_bits)
36 , lines_to_unpack(lines_to_unpack)
37 , packet_size(routing_input_bits[0].size())
38 , num_subpackets(libff::div_ceil(packet_size, FieldT::capacity()))
40 assert(lines_to_unpack <= routing_input_bits.size());
41 assert(num_packets == 1ul << libff::log2(num_packets));
42 assert(routing_input_bits.size() == num_packets);
44 neighbors = generate_benes_topology(num_packets);
46 routed_packets.resize(num_columns + 1);
47 for (size_t column_idx = 0; column_idx <= num_columns; ++column_idx) {
48 routed_packets[column_idx].resize(num_packets);
49 for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
50 routed_packets[column_idx][packet_idx].allocate(
53 FMT(annotation_prefix,
54 " routed_packets_%zu_%zu",
60 pack_inputs.reserve(num_packets);
61 unpack_outputs.reserve(num_packets);
63 for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
64 pack_inputs.emplace_back(multipacking_gadget<FieldT>(
66 pb_variable_array<FieldT>(
67 routing_input_bits[packet_idx].begin(),
68 routing_input_bits[packet_idx].end()),
69 routed_packets[0][packet_idx],
71 FMT(this->annotation_prefix, " pack_inputs_%zu", packet_idx)));
72 if (packet_idx < lines_to_unpack) {
73 unpack_outputs.emplace_back(multipacking_gadget<FieldT>(
75 pb_variable_array<FieldT>(
76 routing_output_bits[packet_idx].begin(),
77 routing_output_bits[packet_idx].end()),
78 routed_packets[num_columns][packet_idx],
80 FMT(this->annotation_prefix,
81 " unpack_outputs_%zu",
86 if (num_subpackets > 1) {
87 benes_switch_bits.resize(num_columns);
88 for (size_t column_idx = 0; column_idx < num_columns; ++column_idx) {
89 benes_switch_bits[column_idx].allocate(
92 FMT(this->annotation_prefix,
93 " benes_switch_bits_%zu",
99 template<typename FieldT>
100 void benes_routing_gadget<FieldT>::generate_r1cs_constraints()
102 /* packing/unpacking */
103 for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
104 pack_inputs[packet_idx].generate_r1cs_constraints(false);
105 if (packet_idx < lines_to_unpack) {
106 unpack_outputs[packet_idx].generate_r1cs_constraints(true);
108 for (size_t subpacket_idx = 0; subpacket_idx < num_subpackets;
110 this->pb.add_r1cs_constraint(
111 r1cs_constraint<FieldT>(
113 routed_packets[0][packet_idx][subpacket_idx],
114 routed_packets[num_columns][packet_idx][subpacket_idx]),
115 FMT(this->annotation_prefix,
116 " fix_line_%zu_subpacket_%zu",
123 /* actual routing constraints */
124 for (size_t column_idx = 0; column_idx < num_columns; ++column_idx) {
125 for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
126 const size_t straight_edge =
127 neighbors[column_idx][packet_idx].first;
128 const size_t cross_edge = neighbors[column_idx][packet_idx].second;
130 if (num_subpackets == 1) {
131 /* easy case: (cur-next)*(cur-cross) = 0 */
132 this->pb.add_r1cs_constraint(
133 r1cs_constraint<FieldT>(
134 routed_packets[column_idx][packet_idx][0] -
135 routed_packets[column_idx + 1][straight_edge][0],
136 routed_packets[column_idx][packet_idx][0] -
137 routed_packets[column_idx + 1][cross_edge][0],
139 FMT(this->annotation_prefix,
140 " easy_route_%zu_%zu",
144 /* routing bit must be boolean */
145 generate_boolean_r1cs_constraint<FieldT>(
147 benes_switch_bits[column_idx][packet_idx],
148 FMT(this->annotation_prefix,
149 " routing_bit_%zu_%zu",
153 /* route forward according to routing bits */
154 for (size_t subpacket_idx = 0; subpacket_idx < num_subpackets;
157 (1-switch_bit) * (cur-straight_edge) + switch_bit *
158 (cur-cross_edge) = 0 switch_bit *
159 (cross_edge-straight_edge) = cur-straight_edge
161 this->pb.add_r1cs_constraint(
162 r1cs_constraint<FieldT>(
163 benes_switch_bits[column_idx][packet_idx],
164 routed_packets[column_idx + 1][cross_edge]
166 routed_packets[column_idx + 1][straight_edge]
168 routed_packets[column_idx][packet_idx]
170 routed_packets[column_idx + 1][straight_edge]
172 FMT(this->annotation_prefix,
173 " route_forward_%zu_%zu_%zu",
183 template<typename FieldT>
184 void benes_routing_gadget<FieldT>::generate_r1cs_witness(
185 const integer_permutation &permutation)
188 for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
189 pack_inputs[packet_idx].generate_r1cs_witness_from_bits();
193 const benes_routing routing = get_benes_routing(permutation);
195 for (size_t column_idx = 0; column_idx < num_columns; ++column_idx) {
196 for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
197 const size_t straight_edge =
198 neighbors[column_idx][packet_idx].first;
199 const size_t cross_edge = neighbors[column_idx][packet_idx].second;
201 if (num_subpackets > 1) {
202 this->pb.val(benes_switch_bits[column_idx][packet_idx]) =
203 FieldT(routing[column_idx][packet_idx] ? 1 : 0);
206 for (size_t subpacket_idx = 0; subpacket_idx < num_subpackets;
209 routing[column_idx][packet_idx]
210 ? routed_packets[column_idx + 1][cross_edge]
212 : routed_packets[column_idx + 1][straight_edge]
215 routed_packets[column_idx][packet_idx][subpacket_idx]);
221 for (size_t packet_idx = 0; packet_idx < lines_to_unpack; ++packet_idx) {
222 unpack_outputs[packet_idx].generate_r1cs_witness_from_packed();
226 template<typename FieldT>
227 void test_benes_routing_gadget(
228 const size_t num_packets, const size_t packet_size)
230 const size_t dimension = libff::log2(num_packets);
231 assert(num_packets == 1ul << dimension);
234 "testing benes_routing_gadget by routing 2^%zu-entry vector of %zu "
235 "bits (Fp fits all %zu bit integers)\n",
240 protoboard<FieldT> pb;
241 integer_permutation permutation(num_packets);
242 permutation.random_shuffle();
243 libff::print_time("generated permutation");
245 std::vector<pb_variable_array<FieldT>> randbits(num_packets),
246 outbits(num_packets);
247 for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
248 randbits[packet_idx].allocate(
249 pb, packet_size, FMT("", "randbits_%zu", packet_idx));
250 outbits[packet_idx].allocate(
251 pb, packet_size, FMT("", "outbits_%zu", packet_idx));
253 for (size_t bit_idx = 0; bit_idx < packet_size; ++bit_idx) {
254 pb.val(randbits[packet_idx][bit_idx]) =
255 (rand() % 2) ? FieldT::one() : FieldT::zero();
258 libff::print_time("generated bits to be routed");
260 benes_routing_gadget<FieldT> r(
261 pb, num_packets, randbits, outbits, num_packets, "main_routing_gadget");
262 r.generate_r1cs_constraints();
263 libff::print_time("generated routing constraints");
265 r.generate_r1cs_witness(permutation);
266 libff::print_time("generated routing assignment");
268 printf("positive test\n");
269 assert(pb.is_satisfied());
270 for (size_t packet_idx = 0; packet_idx < num_packets; ++packet_idx) {
271 for (size_t bit_idx = 0; bit_idx < packet_size; ++bit_idx) {
273 pb.val(outbits[permutation.get(packet_idx)][bit_idx]) ==
274 pb.val(randbits[packet_idx][bit_idx]));
278 printf("negative test\n");
279 pb.val(pb_variable<FieldT>(10)) = FieldT(12345);
280 assert(!pb.is_satisfied());
283 "num_constraints = %zu, num_variables = %zu\n",
284 pb.num_constraints(),
285 pb.constraint_system.num_variables);
288 } // namespace libsnark
290 #endif // BENES_ROUTING_GADGET_TCC_