2  *****************************************************************************
 
    4  Implementation of interfaces for the tally compliance predicate.
 
    8  *****************************************************************************
 
    9  * @author     This file is part of libsnark, developed by SCIPR Lab
 
   10  *             and contributors (see AUTHORS).
 
   11  * @copyright  MIT license (see LICENSE file)
 
   12  *****************************************************************************/
 
   19 #include <libff/algebra/fields/field_utils.hpp>
 
   24 template<typename FieldT>
 
   25 tally_pcd_message<FieldT>::tally_pcd_message(
 
   27     const size_t wordsize,
 
   30     : r1cs_pcd_message<FieldT>(type), wordsize(wordsize), sum(sum), count(count)
 
   34 template<typename FieldT>
 
   35 r1cs_variable_assignment<FieldT> tally_pcd_message<
 
   36     FieldT>::payload_as_r1cs_variable_assignment() const
 
   38     std::function<FieldT(bool)> bit_to_FieldT = [](const bool bit) {
 
   39         return bit ? FieldT::one() : FieldT::zero();
 
   42     const libff::bit_vector sum_bits =
 
   43         libff::convert_field_element_to_bit_vector<FieldT>(sum, wordsize);
 
   44     const libff::bit_vector count_bits =
 
   45         libff::convert_field_element_to_bit_vector<FieldT>(count, wordsize);
 
   47     r1cs_variable_assignment<FieldT> result(2 * wordsize);
 
   49         sum_bits.begin(), sum_bits.end(), result.begin(), bit_to_FieldT);
 
   53         result.begin() + wordsize,
 
   59 template<typename FieldT> void tally_pcd_message<FieldT>::print() const
 
   61     printf("Tally message of type %zu:\n", this->type);
 
   62     printf("  wordsize: %zu\n", wordsize);
 
   63     printf("  sum: %zu\n", sum);
 
   64     printf("  count: %zu\n", count);
 
   67 template<typename FieldT>
 
   68 tally_pcd_local_data<FieldT>::tally_pcd_local_data(const size_t summand)
 
   73 template<typename FieldT>
 
   74 r1cs_variable_assignment<FieldT> tally_pcd_local_data<
 
   75     FieldT>::as_r1cs_variable_assignment() const
 
   77     const r1cs_variable_assignment<FieldT> result = {FieldT(summand)};
 
   81 template<typename FieldT> void tally_pcd_local_data<FieldT>::print() const
 
   83     printf("Tally PCD local data:\n");
 
   84     printf("  summand: %zu\n", summand);
 
   87 template<typename FieldT>
 
   88 class tally_pcd_message_variable : public r1cs_pcd_message_variable<FieldT>
 
   91     pb_variable_array<FieldT> sum_bits;
 
   92     pb_variable_array<FieldT> count_bits;
 
   95     tally_pcd_message_variable(
 
   96         protoboard<FieldT> &pb,
 
   97         const size_t wordsize,
 
   98         const std::string &annotation_prefix)
 
   99         : r1cs_pcd_message_variable<FieldT>(pb, annotation_prefix)
 
  102         sum_bits.allocate(pb, wordsize, FMT(annotation_prefix, " sum_bits"));
 
  104             pb, wordsize, FMT(annotation_prefix, " count_bits"));
 
  106         this->update_all_vars();
 
  109     std::shared_ptr<r1cs_pcd_message<FieldT>> get_message() const
 
  111         const size_t type_val = this->pb.val(this->type).as_ulong();
 
  112         const size_t sum_val =
 
  113             sum_bits.get_field_element_from_bits(this->pb).as_ulong();
 
  114         const size_t count_val =
 
  115             count_bits.get_field_element_from_bits(this->pb).as_ulong();
 
  117         std::shared_ptr<r1cs_pcd_message<FieldT>> result;
 
  118         result.reset(new tally_pcd_message<FieldT>(
 
  119             type_val, wordsize, sum_val, count_val));
 
  123     ~tally_pcd_message_variable() = default;
 
  126 template<typename FieldT>
 
  127 class tally_pcd_local_data_variable
 
  128     : public r1cs_pcd_local_data_variable<FieldT>
 
  131     pb_variable<FieldT> summand;
 
  133     tally_pcd_local_data_variable(
 
  134         protoboard<FieldT> &pb, const std::string &annotation_prefix)
 
  135         : r1cs_pcd_local_data_variable<FieldT>(pb, annotation_prefix)
 
  137         summand.allocate(pb, FMT(annotation_prefix, " summand"));
 
  139         this->update_all_vars();
 
  142     std::shared_ptr<r1cs_pcd_local_data<FieldT>> get_local_data() const
 
  144         const size_t summand_val = this->pb.val(summand).as_ulong();
 
  146         std::shared_ptr<r1cs_pcd_local_data<FieldT>> result;
 
  147         result.reset(new tally_pcd_local_data<FieldT>(summand_val));
 
  151     ~tally_pcd_local_data_variable() = default;
 
  154 template<typename FieldT>
 
  155 tally_cp_handler<FieldT>::tally_cp_handler(
 
  157     const size_t max_arity,
 
  158     const size_t wordsize,
 
  159     const bool relies_on_same_type_inputs,
 
  160     const std::set<size_t> accepted_input_types)
 
  161     : compliance_predicate_handler<FieldT, protoboard<FieldT>>(
 
  162           protoboard<FieldT>(),
 
  166           relies_on_same_type_inputs,
 
  167           accepted_input_types)
 
  170     this->outgoing_message.reset(new tally_pcd_message_variable<FieldT>(
 
  171         this->pb, wordsize, "outgoing_message"));
 
  172     this->arity.allocate(this->pb, "arity");
 
  174     for (size_t i = 0; i < max_arity; ++i) {
 
  175         this->incoming_messages[i].reset(new tally_pcd_message_variable<FieldT>(
 
  176             this->pb, wordsize, FMT("", "incoming_messages_%zu", i)));
 
  179     this->local_data.reset(
 
  180         new tally_pcd_local_data_variable<FieldT>(this->pb, "local_data"));
 
  182     sum_out_packed.allocate(this->pb, "sum_out_packed");
 
  183     count_out_packed.allocate(this->pb, "count_out_packed");
 
  185     sum_in_packed.allocate(this->pb, max_arity, "sum_in_packed");
 
  186     count_in_packed.allocate(this->pb, max_arity, "count_in_packed");
 
  188     sum_in_packed_aux.allocate(this->pb, max_arity, "sum_in_packed_aux");
 
  189     count_in_packed_aux.allocate(this->pb, max_arity, "count_in_packed_aux");
 
  191     type_val_inner_product.allocate(this->pb, "type_val_inner_product");
 
  192     for (auto &msg : this->incoming_messages) {
 
  193         incoming_types.emplace_back(msg->type);
 
  196     compute_type_val_inner_product.reset(new inner_product_gadget<FieldT>(
 
  200         type_val_inner_product,
 
  201         "compute_type_val_inner_product"));
 
  203     unpack_sum_out.reset(new packing_gadget<FieldT>(
 
  205         std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
 
  206             this->outgoing_message)
 
  210     unpack_count_out.reset(new packing_gadget<FieldT>(
 
  212         std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
 
  213             this->outgoing_message)
 
  218     for (size_t i = 0; i < max_arity; ++i) {
 
  219         pack_sum_in.emplace_back(packing_gadget<FieldT>(
 
  221             std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
 
  222                 this->incoming_messages[i])
 
  225             FMT("", "pack_sum_in_%zu", i)));
 
  226         pack_count_in.emplace_back(packing_gadget<FieldT>(
 
  228             std::dynamic_pointer_cast<tally_pcd_message_variable<FieldT>>(
 
  229                 this->incoming_messages[i])
 
  232             FMT("", "pack_count_in_%zu", i)));
 
  235     arity_indicators.allocate(this->pb, max_arity + 1, "arity_indicators");
 
  238 template<typename FieldT>
 
  239 void tally_cp_handler<FieldT>::generate_r1cs_constraints()
 
  241     unpack_sum_out->generate_r1cs_constraints(true);
 
  242     unpack_count_out->generate_r1cs_constraints(true);
 
  244     for (size_t i = 0; i < this->max_arity; ++i) {
 
  245         pack_sum_in[i].generate_r1cs_constraints(true);
 
  246         pack_count_in[i].generate_r1cs_constraints(true);
 
  249     for (size_t i = 0; i < this->max_arity; ++i) {
 
  250         this->pb.add_r1cs_constraint(
 
  251             r1cs_constraint<FieldT>(
 
  252                 incoming_types[i], sum_in_packed_aux[i], sum_in_packed[i]),
 
  253             FMT("", "initial_sum_%zu_is_zero", i));
 
  254         this->pb.add_r1cs_constraint(
 
  255             r1cs_constraint<FieldT>(
 
  256                 incoming_types[i], count_in_packed_aux[i], count_in_packed[i]),
 
  257             FMT("", "initial_sum_%zu_is_zero", i));
 
  260     /* constrain arity indicator variables so that arity_indicators[arity] = 1
 
  261      * and arity_indicators[i] = 0 for any other i */
 
  262     for (size_t i = 0; i < this->max_arity; ++i) {
 
  263         this->pb.add_r1cs_constraint(
 
  264             r1cs_constraint<FieldT>(
 
  265                 this->arity - FieldT(i), arity_indicators[i], 0),
 
  266             FMT("", "arity_indicators_%zu", i));
 
  269     this->pb.add_r1cs_constraint(
 
  270         r1cs_constraint<FieldT>(1, pb_sum<FieldT>(arity_indicators), 1),
 
  273     /* require that types of messages that are past arity (i.e. unbound wires)
 
  275     for (size_t i = 0; i < this->max_arity; ++i) {
 
  276         this->pb.add_r1cs_constraint(
 
  277             r1cs_constraint<FieldT>(
 
  278                 0 + pb_sum<FieldT>(pb_variable_array<FieldT>(
 
  279                         arity_indicators.begin(),
 
  280                         arity_indicators.begin() + i)),
 
  283             FMT("", "unbound_types_%zu", i));
 
  286     /* sum_out = local_data + \sum_i type[i] * sum_in[i] */
 
  287     compute_type_val_inner_product->generate_r1cs_constraints();
 
  288     this->pb.add_r1cs_constraint(
 
  289         r1cs_constraint<FieldT>(
 
  291             type_val_inner_product +
 
  292                 std::dynamic_pointer_cast<
 
  293                     tally_pcd_local_data_variable<FieldT>>(this->local_data)
 
  298     /* count_out = 1 + \sum_i count_in[i] */
 
  299     this->pb.add_r1cs_constraint(
 
  300         r1cs_constraint<FieldT>(
 
  301             1, 1 + pb_sum<FieldT>(count_in_packed), count_out_packed),
 
  305 template<typename FieldT>
 
  306 void tally_cp_handler<FieldT>::generate_r1cs_witness(
 
  307     const std::vector<std::shared_ptr<r1cs_pcd_message<FieldT>>>
 
  309     const std::shared_ptr<r1cs_pcd_local_data<FieldT>> &local_data)
 
  311     base_handler::generate_r1cs_witness(incoming_messages, local_data);
 
  313     for (size_t i = 0; i < this->max_arity; ++i) {
 
  314         pack_sum_in[i].generate_r1cs_witness_from_bits();
 
  315         pack_count_in[i].generate_r1cs_witness_from_bits();
 
  317         if (!this->pb.val(incoming_types[i]).is_zero()) {
 
  318             this->pb.val(sum_in_packed_aux[i]) =
 
  319                 this->pb.val(sum_in_packed[i]) *
 
  320                 this->pb.val(incoming_types[i]).inverse();
 
  321             this->pb.val(count_in_packed_aux[i]) =
 
  322                 this->pb.val(count_in_packed[i]) *
 
  323                 this->pb.val(incoming_types[i]).inverse();
 
  327     for (size_t i = 0; i < this->max_arity + 1; ++i) {
 
  328         this->pb.val(arity_indicators[i]) =
 
  329             (incoming_messages.size() == i ? FieldT::one() : FieldT::zero());
 
  332     compute_type_val_inner_product->generate_r1cs_witness();
 
  333     this->pb.val(sum_out_packed) =
 
  335             std::dynamic_pointer_cast<tally_pcd_local_data_variable<FieldT>>(
 
  338         this->pb.val(type_val_inner_product);
 
  340     this->pb.val(count_out_packed) = FieldT::one();
 
  341     for (size_t i = 0; i < this->max_arity; ++i) {
 
  342         this->pb.val(count_out_packed) += this->pb.val(count_in_packed[i]);
 
  345     unpack_sum_out->generate_r1cs_witness_from_packed();
 
  346     unpack_count_out->generate_r1cs_witness_from_packed();
 
  349 template<typename FieldT>
 
  350 std::shared_ptr<r1cs_pcd_message<FieldT>> tally_cp_handler<
 
  351     FieldT>::get_base_case_message() const
 
  353     const size_t type = 0;
 
  354     const size_t sum = 0;
 
  355     const size_t count = 0;
 
  357     std::shared_ptr<r1cs_pcd_message<FieldT>> result;
 
  358     result.reset(new tally_pcd_message<FieldT>(type, wordsize, sum, count));
 
  363 } // namespace libsnark
 
  365 #endif // TALLY_CP_TCC_