2 *****************************************************************************
4 Implementation of interfaces for the tally compliance predicate.
8 *****************************************************************************
9 * @author This file is part of libsnark, developed by SCIPR Lab
10 * and contributors (see AUTHORS).
11 * @copyright MIT license (see LICENSE file)
12 *****************************************************************************/
19 #include <libff/algebra/fields/field_utils.hpp>
24 template<typename FieldT>
25 tally_pcd_message<FieldT>::tally_pcd_message(
27 const size_t wordsize,
30 : r1cs_pcd_message<FieldT>(type), wordsize(wordsize), sum(sum), count(count)
34 template<typename FieldT>
35 r1cs_variable_assignment<FieldT> tally_pcd_message<
36 FieldT>::payload_as_r1cs_variable_assignment() const
38 std::function<FieldT(bool)> bit_to_FieldT = [](const bool bit) {
39 return bit ? FieldT::one() : FieldT::zero();
42 const libff::bit_vector sum_bits =
43 libff::convert_field_element_to_bit_vector<FieldT>(sum, wordsize);
44 const libff::bit_vector count_bits =
45 libff::convert_field_element_to_bit_vector<FieldT>(count, wordsize);
47 r1cs_variable_assignment<FieldT> result(2 * wordsize);
49 sum_bits.begin(), sum_bits.end(), result.begin(), bit_to_FieldT);
53 result.begin() + wordsize,
59 template<typename FieldT> void tally_pcd_message<FieldT>::print() const
61 printf("Tally message of type %zu:\n", this->type);
62 printf(" wordsize: %zu\n", wordsize);
63 printf(" sum: %zu\n", sum);
64 printf(" count: %zu\n", count);
67 template<typename FieldT>
68 tally_pcd_local_data<FieldT>::tally_pcd_local_data(const size_t summand)
73 template<typename FieldT>
74 r1cs_variable_assignment<FieldT> tally_pcd_local_data<
75 FieldT>::as_r1cs_variable_assignment() const
77 const r1cs_variable_assignment<FieldT> result = {FieldT(summand)};
81 template<typename FieldT> void tally_pcd_local_data<FieldT>::print() const
83 printf("Tally PCD local data:\n");
84 printf(" summand: %zu\n", summand);
87 template<typename FieldT>
88 class tally_pcd_message_variable : public r1cs_pcd_message_variable<FieldT>
91 pb_variable_array<FieldT> sum_bits;
92 pb_variable_array<FieldT> count_bits;
95 tally_pcd_message_variable(
96 protoboard<FieldT> &pb,
97 const size_t wordsize,
98 const std::string &annotation_prefix)
99 : r1cs_pcd_message_variable<FieldT>(pb, annotation_prefix)
102 sum_bits.allocate(pb, wordsize, FMT(annotation_prefix, " sum_bits"));
104 pb, wordsize, FMT(annotation_prefix, " count_bits"));
106 this->update_all_vars();
109 std::shared_ptr<r1cs_pcd_message<FieldT>> get_message() const
111 const size_t type_val = this->pb.val(this->type).as_ulong();
112 const size_t sum_val =
113 sum_bits.get_field_element_from_bits(this->pb).as_ulong();
114 const size_t count_val =
115 count_bits.get_field_element_from_bits(this->pb).as_ulong();
117 std::shared_ptr<r1cs_pcd_message<FieldT>> result;
118 result.reset(new tally_pcd_message<FieldT>(
119 type_val, wordsize, sum_val, count_val));
123 ~tally_pcd_message_variable() = default;
126 template<typename FieldT>
127 class tally_pcd_local_data_variable
128 : public r1cs_pcd_local_data_variable<FieldT>
131 pb_variable<FieldT> summand;
133 tally_pcd_local_data_variable(
134 protoboard<FieldT> &pb, const std::string &annotation_prefix)
135 : r1cs_pcd_local_data_variable<FieldT>(pb, annotation_prefix)
137 summand.allocate(pb, FMT(annotation_prefix, " summand"));
139 this->update_all_vars();
142 std::shared_ptr<r1cs_pcd_local_data<FieldT>> get_local_data() const
144 const size_t summand_val = this->pb.val(summand).as_ulong();
146 std::shared_ptr<r1cs_pcd_local_data<FieldT>> result;
147 result.reset(new tally_pcd_local_data<FieldT>(summand_val));
151 ~tally_pcd_local_data_variable() = default;
154 template<typename FieldT>
155 tally_cp_handler<FieldT>::tally_cp_handler(
157 const size_t max_arity,
158 const size_t wordsize,
159 const bool relies_on_same_type_inputs,
160 const std::set<size_t> accepted_input_types)
161 : compliance_predicate_handler<FieldT, protoboard<FieldT>>(
162 protoboard<FieldT>(),
166 relies_on_same_type_inputs,
167 accepted_input_types)
170 this->outgoing_message.reset(new tally_pcd_message_variable<FieldT>(
171 this->pb, wordsize, "outgoing_message"));
172 this->arity.allocate(this->pb, "arity");
174 for (size_t i = 0; i < max_arity; ++i) {
175 this->incoming_messages[i].reset(new tally_pcd_message_variable<FieldT>(
176 this->pb, wordsize, FMT("", "incoming_messages_%zu", i)));
179 this->local_data.reset(
180 new tally_pcd_local_data_variable<FieldT>(this->pb, "local_data"));
182 sum_out_packed.allocate(this->pb, "sum_out_packed");
183 count_out_packed.allocate(this->pb, "count_out_packed");
185 sum_in_packed.allocate(this->pb, max_arity, "sum_in_packed");
186 count_in_packed.allocate(this->pb, max_arity, "count_in_packed");
188 sum_in_packed_aux.allocate(this->pb, max_arity, "sum_in_packed_aux");
189 count_in_packed_aux.allocate(this->pb, max_arity, "count_in_packed_aux");
191 type_val_inner_product.allocate(this->pb, "type_val_inner_product");
192 for (auto &msg : this->incoming_messages) {
193 incoming_types.emplace_back(msg->type);
196 compute_type_val_inner_product.reset(new inner_product_gadget<FieldT>(
200 type_val_inner_product,
201 "compute_type_val_inner_product"));
203 unpack_sum_out.reset(new packing_gadget<FieldT>(
205 std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
206 this->outgoing_message)
210 unpack_count_out.reset(new packing_gadget<FieldT>(
212 std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
213 this->outgoing_message)
218 for (size_t i = 0; i < max_arity; ++i) {
219 pack_sum_in.emplace_back(packing_gadget<FieldT>(
221 std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
222 this->incoming_messages[i])
225 FMT("", "pack_sum_in_%zu", i)));
226 pack_count_in.emplace_back(packing_gadget<FieldT>(
228 std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
229 this->incoming_messages[i])
232 FMT("", "pack_count_in_%zu", i)));
235 arity_indicators.allocate(this->pb, max_arity + 1, "arity_indicators");
238 template<typename FieldT>
239 void tally_cp_handler<FieldT>::generate_r1cs_constraints()
241 unpack_sum_out->generate_r1cs_constraints(true);
242 unpack_count_out->generate_r1cs_constraints(true);
244 for (size_t i = 0; i < this->max_arity; ++i) {
245 pack_sum_in[i].generate_r1cs_constraints(true);
246 pack_count_in[i].generate_r1cs_constraints(true);
249 for (size_t i = 0; i < this->max_arity; ++i) {
250 this->pb.add_r1cs_constraint(
251 r1cs_constraint<FieldT>(
252 incoming_types[i], sum_in_packed_aux[i], sum_in_packed[i]),
253 FMT("", "initial_sum_%zu_is_zero", i));
254 this->pb.add_r1cs_constraint(
255 r1cs_constraint<FieldT>(
256 incoming_types[i], count_in_packed_aux[i], count_in_packed[i]),
257 FMT("", "initial_sum_%zu_is_zero", i));
260 /* constrain arity indicator variables so that arity_indicators[arity] = 1
261 * and arity_indicators[i] = 0 for any other i */
262 for (size_t i = 0; i < this->max_arity; ++i) {
263 this->pb.add_r1cs_constraint(
264 r1cs_constraint<FieldT>(
265 this->arity - FieldT(i), arity_indicators[i], 0),
266 FMT("", "arity_indicators_%zu", i));
269 this->pb.add_r1cs_constraint(
270 r1cs_constraint<FieldT>(1, pb_sum<FieldT>(arity_indicators), 1),
273 /* require that types of messages that are past arity (i.e. unbound wires)
275 for (size_t i = 0; i < this->max_arity; ++i) {
276 this->pb.add_r1cs_constraint(
277 r1cs_constraint<FieldT>(
278 0 + pb_sum<FieldT>(pb_variable_array<FieldT>(
279 arity_indicators.begin(),
280 arity_indicators.begin() + i)),
283 FMT("", "unbound_types_%zu", i));
286 /* sum_out = local_data + \sum_i type[i] * sum_in[i] */
287 compute_type_val_inner_product->generate_r1cs_constraints();
288 this->pb.add_r1cs_constraint(
289 r1cs_constraint<FieldT>(
291 type_val_inner_product +
292 std::dynamic_pointer_cast<
293 tally_pcd_local_data_variable<FieldT>>(this->local_data)
298 /* count_out = 1 + \sum_i count_in[i] */
299 this->pb.add_r1cs_constraint(
300 r1cs_constraint<FieldT>(
301 1, 1 + pb_sum<FieldT>(count_in_packed), count_out_packed),
305 template<typename FieldT>
306 void tally_cp_handler<FieldT>::generate_r1cs_witness(
307 const std::vector<std::shared_ptr<r1cs_pcd_message<FieldT>>>
309 const std::shared_ptr<r1cs_pcd_local_data<FieldT>> &local_data)
311 base_handler::generate_r1cs_witness(incoming_messages, local_data);
313 for (size_t i = 0; i < this->max_arity; ++i) {
314 pack_sum_in[i].generate_r1cs_witness_from_bits();
315 pack_count_in[i].generate_r1cs_witness_from_bits();
317 if (!this->pb.val(incoming_types[i]).is_zero()) {
318 this->pb.val(sum_in_packed_aux[i]) =
319 this->pb.val(sum_in_packed[i]) *
320 this->pb.val(incoming_types[i]).inverse();
321 this->pb.val(count_in_packed_aux[i]) =
322 this->pb.val(count_in_packed[i]) *
323 this->pb.val(incoming_types[i]).inverse();
327 for (size_t i = 0; i < this->max_arity + 1; ++i) {
328 this->pb.val(arity_indicators[i]) =
329 (incoming_messages.size() == i ? FieldT::one() : FieldT::zero());
332 compute_type_val_inner_product->generate_r1cs_witness();
333 this->pb.val(sum_out_packed) =
335 std::dynamic_pointer_cast<tally_pcd_local_data_variable<FieldT>>(
338 this->pb.val(type_val_inner_product);
340 this->pb.val(count_out_packed) = FieldT::one();
341 for (size_t i = 0; i < this->max_arity; ++i) {
342 this->pb.val(count_out_packed) += this->pb.val(count_in_packed[i]);
345 unpack_sum_out->generate_r1cs_witness_from_packed();
346 unpack_count_out->generate_r1cs_witness_from_packed();
349 template<typename FieldT>
350 std::shared_ptr<r1cs_pcd_message<FieldT>> tally_cp_handler<
351 FieldT>::get_base_case_message() const
353 const size_t type = 0;
354 const size_t sum = 0;
355 const size_t count = 0;
357 std::shared_ptr<r1cs_pcd_message<FieldT>> result;
358 result.reset(new tally_pcd_message<FieldT>(type, wordsize, sum, count));
363 } // namespace libsnark
365 #endif // TALLY_CP_TCC_