1 // Copyright (c) 2015-2022 Clearmatics Technologies Ltd
3 // SPDX-License-Identifier: LGPL-3.0+
5 #ifndef __ZETH_CIRCUITS_MIMC_ROUND_TCC__
6 #define __ZETH_CIRCUITS_MIMC_ROUND_TCC__
8 #include "libzeth/circuits/mimc/mimc_round.hpp"
13 template<typename FieldT, size_t Exponent>
14 void MiMC_round_gadget<FieldT, Exponent>::initialize()
16 // Each condition requires an intermediate variable, except the final one,
17 // which uses _result (and optionally _k).
18 exponents.resize(NUM_CONDITIONS - 1);
21 template<typename FieldT, size_t Exponent>
22 MiMC_round_gadget<FieldT, Exponent>::MiMC_round_gadget(
23 libsnark::protoboard<FieldT> &pb,
24 const libsnark::pb_linear_combination<FieldT> &msg,
25 const libsnark::pb_linear_combination<FieldT> &key,
26 const FieldT &round_const,
27 libsnark::pb_variable<FieldT> &result,
28 const std::string &annotation_prefix)
29 : libsnark::gadget<FieldT>(pb, annotation_prefix)
32 , round_const(round_const)
34 , add_to_result_is_valid(false)
39 template<typename FieldT, size_t Exponent>
40 MiMC_round_gadget<FieldT, Exponent>::MiMC_round_gadget(
41 libsnark::protoboard<FieldT> &pb,
42 const libsnark::pb_linear_combination<FieldT> &msg,
43 const libsnark::pb_linear_combination<FieldT> &key,
44 const FieldT &round_const,
45 libsnark::pb_variable<FieldT> &result,
46 const libsnark::pb_linear_combination<FieldT> &add_to_result,
47 const std::string &annotation_prefix)
48 : libsnark::gadget<FieldT>(pb, annotation_prefix)
51 , round_const(round_const)
53 , add_to_result(add_to_result)
54 , add_to_result_is_valid(true)
59 template<typename FieldT, size_t Exponent>
60 void MiMC_round_gadget<FieldT, Exponent>::generate_r1cs_constraints()
62 // Mask to capture the most significant bit (the "current" bit when
63 // iterating from most to least significant).
64 constexpr size_t mask = 1 << (EXPONENT_NUM_BITS - 1);
66 libsnark::pb_linear_combination<FieldT> t;
67 t.assign(this->pb, msg + key + round_const);
69 // For first bit (1 by definition) compute t^2
70 size_t exp = Exponent << 1;
71 exponents[0].allocate(
72 this->pb, FMT(this->annotation_prefix, " exponents[0]"));
73 this->pb.add_r1cs_constraint(
74 libsnark::r1cs_constraint<FieldT>(t, t, exponents[0]),
75 FMT(this->annotation_prefix, " calc_t^2"));
78 libsnark::pb_variable<FieldT> *last = &exponents[0];
80 // Square-and-multiply based on all bits up to the final (lowest-order) bit.
81 for (size_t i = 1; i < EXPONENT_NUM_BITS - 1; ++i) {
84 const size_t new_exp = exp >> (EXPONENT_NUM_BITS - 1);
85 exponents[exp_idx].allocate(
87 FMT(this->annotation_prefix, " exponents[%zu]", exp_idx));
88 this->pb.add_r1cs_constraint(
89 libsnark::r1cs_constraint<FieldT>(t, *last, exponents[exp_idx]),
90 FMT(this->annotation_prefix, " calc_t^%zu", new_exp));
91 last = &exponents[exp_idx];
96 const size_t new_exp = 2 * (exp >> (EXPONENT_NUM_BITS - 1));
97 exponents[exp_idx].allocate(
98 this->pb, FMT(this->annotation_prefix, " exponents[%zu]", exp_idx));
99 this->pb.add_r1cs_constraint(
100 libsnark::r1cs_constraint<FieldT>(*last, *last, exponents[exp_idx]),
101 FMT(this->annotation_prefix, " calc_t^%zu", new_exp));
102 last = &exponents[exp_idx];
105 // Shift to capture the next bit by mask.
108 assert(exp_idx == exponents.size());
110 // Final multiply (lowest-order bit is known to be 1),
111 if (add_to_result_is_valid) {
112 // addition of add_to_result:
113 // result = last * t + add_to_result
114 // <=> result - add_to_result = last * t
115 this->pb.add_r1cs_constraint(
116 libsnark::r1cs_constraint<FieldT>(*last, t, result - add_to_result),
117 FMT(this->annotation_prefix,
118 " calc_t^%zu_add_to_result",
121 this->pb.add_r1cs_constraint(
122 libsnark::r1cs_constraint<FieldT>(*last, t, result),
123 FMT(this->annotation_prefix, " calc_t^%zu", Exponent));
127 template<typename FieldT, size_t Exponent>
128 void MiMC_round_gadget<FieldT, Exponent>::generate_r1cs_witness() const
130 key.evaluate(this->pb);
131 msg.evaluate(this->pb);
133 constexpr size_t mask = 1 << (EXPONENT_NUM_BITS - 1);
134 const FieldT k_val = this->pb.lc_val(key);
135 const FieldT t = this->pb.lc_val(msg) + k_val + round_const;
137 // First intermediate variable has value t^2
138 size_t exp = Exponent << 1;
140 this->pb.val(exponents[0]) = v;
142 // Square-and-multiply remaining bits, except final one.
144 for (size_t i = 1; i < EXPONENT_NUM_BITS - 1; ++i) {
148 this->pb.val(exponents[var_idx++]) = v;
152 this->pb.val(exponents[var_idx++]) = v;
155 // v = v * t + add_to_result
157 if (add_to_result_is_valid) {
158 add_to_result.evaluate(this->pb);
159 v = v + this->pb.lc_val(add_to_result);
161 this->pb.val(result) = v;
164 } // namespace libzeth
166 #endif // __ZETH_CIRCUITS_MIMC_ROUND_TCC__