Clearmatics Libsnark  0.1
C++ library for zkSNARK proofs
alu_arithmetic.tcc
Go to the documentation of this file.
1 /** @file
2  *****************************************************************************
3 
4  Implementation of interfaces for the TinyRAM ALU arithmetic gadgets.
5 
6  See alu_arithmetic.hpp .
7 
8  *****************************************************************************
9  * @author This file is part of libsnark, developed by SCIPR Lab
10  * and contributors (see AUTHORS).
11  * @copyright MIT license (see LICENSE file)
12  *****************************************************************************/
13 
14 #ifndef ALU_ARITHMETIC_TCC_
15 #define ALU_ARITHMETIC_TCC_
16 
17 #include <functional>
18 #include <libff/common/profiling.hpp>
19 #include <libff/common/utils.hpp>
20 
21 namespace libsnark
22 {
23 
24 /* the code here is full of template lambda magic, but it is better to
25  have limited presence of such code than to have code duplication in
26  testing functions, which basically do the same thing: brute force
27  the range of inputs which different success predicates */
28 
29 template<class T, typename FieldT>
30 using initializer_fn = std::function<T *(
31  tinyram_protoboard<FieldT> &, // pb
32  pb_variable_array<FieldT> &, // opcode_indicators
33  word_variable_gadget<FieldT> &, // desval
34  word_variable_gadget<FieldT> &, // arg1val
35  word_variable_gadget<FieldT> &, // arg2val
36  pb_variable<FieldT> &, // flag
37  pb_variable<FieldT> &, // result
38  pb_variable<FieldT> & // result_flag
39  )>;
40 
41 template<class T, typename FieldT>
42 void brute_force_arithmetic_gadget(
43  const size_t w,
44  const size_t opcode,
45  initializer_fn<T, FieldT> initializer,
46  std::function<size_t(size_t, bool, size_t, size_t)> res_function,
47  std::function<bool(size_t, bool, size_t, size_t)> flag_function)
48 /* parameters for res_function and flag_function are both desval, flag, arg1val,
49  * arg2val */
50 {
51  printf("testing on all %zu bit inputs\n", w);
52 
53  tinyram_architecture_params ap(w, 16);
54  tinyram_program P;
55  P.instructions = generate_tinyram_prelude(ap);
56  tinyram_protoboard<FieldT> pb(ap, P.size(), 0, 10);
57 
58  pb_variable_array<FieldT> opcode_indicators;
59  opcode_indicators.allocate(
60  pb, 1ul << ap.opcode_width(), "opcode_indicators");
61  for (size_t i = 0; i < 1ul << ap.opcode_width(); ++i) {
62  pb.val(opcode_indicators[i]) =
63  (i == opcode ? FieldT::one() : FieldT::zero());
64  }
65 
66  word_variable_gadget<FieldT> desval(pb, "desval");
67  desval.generate_r1cs_constraints(true);
68  word_variable_gadget<FieldT> arg1val(pb, "arg1val");
69  arg1val.generate_r1cs_constraints(true);
70  word_variable_gadget<FieldT> arg2val(pb, "arg2val");
71  arg2val.generate_r1cs_constraints(true);
72  pb_variable<FieldT> flag;
73  flag.allocate(pb, "flag");
74  pb_variable<FieldT> result;
75  result.allocate(pb, "result");
76  pb_variable<FieldT> result_flag;
77  result_flag.allocate(pb, "result_flag");
78 
79  std::unique_ptr<T> g;
80  g.reset(initializer(
81  pb,
82  opcode_indicators,
83  desval,
84  arg1val,
85  arg2val,
86  flag,
87  result,
88  result_flag));
89  g->generate_r1cs_constraints();
90 
91  for (size_t des = 0; des < (1u << w); ++des) {
92  pb.val(desval.packed) = FieldT(des);
93  desval.generate_r1cs_witness_from_packed();
94 
95  for (char f = 0; f <= 1; ++f) {
96  pb.val(flag) = (f ? FieldT::one() : FieldT::zero());
97 
98  for (size_t arg1 = 0; arg1 < (1u << w); ++arg1) {
99  pb.val(arg1val.packed) = FieldT(arg1);
100  arg1val.generate_r1cs_witness_from_packed();
101 
102  for (size_t arg2 = 0; arg2 < (1u << w); ++arg2) {
103  pb.val(arg2val.packed) = FieldT(arg2);
104  arg2val.generate_r1cs_witness_from_packed();
105 
106  size_t res = res_function(des, f, arg1, arg2);
107  bool res_f = flag_function(des, f, arg1, arg2);
108 #ifdef DEBUG
109  printf(
110  "with the following parameters: flag = %d"
111  ", desval = %zu (%d)"
112  ", arg1val = %zu (%d)"
113  ", arg2val = %zu (%d)"
114  ". expected result: %zu (%d), expected flag: %d\n",
115  f,
116  des,
117  libff::from_twos_complement(des, w),
118  arg1,
119  libff::from_twos_complement(arg1, w),
120  arg2,
121  libff::from_twos_complement(arg2, w),
122  res,
123  libff::from_twos_complement(res, w),
124  res_f);
125 #else
126  libff::UNUSED(res);
127  libff::UNUSED(res_f);
128 #endif
129  g->generate_r1cs_witness();
130 #ifdef DEBUG
131  printf("result: ");
132  pb.val(result).print();
133  printf("flag: ");
134  pb.val(result_flag).print();
135 #endif
136  assert(pb.is_satisfied());
137  assert(pb.val(result) == FieldT(res));
138  assert(
139  pb.val(result_flag) ==
140  (res_f ? FieldT::one() : FieldT::zero()));
141  }
142  }
143  }
144  }
145 }
146 
147 /* and */
148 template<typename FieldT>
149 void ALU_and_gadget<FieldT>::generate_r1cs_constraints()
150 {
151  for (size_t i = 0; i < this->pb.ap.w; ++i) {
152  this->pb.add_r1cs_constraint(
153  r1cs_constraint<FieldT>(
154  {this->arg1val.bits[i]},
155  {this->arg2val.bits[i]},
156  {this->res_word[i]}),
157  FMT(this->annotation_prefix, " res_word_%zu", i));
158  }
159 
160  /* generate result */
161  pack_result->generate_r1cs_constraints(false);
162  not_all_zeros->generate_r1cs_constraints();
163 
164  /* result_flag = 1 - not_all_zeros = result is 0^w */
165  this->pb.add_r1cs_constraint(
166  r1cs_constraint<FieldT>(
167  {ONE},
168  {ONE, this->not_all_zeros_result * (-1)},
169  {this->result_flag}),
170  FMT(this->annotation_prefix, " result_flag"));
171 }
172 
173 template<typename FieldT> void ALU_and_gadget<FieldT>::generate_r1cs_witness()
174 {
175  for (size_t i = 0; i < this->pb.ap.w; ++i) {
176  bool b1 = this->pb.val(this->arg1val.bits[i]) == FieldT::one();
177  bool b2 = this->pb.val(this->arg2val.bits[i]) == FieldT::one();
178 
179  this->pb.val(this->res_word[i]) =
180  (b1 && b2 ? FieldT::one() : FieldT::zero());
181  }
182 
183  pack_result->generate_r1cs_witness_from_bits();
184  not_all_zeros->generate_r1cs_witness();
185  this->pb.val(this->result_flag) =
186  FieldT::one() - this->pb.val(not_all_zeros_result);
187 }
188 
189 template<typename FieldT> void test_ALU_and_gadget(const size_t w)
190 {
191  libff::print_time("starting and test");
192  brute_force_arithmetic_gadget<ALU_and_gadget<FieldT>, FieldT>(
193  w,
194  tinyram_opcode_AND,
195  [](tinyram_protoboard<FieldT> &pb,
196  pb_variable_array<FieldT> &opcode_indicators,
197  word_variable_gadget<FieldT> &desval,
198  word_variable_gadget<FieldT> &arg1val,
199  word_variable_gadget<FieldT> &arg2val,
200  pb_variable<FieldT> &flag,
201  pb_variable<FieldT> &result,
202  pb_variable<FieldT> &result_flag) -> ALU_and_gadget<FieldT> * {
203  return new ALU_and_gadget<FieldT>(
204  pb,
205  opcode_indicators,
206  desval,
207  arg1val,
208  arg2val,
209  flag,
210  result,
211  result_flag,
212  "ALU_and_gadget");
213  },
214  [w](size_t, bool, size_t x, size_t y) -> size_t { return x & y; },
215  [w](size_t, bool, size_t x, size_t y) -> bool { return (x & y) == 0; });
216  libff::print_time("and tests successful");
217 }
218 
219 /* or */
220 template<typename FieldT>
221 void ALU_or_gadget<FieldT>::generate_r1cs_constraints()
222 {
223  for (size_t i = 0; i < this->pb.ap.w; ++i) {
224  this->pb.add_r1cs_constraint(
225  r1cs_constraint<FieldT>(
226  {ONE, this->arg1val.bits[i] * (-1)},
227  {ONE, this->arg2val.bits[i] * (-1)},
228  {ONE, this->res_word[i] * (-1)}),
229  FMT(this->annotation_prefix, " res_word_%zu", i));
230  }
231 
232  /* generate result */
233  pack_result->generate_r1cs_constraints(false);
234  not_all_zeros->generate_r1cs_constraints();
235 
236  /* result_flag = 1 - not_all_zeros = result is 0^w */
237  this->pb.add_r1cs_constraint(
238  r1cs_constraint<FieldT>(
239  {ONE},
240  {ONE, this->not_all_zeros_result * (-1)},
241  {this->result_flag}),
242  FMT(this->annotation_prefix, " result_flag"));
243 }
244 
245 template<typename FieldT> void ALU_or_gadget<FieldT>::generate_r1cs_witness()
246 {
247  for (size_t i = 0; i < this->pb.ap.w; ++i) {
248  bool b1 = this->pb.val(this->arg1val.bits[i]) == FieldT::one();
249  bool b2 = this->pb.val(this->arg2val.bits[i]) == FieldT::one();
250 
251  this->pb.val(this->res_word[i]) =
252  (b1 || b2 ? FieldT::one() : FieldT::zero());
253  }
254 
255  pack_result->generate_r1cs_witness_from_bits();
256  not_all_zeros->generate_r1cs_witness();
257  this->pb.val(this->result_flag) =
258  FieldT::one() - this->pb.val(this->not_all_zeros_result);
259 }
260 
261 template<typename FieldT> void test_ALU_or_gadget(const size_t w)
262 {
263  libff::print_time("starting or test");
264  brute_force_arithmetic_gadget<ALU_or_gadget<FieldT>, FieldT>(
265  w,
266  tinyram_opcode_OR,
267  [](tinyram_protoboard<FieldT> &pb,
268  pb_variable_array<FieldT> &opcode_indicators,
269  word_variable_gadget<FieldT> &desval,
270  word_variable_gadget<FieldT> &arg1val,
271  word_variable_gadget<FieldT> &arg2val,
272  pb_variable<FieldT> &flag,
273  pb_variable<FieldT> &result,
274  pb_variable<FieldT> &result_flag) -> ALU_or_gadget<FieldT> * {
275  return new ALU_or_gadget<FieldT>(
276  pb,
277  opcode_indicators,
278  desval,
279  arg1val,
280  arg2val,
281  flag,
282  result,
283  result_flag,
284  "ALU_or_gadget");
285  },
286  [w](size_t, bool, size_t x, size_t y) -> size_t { return x | y; },
287  [w](size_t, bool, size_t x, size_t y) -> bool { return (x | y) == 0; });
288  libff::print_time("or tests successful");
289 }
290 
291 /* xor */
292 template<typename FieldT>
293 void ALU_xor_gadget<FieldT>::generate_r1cs_constraints()
294 {
295  for (size_t i = 0; i < this->pb.ap.w; ++i) {
296  /* a = b ^ c <=> a = b + c - 2*b*c, (2*b)*c = b+c - a */
297  this->pb.add_r1cs_constraint(
298  r1cs_constraint<FieldT>(
299  {this->arg1val.bits[i] * 2},
300  {this->arg2val.bits[i]},
301  {this->arg1val.bits[i],
302  this->arg2val.bits[i],
303  this->res_word[i] * (-1)}),
304  FMT(this->annotation_prefix, " res_word_%zu", i));
305  }
306 
307  /* generate result */
308  pack_result->generate_r1cs_constraints(false);
309  not_all_zeros->generate_r1cs_constraints();
310 
311  /* result_flag = 1 - not_all_zeros = result is 0^w */
312  this->pb.add_r1cs_constraint(
313  r1cs_constraint<FieldT>(
314  {ONE},
315  {ONE, this->not_all_zeros_result * (-1)},
316  {this->result_flag}),
317  FMT(this->annotation_prefix, " result_flag"));
318 }
319 
320 template<typename FieldT> void ALU_xor_gadget<FieldT>::generate_r1cs_witness()
321 {
322  for (size_t i = 0; i < this->pb.ap.w; ++i) {
323  bool b1 = this->pb.val(this->arg1val.bits[i]) == FieldT::one();
324  bool b2 = this->pb.val(this->arg2val.bits[i]) == FieldT::one();
325 
326  this->pb.val(this->res_word[i]) =
327  (b1 ^ b2 ? FieldT::one() : FieldT::zero());
328  }
329 
330  pack_result->generate_r1cs_witness_from_bits();
331  not_all_zeros->generate_r1cs_witness();
332  this->pb.val(this->result_flag) =
333  FieldT::one() - this->pb.val(this->not_all_zeros_result);
334 }
335 
336 template<typename FieldT> void test_ALU_xor_gadget(const size_t w)
337 {
338  libff::print_time("starting xor test");
339  brute_force_arithmetic_gadget<ALU_xor_gadget<FieldT>, FieldT>(
340  w,
341  tinyram_opcode_XOR,
342  [](tinyram_protoboard<FieldT> &pb,
343  pb_variable_array<FieldT> &opcode_indicators,
344  word_variable_gadget<FieldT> &desval,
345  word_variable_gadget<FieldT> &arg1val,
346  word_variable_gadget<FieldT> &arg2val,
347  pb_variable<FieldT> &flag,
348  pb_variable<FieldT> &result,
349  pb_variable<FieldT> &result_flag) -> ALU_xor_gadget<FieldT> * {
350  return new ALU_xor_gadget<FieldT>(
351  pb,
352  opcode_indicators,
353  desval,
354  arg1val,
355  arg2val,
356  flag,
357  result,
358  result_flag,
359  "ALU_xor_gadget");
360  },
361  [w](size_t, bool, size_t x, size_t y) -> size_t { return x ^ y; },
362  [w](size_t, bool, size_t x, size_t y) -> bool { return (x ^ y) == 0; });
363  libff::print_time("xor tests successful");
364 }
365 
366 /* not */
367 template<typename FieldT>
368 void ALU_not_gadget<FieldT>::generate_r1cs_constraints()
369 {
370  for (size_t i = 0; i < this->pb.ap.w; ++i) {
371  this->pb.add_r1cs_constraint(
372  r1cs_constraint<FieldT>(
373  {ONE},
374  {ONE, this->arg2val.bits[i] * (-1)},
375  {this->res_word[i]}),
376  FMT(this->annotation_prefix, " res_word_%zu", i));
377  }
378 
379  /* generate result */
380  pack_result->generate_r1cs_constraints(false);
381  not_all_zeros->generate_r1cs_constraints();
382 
383  /* result_flag = 1 - not_all_zeros = result is 0^w */
384  this->pb.add_r1cs_constraint(
385  r1cs_constraint<FieldT>(
386  {ONE},
387  {ONE, this->not_all_zeros_result * (-1)},
388  {this->result_flag}),
389  FMT(this->annotation_prefix, " result_flag"));
390 }
391 
392 template<typename FieldT> void ALU_not_gadget<FieldT>::generate_r1cs_witness()
393 {
394  for (size_t i = 0; i < this->pb.ap.w; ++i) {
395  bool b2 = this->pb.val(this->arg2val.bits[i]) == FieldT::one();
396 
397  this->pb.val(this->res_word[i]) =
398  (!b2 ? FieldT::one() : FieldT::zero());
399  }
400 
401  pack_result->generate_r1cs_witness_from_bits();
402  not_all_zeros->generate_r1cs_witness();
403  this->pb.val(this->result_flag) =
404  FieldT::one() - this->pb.val(this->not_all_zeros_result);
405 }
406 
407 template<typename FieldT> void test_ALU_not_gadget(const size_t w)
408 {
409  libff::print_time("starting not test");
410  brute_force_arithmetic_gadget<ALU_not_gadget<FieldT>, FieldT>(
411  w,
412  tinyram_opcode_NOT,
413  [](tinyram_protoboard<FieldT> &pb,
414  pb_variable_array<FieldT> &opcode_indicators,
415  word_variable_gadget<FieldT> &desval,
416  word_variable_gadget<FieldT> &arg1val,
417  word_variable_gadget<FieldT> &arg2val,
418  pb_variable<FieldT> &flag,
419  pb_variable<FieldT> &result,
420  pb_variable<FieldT> &result_flag) -> ALU_not_gadget<FieldT> * {
421  return new ALU_not_gadget<FieldT>(
422  pb,
423  opcode_indicators,
424  desval,
425  arg1val,
426  arg2val,
427  flag,
428  result,
429  result_flag,
430  "ALU_not_gadget");
431  },
432  [w](size_t, bool, size_t, size_t y) -> size_t {
433  return (1ul << w) - 1 - y;
434  },
435  [w](size_t, bool, size_t, size_t y) -> bool {
436  return ((1ul << w) - 1 - y) == 0;
437  });
438  libff::print_time("not tests successful");
439 }
440 
441 /* add */
442 template<typename FieldT>
443 void ALU_add_gadget<FieldT>::generate_r1cs_constraints()
444 {
445  /* addition_result = 1 * (arg1val + arg2val) */
446  this->pb.add_r1cs_constraint(
447  r1cs_constraint<FieldT>(
448  {ONE},
449  {this->arg1val.packed, this->arg2val.packed},
450  {this->addition_result}),
451  FMT(this->annotation_prefix, " addition_result"));
452 
453  /* unpack into bits */
454  unpack_addition->generate_r1cs_constraints(true);
455 
456  /* generate result */
457  pack_result->generate_r1cs_constraints(false);
458 }
459 
460 template<typename FieldT> void ALU_add_gadget<FieldT>::generate_r1cs_witness()
461 {
462  this->pb.val(addition_result) =
463  this->pb.val(this->arg1val.packed) + this->pb.val(this->arg2val.packed);
464  unpack_addition->generate_r1cs_witness_from_packed();
465  pack_result->generate_r1cs_witness_from_bits();
466 }
467 
468 template<typename FieldT> void test_ALU_add_gadget(const size_t w)
469 {
470  libff::print_time("starting add test");
471  brute_force_arithmetic_gadget<ALU_add_gadget<FieldT>, FieldT>(
472  w,
473  tinyram_opcode_ADD,
474  [](tinyram_protoboard<FieldT> &pb,
475  pb_variable_array<FieldT> &opcode_indicators,
476  word_variable_gadget<FieldT> &desval,
477  word_variable_gadget<FieldT> &arg1val,
478  word_variable_gadget<FieldT> &arg2val,
479  pb_variable<FieldT> &flag,
480  pb_variable<FieldT> &result,
481  pb_variable<FieldT> &result_flag) -> ALU_add_gadget<FieldT> * {
482  return new ALU_add_gadget<FieldT>(
483  pb,
484  opcode_indicators,
485  desval,
486  arg1val,
487  arg2val,
488  flag,
489  result,
490  result_flag,
491  "ALU_add_gadget");
492  },
493  [w](size_t, bool, size_t x, size_t y) -> size_t {
494  return (x + y) % (1ul << w);
495  },
496  [w](size_t, bool, size_t x, size_t y) -> bool {
497  return (x + y) >= (1ul << w);
498  });
499  libff::print_time("add tests successful");
500 }
501 
502 /* sub */
503 template<typename FieldT>
504 void ALU_sub_gadget<FieldT>::generate_r1cs_constraints()
505 {
506  /* intermediate_result = 2^w + (arg1val - arg2val) */
507  FieldT twoi = FieldT::one();
508 
509  linear_combination<FieldT> a, b, c;
510 
511  a.add_term(0, 1);
512  for (size_t i = 0; i < this->pb.ap.w; ++i) {
513  twoi = twoi + twoi;
514  }
515  b.add_term(0, twoi);
516  b.add_term(this->arg1val.packed, 1);
517  b.add_term(this->arg2val.packed, -1);
518  c.add_term(intermediate_result, 1);
519 
520  this->pb.add_r1cs_constraint(
521  r1cs_constraint<FieldT>(a, b, c),
522  FMT(this->annotation_prefix, " main_constraint"));
523 
524  /* unpack into bits */
525  unpack_intermediate->generate_r1cs_constraints(true);
526 
527  /* generate result */
528  pack_result->generate_r1cs_constraints(false);
529  this->pb.add_r1cs_constraint(
530  r1cs_constraint<FieldT>(
531  {ONE}, {ONE, this->negated_flag * (-1)}, {this->result_flag}),
532  FMT(this->annotation_prefix, " result_flag"));
533 }
534 
535 template<typename FieldT> void ALU_sub_gadget<FieldT>::generate_r1cs_witness()
536 {
537  FieldT twoi = FieldT::one();
538  for (size_t i = 0; i < this->pb.ap.w; ++i) {
539  twoi = twoi + twoi;
540  }
541 
542  this->pb.val(intermediate_result) = twoi +
543  this->pb.val(this->arg1val.packed) -
544  this->pb.val(this->arg2val.packed);
545  unpack_intermediate->generate_r1cs_witness_from_packed();
546  pack_result->generate_r1cs_witness_from_bits();
547  this->pb.val(this->result_flag) =
548  FieldT::one() - this->pb.val(this->negated_flag);
549 }
550 
551 template<typename FieldT> void test_ALU_sub_gadget(const size_t w)
552 {
553  libff::print_time("starting sub test");
554  brute_force_arithmetic_gadget<ALU_sub_gadget<FieldT>, FieldT>(
555  w,
556  tinyram_opcode_SUB,
557  [](tinyram_protoboard<FieldT> &pb,
558  pb_variable_array<FieldT> &opcode_indicators,
559  word_variable_gadget<FieldT> &desval,
560  word_variable_gadget<FieldT> &arg1val,
561  word_variable_gadget<FieldT> &arg2val,
562  pb_variable<FieldT> &flag,
563  pb_variable<FieldT> &result,
564  pb_variable<FieldT> &result_flag) -> ALU_sub_gadget<FieldT> * {
565  return new ALU_sub_gadget<FieldT>(
566  pb,
567  opcode_indicators,
568  desval,
569  arg1val,
570  arg2val,
571  flag,
572  result,
573  result_flag,
574  "ALU_sub_gadget");
575  },
576  [w](size_t, bool, size_t x, size_t y) -> size_t {
577  const size_t unsigned_result = ((1ul << w) + x - y) % (1ul << w);
578  return unsigned_result;
579  },
580  [w](size_t, bool, size_t x, size_t y) -> bool {
581  const size_t msb = ((1ul << w) + x - y) >> w;
582  return (msb == 0);
583  });
584  libff::print_time("sub tests successful");
585 }
586 
587 /* mov */
588 template<typename FieldT>
589 void ALU_mov_gadget<FieldT>::generate_r1cs_constraints()
590 {
591  this->pb.add_r1cs_constraint(
592  r1cs_constraint<FieldT>({ONE}, {this->arg2val.packed}, {this->result}),
593  FMT(this->annotation_prefix, " mov_result"));
594 
595  this->pb.add_r1cs_constraint(
596  r1cs_constraint<FieldT>({ONE}, {this->flag}, {this->result_flag}),
597  FMT(this->annotation_prefix, " mov_result_flag"));
598 }
599 
600 template<typename FieldT> void ALU_mov_gadget<FieldT>::generate_r1cs_witness()
601 {
602  this->pb.val(this->result) = this->pb.val(this->arg2val.packed);
603  this->pb.val(this->result_flag) = this->pb.val(this->flag);
604 }
605 
606 template<typename FieldT> void test_ALU_mov_gadget(const size_t w)
607 {
608  libff::print_time("starting mov test");
609  brute_force_arithmetic_gadget<ALU_mov_gadget<FieldT>, FieldT>(
610  w,
611  tinyram_opcode_MOV,
612  [](tinyram_protoboard<FieldT> &pb,
613  pb_variable_array<FieldT> &opcode_indicators,
614  word_variable_gadget<FieldT> &desval,
615  word_variable_gadget<FieldT> &arg1val,
616  word_variable_gadget<FieldT> &arg2val,
617  pb_variable<FieldT> &flag,
618  pb_variable<FieldT> &result,
619  pb_variable<FieldT> &result_flag) -> ALU_mov_gadget<FieldT> * {
620  return new ALU_mov_gadget<FieldT>(
621  pb,
622  opcode_indicators,
623  desval,
624  arg1val,
625  arg2val,
626  flag,
627  result,
628  result_flag,
629  "ALU_mov_gadget");
630  },
631  [w](size_t, bool, size_t, size_t y) -> size_t { return y; },
632  [w](size_t, bool f, size_t, size_t) -> bool { return f; });
633  libff::print_time("mov tests successful");
634 }
635 
636 /* cmov */
637 template<typename FieldT>
638 void ALU_cmov_gadget<FieldT>::generate_r1cs_constraints()
639 {
640  /*
641  flag1 * arg2val + (1-flag1) * desval = result
642  flag1 * (arg2val - desval) = result - desval
643  */
644  this->pb.add_r1cs_constraint(
645  r1cs_constraint<FieldT>(
646  {this->flag},
647  {this->arg2val.packed, this->desval.packed * (-1)},
648  {this->result, this->desval.packed * (-1)}),
649  FMT(this->annotation_prefix, " cmov_result"));
650 
651  this->pb.add_r1cs_constraint(
652  r1cs_constraint<FieldT>({ONE}, {this->flag}, {this->result_flag}),
653  FMT(this->annotation_prefix, " cmov_result_flag"));
654 }
655 
656 template<typename FieldT> void ALU_cmov_gadget<FieldT>::generate_r1cs_witness()
657 {
658  this->pb.val(this->result) =
659  ((this->pb.val(this->flag) == FieldT::one())
660  ? this->pb.val(this->arg2val.packed)
661  : this->pb.val(this->desval.packed));
662  this->pb.val(this->result_flag) = this->pb.val(this->flag);
663 }
664 
665 template<typename FieldT> void test_ALU_cmov_gadget(const size_t w)
666 {
667  libff::print_time("starting cmov test");
668  brute_force_arithmetic_gadget<ALU_cmov_gadget<FieldT>, FieldT>(
669  w,
670  tinyram_opcode_CMOV,
671  [](tinyram_protoboard<FieldT> &pb,
672  pb_variable_array<FieldT> &opcode_indicators,
673  word_variable_gadget<FieldT> &desval,
674  word_variable_gadget<FieldT> &arg1val,
675  word_variable_gadget<FieldT> &arg2val,
676  pb_variable<FieldT> &flag,
677  pb_variable<FieldT> &result,
678  pb_variable<FieldT> &result_flag) -> ALU_cmov_gadget<FieldT> * {
679  return new ALU_cmov_gadget<FieldT>(
680  pb,
681  opcode_indicators,
682  desval,
683  arg1val,
684  arg2val,
685  flag,
686  result,
687  result_flag,
688  "ALU_cmov_gadget");
689  },
690  [w](size_t des, bool f, size_t, size_t y) -> size_t {
691  return f ? y : des;
692  },
693  [w](size_t, bool f, size_t, size_t) -> bool { return f; });
694  libff::print_time("cmov tests successful");
695 }
696 
697 /* unsigned comparison */
698 template<typename FieldT>
699 void ALU_cmp_gadget<FieldT>::generate_r1cs_constraints()
700 {
701  comparator.generate_r1cs_constraints();
702  /*
703  cmpe = cmpae * (1-cmpa)
704  */
705  this->pb.add_r1cs_constraint(
706  r1cs_constraint<FieldT>(
707  {cmpae_result_flag},
708  {ONE, cmpa_result_flag * (-1)},
709  {cmpe_result_flag}),
710  FMT(this->annotation_prefix, " cmpa_result_flag"));
711 
712  /* copy over results */
713  this->pb.add_r1cs_constraint(
714  r1cs_constraint<FieldT>({ONE}, {this->desval.packed}, {cmpe_result}),
715  FMT(this->annotation_prefix, " cmpe_result"));
716 
717  this->pb.add_r1cs_constraint(
718  r1cs_constraint<FieldT>({ONE}, {this->desval.packed}, {cmpa_result}),
719  FMT(this->annotation_prefix, " cmpa_result"));
720 
721  this->pb.add_r1cs_constraint(
722  r1cs_constraint<FieldT>({ONE}, {this->desval.packed}, {cmpae_result}),
723  FMT(this->annotation_prefix, " cmpae_result"));
724 }
725 
726 template<typename FieldT> void ALU_cmp_gadget<FieldT>::generate_r1cs_witness()
727 {
728  comparator.generate_r1cs_witness();
729 
730  this->pb.val(cmpe_result) = this->pb.val(this->desval.packed);
731  this->pb.val(cmpa_result) = this->pb.val(this->desval.packed);
732  this->pb.val(cmpae_result) = this->pb.val(this->desval.packed);
733 
734  this->pb.val(cmpe_result_flag) =
735  ((this->pb.val(cmpae_result_flag) == FieldT::one()) &&
736  (this->pb.val(cmpa_result_flag) == FieldT::zero())
737  ? FieldT::one()
738  : FieldT::zero());
739 }
740 
741 template<typename FieldT> void test_ALU_cmpe_gadget(const size_t w)
742 {
743  libff::print_time("starting cmpe test");
744  brute_force_arithmetic_gadget<ALU_cmp_gadget<FieldT>, FieldT>(
745  w,
746  tinyram_opcode_CMPE,
747  [](tinyram_protoboard<FieldT> &pb,
748  pb_variable_array<FieldT> &opcode_indicators,
749  word_variable_gadget<FieldT> &desval,
750  word_variable_gadget<FieldT> &arg1val,
751  word_variable_gadget<FieldT> &arg2val,
752  pb_variable<FieldT> &flag,
753  pb_variable<FieldT> &result,
754  pb_variable<FieldT> &result_flag) -> ALU_cmp_gadget<FieldT> * {
755  pb_variable<FieldT> cmpa_result;
756  cmpa_result.allocate(pb, "cmpa_result");
757  pb_variable<FieldT> cmpa_result_flag;
758  cmpa_result_flag.allocate(pb, "cmpa_result_flag");
759  pb_variable<FieldT> cmpae_result;
760  cmpae_result.allocate(pb, "cmpae_result");
761  pb_variable<FieldT> cmpae_result_flag;
762  cmpae_result_flag.allocate(pb, "cmpae_result_flag");
763  return new ALU_cmp_gadget<FieldT>(
764  pb,
765  opcode_indicators,
766  desval,
767  arg1val,
768  arg2val,
769  flag,
770  result,
771  result_flag,
772  cmpa_result,
773  cmpa_result_flag,
774  cmpae_result,
775  cmpae_result_flag,
776  "ALU_cmp_gadget");
777  },
778  [w](size_t des, bool, size_t, size_t) -> size_t { return des; },
779  [w](size_t, bool, size_t x, size_t y) -> bool { return x == y; });
780  libff::print_time("cmpe tests successful");
781 }
782 
783 template<typename FieldT> void test_ALU_cmpa_gadget(const size_t w)
784 {
785  libff::print_time("starting cmpa test");
786  brute_force_arithmetic_gadget<ALU_cmp_gadget<FieldT>, FieldT>(
787  w,
788  tinyram_opcode_CMPA,
789  [](tinyram_protoboard<FieldT> &pb,
790  pb_variable_array<FieldT> &opcode_indicators,
791  word_variable_gadget<FieldT> &desval,
792  word_variable_gadget<FieldT> &arg1val,
793  word_variable_gadget<FieldT> &arg2val,
794  pb_variable<FieldT> &flag,
795  pb_variable<FieldT> &result,
796  pb_variable<FieldT> &result_flag) -> ALU_cmp_gadget<FieldT> * {
797  pb_variable<FieldT> cmpe_result;
798  cmpe_result.allocate(pb, "cmpe_result");
799  pb_variable<FieldT> cmpe_result_flag;
800  cmpe_result_flag.allocate(pb, "cmpe_result_flag");
801  pb_variable<FieldT> cmpae_result;
802  cmpae_result.allocate(pb, "cmpae_result");
803  pb_variable<FieldT> cmpae_result_flag;
804  cmpae_result_flag.allocate(pb, "cmpae_result_flag");
805  return new ALU_cmp_gadget<FieldT>(
806  pb,
807  opcode_indicators,
808  desval,
809  arg1val,
810  arg2val,
811  flag,
812  cmpe_result,
813  cmpe_result_flag,
814  result,
815  result_flag,
816  cmpae_result,
817  cmpae_result_flag,
818  "ALU_cmp_gadget");
819  },
820  [w](size_t des, bool, size_t, size_t) -> size_t { return des; },
821  [w](size_t, bool, size_t x, size_t y) -> bool { return x > y; });
822  libff::print_time("cmpa tests successful");
823 }
824 
825 template<typename FieldT> void test_ALU_cmpae_gadget(const size_t w)
826 {
827  libff::print_time("starting cmpae test");
828  brute_force_arithmetic_gadget<ALU_cmp_gadget<FieldT>, FieldT>(
829  w,
830  tinyram_opcode_CMPAE,
831  [](tinyram_protoboard<FieldT> &pb,
832  pb_variable_array<FieldT> &opcode_indicators,
833  word_variable_gadget<FieldT> &desval,
834  word_variable_gadget<FieldT> &arg1val,
835  word_variable_gadget<FieldT> &arg2val,
836  pb_variable<FieldT> &flag,
837  pb_variable<FieldT> &result,
838  pb_variable<FieldT> &result_flag) -> ALU_cmp_gadget<FieldT> * {
839  pb_variable<FieldT> cmpe_result;
840  cmpe_result.allocate(pb, "cmpe_result");
841  pb_variable<FieldT> cmpe_result_flag;
842  cmpe_result_flag.allocate(pb, "cmpe_result_flag");
843  pb_variable<FieldT> cmpa_result;
844  cmpa_result.allocate(pb, "cmpa_result");
845  pb_variable<FieldT> cmpa_result_flag;
846  cmpa_result_flag.allocate(pb, "cmpa_result_flag");
847  return new ALU_cmp_gadget<FieldT>(
848  pb,
849  opcode_indicators,
850  desval,
851  arg1val,
852  arg2val,
853  flag,
854  cmpe_result,
855  cmpe_result_flag,
856  cmpa_result,
857  cmpa_result_flag,
858  result,
859  result_flag,
860  "ALU_cmp_gadget");
861  },
862  [w](size_t des, bool, size_t, size_t) -> size_t { return des; },
863  [w](size_t, bool, size_t x, size_t y) -> bool { return x >= y; });
864  libff::print_time("cmpae tests successful");
865 }
866 
867 /* signed comparison */
868 template<typename FieldT>
869 void ALU_cmps_gadget<FieldT>::generate_r1cs_constraints()
870 {
871  /* negate sign bits */
872  this->pb.add_r1cs_constraint(
873  r1cs_constraint<FieldT>(
874  {ONE},
875  {ONE, this->arg1val.bits[this->pb.ap.w - 1] * (-1)},
876  {negated_arg1val_sign}),
877  FMT(this->annotation_prefix, " negated_arg1val_sign"));
878  this->pb.add_r1cs_constraint(
879  r1cs_constraint<FieldT>(
880  {ONE},
881  {ONE, this->arg2val.bits[this->pb.ap.w - 1] * (-1)},
882  {negated_arg2val_sign}),
883  FMT(this->annotation_prefix, " negated_arg2val_sign"));
884 
885  /* pack */
886  pack_modified_arg1->generate_r1cs_constraints(false);
887  pack_modified_arg2->generate_r1cs_constraints(false);
888 
889  /* compare */
890  comparator->generate_r1cs_constraints();
891 
892  /* copy over results */
893  this->pb.add_r1cs_constraint(
894  r1cs_constraint<FieldT>({ONE}, {this->desval.packed}, {cmpg_result}),
895  FMT(this->annotation_prefix, " cmpg_result"));
896 
897  this->pb.add_r1cs_constraint(
898  r1cs_constraint<FieldT>({ONE}, {this->desval.packed}, {cmpge_result}),
899  FMT(this->annotation_prefix, " cmpge_result"));
900 }
901 
902 template<typename FieldT> void ALU_cmps_gadget<FieldT>::generate_r1cs_witness()
903 {
904  /* negate sign bits */
905  this->pb.val(negated_arg1val_sign) =
906  FieldT::one() - this->pb.val(this->arg1val.bits[this->pb.ap.w - 1]);
907  this->pb.val(negated_arg2val_sign) =
908  FieldT::one() - this->pb.val(this->arg2val.bits[this->pb.ap.w - 1]);
909 
910  /* pack */
911  pack_modified_arg1->generate_r1cs_witness_from_bits();
912  pack_modified_arg2->generate_r1cs_witness_from_bits();
913 
914  /* produce result */
915  comparator->generate_r1cs_witness();
916 
917  this->pb.val(cmpg_result) = this->pb.val(this->desval.packed);
918  this->pb.val(cmpge_result) = this->pb.val(this->desval.packed);
919 }
920 
921 template<typename FieldT> void test_ALU_cmpg_gadget(const size_t w)
922 {
923  libff::print_time("starting cmpg test");
924  brute_force_arithmetic_gadget<ALU_cmps_gadget<FieldT>, FieldT>(
925  w,
926  tinyram_opcode_CMPG,
927  [](tinyram_protoboard<FieldT> &pb,
928  pb_variable_array<FieldT> &opcode_indicators,
929  word_variable_gadget<FieldT> &desval,
930  word_variable_gadget<FieldT> &arg1val,
931  word_variable_gadget<FieldT> &arg2val,
932  pb_variable<FieldT> &flag,
933  pb_variable<FieldT> &result,
934  pb_variable<FieldT> &result_flag) -> ALU_cmps_gadget<FieldT> * {
935  pb_variable<FieldT> cmpge_result;
936  cmpge_result.allocate(pb, "cmpge_result");
937  pb_variable<FieldT> cmpge_result_flag;
938  cmpge_result_flag.allocate(pb, "cmpge_result_flag");
939  return new ALU_cmps_gadget<FieldT>(
940  pb,
941  opcode_indicators,
942  desval,
943  arg1val,
944  arg2val,
945  flag,
946  result,
947  result_flag,
948  cmpge_result,
949  cmpge_result_flag,
950  "ALU_cmps_gadget");
951  },
952  [w](size_t des, bool, size_t, size_t) -> size_t { return des; },
953  [w](size_t, bool, size_t x, size_t y) -> bool {
954  return (
955  libff::from_twos_complement(x, w) >
956  libff::from_twos_complement(y, w));
957  });
958  libff::print_time("cmpg tests successful");
959 }
960 
961 template<typename FieldT> void test_ALU_cmpge_gadget(const size_t w)
962 {
963  libff::print_time("starting cmpge test");
964  brute_force_arithmetic_gadget<ALU_cmps_gadget<FieldT>, FieldT>(
965  w,
966  tinyram_opcode_CMPGE,
967  [](tinyram_protoboard<FieldT> &pb,
968  pb_variable_array<FieldT> &opcode_indicators,
969  word_variable_gadget<FieldT> &desval,
970  word_variable_gadget<FieldT> &arg1val,
971  word_variable_gadget<FieldT> &arg2val,
972  pb_variable<FieldT> &flag,
973  pb_variable<FieldT> &result,
974  pb_variable<FieldT> &result_flag) -> ALU_cmps_gadget<FieldT> * {
975  pb_variable<FieldT> cmpg_result;
976  cmpg_result.allocate(pb, "cmpg_result");
977  pb_variable<FieldT> cmpg_result_flag;
978  cmpg_result_flag.allocate(pb, "cmpg_result_flag");
979  return new ALU_cmps_gadget<FieldT>(
980  pb,
981  opcode_indicators,
982  desval,
983  arg1val,
984  arg2val,
985  flag,
986  cmpg_result,
987  cmpg_result_flag,
988  result,
989  result_flag,
990  "ALU_cmps_gadget");
991  },
992  [w](size_t des, bool, size_t, size_t) -> size_t { return des; },
993  [w](size_t, bool, size_t x, size_t y) -> bool {
994  return (
995  libff::from_twos_complement(x, w) >=
996  libff::from_twos_complement(y, w));
997  });
998  libff::print_time("cmpge tests successful");
999 }
1000 
1001 template<typename FieldT>
1002 void ALU_umul_gadget<FieldT>::generate_r1cs_constraints()
1003 {
1004  /* do multiplication */
1005  this->pb.add_r1cs_constraint(
1006  r1cs_constraint<FieldT>(
1007  {this->arg1val.packed},
1008  {this->arg2val.packed},
1009  {mul_result.packed}),
1010  FMT(this->annotation_prefix, " main_constraint"));
1011  mul_result.generate_r1cs_constraints(true);
1012 
1013  /* pack result */
1014  pack_mull_result->generate_r1cs_constraints(false);
1015  pack_umulh_result->generate_r1cs_constraints(false);
1016 
1017  /* compute flag */
1018  compute_flag->generate_r1cs_constraints();
1019 
1020  this->pb.add_r1cs_constraint(
1021  r1cs_constraint<FieldT>({ONE}, {this->result_flag}, {mull_flag}),
1022  FMT(this->annotation_prefix, " mull_flag"));
1023 
1024  this->pb.add_r1cs_constraint(
1025  r1cs_constraint<FieldT>({ONE}, {this->result_flag}, {umulh_flag}),
1026  FMT(this->annotation_prefix, " umulh_flag"));
1027 }
1028 
1029 template<typename FieldT> void ALU_umul_gadget<FieldT>::generate_r1cs_witness()
1030 {
1031  /* do multiplication */
1032  this->pb.val(mul_result.packed) =
1033  this->pb.val(this->arg1val.packed) * this->pb.val(this->arg2val.packed);
1034  mul_result.generate_r1cs_witness_from_packed();
1035 
1036  /* pack result */
1037  pack_mull_result->generate_r1cs_witness_from_bits();
1038  pack_umulh_result->generate_r1cs_witness_from_bits();
1039 
1040  /* compute flag */
1041  compute_flag->generate_r1cs_witness();
1042 
1043  this->pb.val(mull_flag) = this->pb.val(this->result_flag);
1044  this->pb.val(umulh_flag) = this->pb.val(this->result_flag);
1045 }
1046 
1047 template<typename FieldT> void test_ALU_mull_gadget(const size_t w)
1048 {
1049  libff::print_time("starting mull test");
1050  brute_force_arithmetic_gadget<ALU_umul_gadget<FieldT>, FieldT>(
1051  w,
1052  tinyram_opcode_MULL,
1053  [](tinyram_protoboard<FieldT> &pb,
1054  pb_variable_array<FieldT> &opcode_indicators,
1055  word_variable_gadget<FieldT> &desval,
1056  word_variable_gadget<FieldT> &arg1val,
1057  word_variable_gadget<FieldT> &arg2val,
1058  pb_variable<FieldT> &flag,
1059  pb_variable<FieldT> &result,
1060  pb_variable<FieldT> &result_flag) -> ALU_umul_gadget<FieldT> * {
1061  pb_variable<FieldT> umulh_result;
1062  umulh_result.allocate(pb, "umulh_result");
1063  pb_variable<FieldT> umulh_flag;
1064  umulh_flag.allocate(pb, "umulh_flag");
1065  return new ALU_umul_gadget<FieldT>(
1066  pb,
1067  opcode_indicators,
1068  desval,
1069  arg1val,
1070  arg2val,
1071  flag,
1072  result,
1073  result_flag,
1074  umulh_result,
1075  umulh_flag,
1076  "ALU_umul_gadget");
1077  },
1078  [w](size_t, bool, size_t x, size_t y) -> size_t {
1079  return (x * y) % (1ul << w);
1080  },
1081  [w](size_t, bool, size_t x, size_t y) -> bool {
1082  return ((x * y) >> w) != 0;
1083  });
1084  libff::print_time("mull tests successful");
1085 }
1086 
1087 template<typename FieldT> void test_ALU_umulh_gadget(const size_t w)
1088 {
1089  libff::print_time("starting umulh test");
1090  brute_force_arithmetic_gadget<ALU_umul_gadget<FieldT>, FieldT>(
1091  w,
1092  tinyram_opcode_UMULH,
1093  [](tinyram_protoboard<FieldT> &pb,
1094  pb_variable_array<FieldT> &opcode_indicators,
1095  word_variable_gadget<FieldT> &desval,
1096  word_variable_gadget<FieldT> &arg1val,
1097  word_variable_gadget<FieldT> &arg2val,
1098  pb_variable<FieldT> &flag,
1099  pb_variable<FieldT> &result,
1100  pb_variable<FieldT> &result_flag) -> ALU_umul_gadget<FieldT> * {
1101  pb_variable<FieldT> mull_result;
1102  mull_result.allocate(pb, "mull_result");
1103  pb_variable<FieldT> mull_flag;
1104  mull_flag.allocate(pb, "mull_flag");
1105  return new ALU_umul_gadget<FieldT>(
1106  pb,
1107  opcode_indicators,
1108  desval,
1109  arg1val,
1110  arg2val,
1111  flag,
1112  mull_result,
1113  mull_flag,
1114  result,
1115  result_flag,
1116  "ALU_umul_gadget");
1117  },
1118  [w](size_t, bool, size_t x, size_t y) -> size_t {
1119  return (x * y) >> w;
1120  },
1121  [w](size_t, bool, size_t x, size_t y) -> bool {
1122  return ((x * y) >> w) != 0;
1123  });
1124  libff::print_time("umulh tests successful");
1125 }
1126 
1127 template<typename FieldT>
1128 void ALU_smul_gadget<FieldT>::generate_r1cs_constraints()
1129 {
1130  /* do multiplication */
1131  /*
1132  from two's complement: (packed - 2^w * bits[w-1])
1133  to two's complement: lower order bits of 2^{2w} + result_of_*
1134  */
1135 
1136  linear_combination<FieldT> a, b, c;
1137  a.add_term(this->arg1val.packed, 1);
1138  a.add_term(
1139  this->arg1val.bits[this->pb.ap.w - 1], -(FieldT(2) ^ this->pb.ap.w));
1140  b.add_term(this->arg2val.packed, 1);
1141  b.add_term(
1142  this->arg2val.bits[this->pb.ap.w - 1], -(FieldT(2) ^ this->pb.ap.w));
1143  c.add_term(mul_result.packed, 1);
1144  c.add_term(ONE, -(FieldT(2) ^ (2 * this->pb.ap.w)));
1145  this->pb.add_r1cs_constraint(
1146  r1cs_constraint<FieldT>(a, b, c),
1147  FMT(this->annotation_prefix, " main_constraint"));
1148 
1149  mul_result.generate_r1cs_constraints(true);
1150 
1151  /* pack result */
1152  pack_smulh_result->generate_r1cs_constraints(false);
1153 
1154  /* compute flag */
1155  pack_top->generate_r1cs_constraints(false);
1156 
1157  /*
1158  the gadgets below are FieldT specific:
1159  I * X = (1-R)
1160  R * X = 0
1161 
1162  if X = 0 then R = 1
1163  if X != 0 then R = 0 and I = X^{-1}
1164  */
1165  this->pb.add_r1cs_constraint(
1166  r1cs_constraint<FieldT>(
1167  {is_top_empty_aux}, {top}, {ONE, is_top_empty * (-1)}),
1168  FMT(this->annotation_prefix, " I*X=1-R (is_top_empty)"));
1169  this->pb.add_r1cs_constraint(
1170  r1cs_constraint<FieldT>({is_top_empty}, {top}, {ONE * 0}),
1171  FMT(this->annotation_prefix, " R*X=0 (is_top_full)"));
1172 
1173  this->pb.add_r1cs_constraint(
1174  r1cs_constraint<FieldT>(
1175  {is_top_full_aux},
1176  {top, ONE * (1l - (1ul << (this->pb.ap.w + 1)))},
1177  {ONE, is_top_full * (-1)}),
1178  FMT(this->annotation_prefix, " I*X=1-R (is_top_full)"));
1179  this->pb.add_r1cs_constraint(
1180  r1cs_constraint<FieldT>(
1181  {is_top_full},
1182  {top, ONE * (1l - (1ul << (this->pb.ap.w + 1)))},
1183  {ONE * 0}),
1184  FMT(this->annotation_prefix, " R*X=0 (is_top_full)"));
1185 
1186  /* smulh_flag = 1 - (is_top_full + is_top_empty) */
1187  this->pb.add_r1cs_constraint(
1188  r1cs_constraint<FieldT>(
1189  {ONE},
1190  {ONE, is_top_full * (-1), is_top_empty * (-1)},
1191  {smulh_flag}),
1192  FMT(this->annotation_prefix, " smulh_flag"));
1193 }
1194 
1195 template<typename FieldT> void ALU_smul_gadget<FieldT>::generate_r1cs_witness()
1196 {
1197  /* do multiplication */
1198  /*
1199  from two's complement: (packed - 2^w * bits[w-1])
1200  to two's complement: lower order bits of (2^{2w} + result_of_mul)
1201  */
1202  this->pb.val(mul_result.packed) =
1203  (this->pb.val(this->arg1val.packed) -
1204  (this->pb.val(this->arg1val.bits[this->pb.ap.w - 1]) *
1205  (FieldT(2) ^ this->pb.ap.w))) *
1206  (this->pb.val(this->arg2val.packed) -
1207  (this->pb.val(this->arg2val.bits[this->pb.ap.w - 1]) *
1208  (FieldT(2) ^ this->pb.ap.w))) +
1209  (FieldT(2) ^ (2 * this->pb.ap.w));
1210 
1211  mul_result.generate_r1cs_witness_from_packed();
1212 
1213  /* pack result */
1214  pack_smulh_result->generate_r1cs_witness_from_bits();
1215 
1216  /* compute flag */
1217  pack_top->generate_r1cs_witness_from_bits();
1218  size_t topval = this->pb.val(top).as_ulong();
1219 
1220  if (topval == 0) {
1221  this->pb.val(is_top_empty) = FieldT::one();
1222  this->pb.val(is_top_empty_aux) = FieldT::zero();
1223  } else {
1224  this->pb.val(is_top_empty) = FieldT::zero();
1225  this->pb.val(is_top_empty_aux) = this->pb.val(top).inverse();
1226  }
1227 
1228  if (topval == ((1ul << (this->pb.ap.w + 1)) - 1)) {
1229  this->pb.val(is_top_full) = FieldT::one();
1230  this->pb.val(is_top_full_aux) = FieldT::zero();
1231  } else {
1232  this->pb.val(is_top_full) = FieldT::zero();
1233  this->pb.val(is_top_full_aux) =
1234  (this->pb.val(top) - FieldT((1ul << (this->pb.ap.w + 1)) - 1))
1235  .inverse();
1236  }
1237 
1238  /* smulh_flag = 1 - (is_top_full + is_top_empty) */
1239  this->pb.val(smulh_flag) = FieldT::one() - (this->pb.val(is_top_full) +
1240  this->pb.val(is_top_empty));
1241 }
1242 
1243 template<typename FieldT> void test_ALU_smulh_gadget(const size_t w)
1244 {
1245  libff::print_time("starting smulh test");
1246  brute_force_arithmetic_gadget<ALU_smul_gadget<FieldT>, FieldT>(
1247  w,
1248  tinyram_opcode_SMULH,
1249  [](tinyram_protoboard<FieldT> &pb,
1250  pb_variable_array<FieldT> &opcode_indicators,
1251  word_variable_gadget<FieldT> &desval,
1252  word_variable_gadget<FieldT> &arg1val,
1253  word_variable_gadget<FieldT> &arg2val,
1254  pb_variable<FieldT> &flag,
1255  pb_variable<FieldT> &result,
1256  pb_variable<FieldT> &result_flag) -> ALU_smul_gadget<FieldT> * {
1257  return new ALU_smul_gadget<FieldT>(
1258  pb,
1259  opcode_indicators,
1260  desval,
1261  arg1val,
1262  arg2val,
1263  flag,
1264  result,
1265  result_flag,
1266  "ALU_smul_gadget");
1267  },
1268  [w](size_t, bool, size_t x, size_t y) -> size_t {
1269  const size_t res = libff::to_twos_complement(
1270  (libff::from_twos_complement(x, w) *
1271  libff::from_twos_complement(y, w)),
1272  2 * w);
1273  return res >> w;
1274  },
1275  [w](size_t, bool, size_t x, size_t y) -> bool {
1276  const int res = libff::from_twos_complement(x, w) *
1277  libff::from_twos_complement(y, w);
1278  const int truncated_res = libff::from_twos_complement(
1279  libff::to_twos_complement(res, 2 * w) & ((1ul << w) - 1), w);
1280  return (res != truncated_res);
1281  });
1282  libff::print_time("smulh tests successful");
1283 }
1284 
1285 template<typename FieldT>
1286 void ALU_divmod_gadget<FieldT>::generate_r1cs_constraints()
1287 {
1288  /* B_inv * B = B_nonzero */
1289  linear_combination<FieldT> a1, b1, c1;
1290  a1.add_term(B_inv, 1);
1291  b1.add_term(this->arg2val.packed, 1);
1292  c1.add_term(B_nonzero, 1);
1293 
1294  this->pb.add_r1cs_constraint(
1295  r1cs_constraint<FieldT>(a1, b1, c1),
1296  FMT(this->annotation_prefix, " B_inv*B=B_nonzero"));
1297 
1298  /* (1-B_nonzero) * B = 0 */
1299  linear_combination<FieldT> a2, b2, c2;
1300  a2.add_term(ONE, 1);
1301  a2.add_term(B_nonzero, -1);
1302  b2.add_term(this->arg2val.packed, 1);
1303  c2.add_term(ONE, 0);
1304 
1305  this->pb.add_r1cs_constraint(
1306  r1cs_constraint<FieldT>(a2, b2, c2),
1307  FMT(this->annotation_prefix, " (1-B_nonzero)*B=0"));
1308 
1309  /* B * q + r = A_aux = A * B_nonzero */
1310  linear_combination<FieldT> a3, b3, c3;
1311  a3.add_term(this->arg2val.packed, 1);
1312  b3.add_term(udiv_result, 1);
1313  c3.add_term(A_aux, 1);
1314  c3.add_term(umod_result, -1);
1315 
1316  this->pb.add_r1cs_constraint(
1317  r1cs_constraint<FieldT>(a3, b3, c3),
1318  FMT(this->annotation_prefix, " B*q+r=A_aux"));
1319 
1320  linear_combination<FieldT> a4, b4, c4;
1321  a4.add_term(this->arg1val.packed, 1);
1322  b4.add_term(B_nonzero, 1);
1323  c4.add_term(A_aux, 1);
1324 
1325  this->pb.add_r1cs_constraint(
1326  r1cs_constraint<FieldT>(a4, b4, c4),
1327  FMT(this->annotation_prefix, " A_aux=A*B_nonzero"));
1328 
1329  /* q * (1-B_nonzero) = 0 */
1330  linear_combination<FieldT> a5, b5, c5;
1331  a5.add_term(udiv_result, 1);
1332  b5.add_term(ONE, 1);
1333  b5.add_term(B_nonzero, -1);
1334  c5.add_term(ONE, 0);
1335 
1336  this->pb.add_r1cs_constraint(
1337  r1cs_constraint<FieldT>(a5, b5, c5),
1338  FMT(this->annotation_prefix, " q*B_nonzero=0"));
1339 
1340  /* A<B_gadget<FieldT>(B, r, less=B_nonzero, leq=ONE) */
1341  r_less_B->generate_r1cs_constraints();
1342 }
1343 
1344 template<typename FieldT>
1345 void ALU_divmod_gadget<FieldT>::generate_r1cs_witness()
1346 {
1347  if (this->pb.val(this->arg2val.packed) == FieldT::zero()) {
1348  this->pb.val(B_inv) = FieldT::zero();
1349  this->pb.val(B_nonzero) = FieldT::zero();
1350 
1351  this->pb.val(A_aux) = FieldT::zero();
1352 
1353  this->pb.val(udiv_result) = FieldT::zero();
1354  this->pb.val(umod_result) = FieldT::zero();
1355 
1356  this->pb.val(udiv_flag) = FieldT::one();
1357  this->pb.val(umod_flag) = FieldT::one();
1358  } else {
1359  this->pb.val(B_inv) = this->pb.val(this->arg2val.packed).inverse();
1360  this->pb.val(B_nonzero) = FieldT::one();
1361 
1362  const size_t A = this->pb.val(this->arg1val.packed).as_ulong();
1363  const size_t B = this->pb.val(this->arg2val.packed).as_ulong();
1364 
1365  this->pb.val(A_aux) = this->pb.val(this->arg1val.packed);
1366 
1367  this->pb.val(udiv_result) = FieldT(A / B);
1368  this->pb.val(umod_result) = FieldT(A % B);
1369 
1370  this->pb.val(udiv_flag) = FieldT::zero();
1371  this->pb.val(umod_flag) = FieldT::zero();
1372  }
1373 
1374  r_less_B->generate_r1cs_witness();
1375 }
1376 
1377 template<typename FieldT> void test_ALU_udiv_gadget(const size_t w)
1378 {
1379  libff::print_time("starting udiv test");
1380  brute_force_arithmetic_gadget<ALU_divmod_gadget<FieldT>, FieldT>(
1381  w,
1382  tinyram_opcode_UDIV,
1383  [](tinyram_protoboard<FieldT> &pb,
1384  pb_variable_array<FieldT> &opcode_indicators,
1385  word_variable_gadget<FieldT> &desval,
1386  word_variable_gadget<FieldT> &arg1val,
1387  word_variable_gadget<FieldT> &arg2val,
1388  pb_variable<FieldT> &flag,
1389  pb_variable<FieldT> &result,
1390  pb_variable<FieldT> &result_flag) -> ALU_divmod_gadget<FieldT> * {
1391  pb_variable<FieldT> umod_result;
1392  umod_result.allocate(pb, "umod_result");
1393  pb_variable<FieldT> umod_flag;
1394  umod_flag.allocate(pb, "umod_flag");
1395  return new ALU_divmod_gadget<FieldT>(
1396  pb,
1397  opcode_indicators,
1398  desval,
1399  arg1val,
1400  arg2val,
1401  flag,
1402  result,
1403  result_flag,
1404  umod_result,
1405  umod_flag,
1406  "ALU_divmod_gadget");
1407  },
1408  [w](size_t, bool, size_t x, size_t y) -> size_t {
1409  return (y == 0 ? 0 : x / y);
1410  },
1411  [w](size_t, bool, size_t, size_t y) -> bool { return (y == 0); });
1412  libff::print_time("udiv tests successful");
1413 }
1414 
1415 template<typename FieldT> void test_ALU_umod_gadget(const size_t w)
1416 {
1417  libff::print_time("starting umod test");
1418  brute_force_arithmetic_gadget<ALU_divmod_gadget<FieldT>, FieldT>(
1419  w,
1420  tinyram_opcode_UMOD,
1421  [](tinyram_protoboard<FieldT> &pb,
1422  pb_variable_array<FieldT> &opcode_indicators,
1423  word_variable_gadget<FieldT> &desval,
1424  word_variable_gadget<FieldT> &arg1val,
1425  word_variable_gadget<FieldT> &arg2val,
1426  pb_variable<FieldT> &flag,
1427  pb_variable<FieldT> &result,
1428  pb_variable<FieldT> &result_flag) -> ALU_divmod_gadget<FieldT> * {
1429  pb_variable<FieldT> udiv_result;
1430  udiv_result.allocate(pb, "udiv_result");
1431  pb_variable<FieldT> udiv_flag;
1432  udiv_flag.allocate(pb, "udiv_flag");
1433  return new ALU_divmod_gadget<FieldT>(
1434  pb,
1435  opcode_indicators,
1436  desval,
1437  arg1val,
1438  arg2val,
1439  flag,
1440  udiv_result,
1441  udiv_flag,
1442  result,
1443  result_flag,
1444  "ALU_divmod_gadget");
1445  },
1446  [w](size_t, bool, size_t x, size_t y) -> size_t {
1447  return (y == 0 ? 0 : x % y);
1448  },
1449  [w](size_t, bool, size_t, size_t y) -> bool { return (y == 0); });
1450  libff::print_time("umod tests successful");
1451 }
1452 
1453 template<typename FieldT>
1454 void ALU_shr_shl_gadget<FieldT>::generate_r1cs_constraints()
1455 {
1456  /*
1457  select the input for barrel shifter:
1458 
1459  r = arg1val * opcode_indicators[SHR] + reverse(arg1val) *
1460  (1-opcode_indicators[SHR]) r - reverse(arg1val) = (arg1val -
1461  reverse(arg1val)) * opcode_indicators[SHR]
1462  */
1463  pack_reversed_input->generate_r1cs_constraints(false);
1464 
1465  this->pb.add_r1cs_constraint(
1466  r1cs_constraint<FieldT>(
1467  {this->arg1val.packed, reversed_input * (-1)},
1468  {this->opcode_indicators[tinyram_opcode_SHR]},
1469  {barrel_right_internal[0], reversed_input * (-1)}),
1470  FMT(this->annotation_prefix, " select_arg1val_or_reversed"));
1471 
1472  /*
1473  do logw iterations of barrel shifts
1474  */
1475  for (size_t i = 0; i < logw; ++i) {
1476  /* assert that shifted out part is bits */
1477  for (size_t j = 0; j < 1ul << i; ++j) {
1478  generate_boolean_r1cs_constraint<FieldT>(
1479  this->pb,
1480  shifted_out_bits[i][j],
1481  FMT(this->annotation_prefix,
1482  " shifted_out_bits_%zu_%zu",
1483  i,
1484  j));
1485  }
1486 
1487  /*
1488  add main shifting constraint
1489 
1490 
1491  old_result =
1492  (shifted_result * 2^(i+1) + shifted_out_part) * need_to_shift +
1493  (shfited_result) * (1-need_to_shift)
1494 
1495  old_result - shifted_result = (shifted_result * (2^(i+1) - 1) +
1496  shifted_out_part) * need_to_shift
1497  */
1498  linear_combination<FieldT> a, b, c;
1499 
1500  a.add_term(
1501  barrel_right_internal[i + 1],
1502  (FieldT(2) ^ (i + 1)) - FieldT::one());
1503  for (size_t j = 0; j < 1ul << i; ++j) {
1504  a.add_term(shifted_out_bits[i][j], (FieldT(2) ^ j));
1505  }
1506 
1507  b.add_term(this->arg2val.bits[i], 1);
1508 
1509  c.add_term(barrel_right_internal[i], 1);
1510  c.add_term(barrel_right_internal[i + 1], -1);
1511 
1512  this->pb.add_r1cs_constraint(
1513  r1cs_constraint<FieldT>(a, b, c),
1514  FMT(this->annotation_prefix, " barrel_shift_%zu", i));
1515  }
1516 
1517  /*
1518  get result as the logw iterations or zero if shift was oversized
1519 
1520  result = (1-is_oversize_shift) * barrel_right_internal[logw]
1521  */
1522  check_oversize_shift->generate_r1cs_constraints();
1523  this->pb.add_r1cs_constraint(
1524  r1cs_constraint<FieldT>(
1525  {ONE, is_oversize_shift * (-1)},
1526  {barrel_right_internal[logw]},
1527  {this->result}),
1528  FMT(this->annotation_prefix, " result"));
1529 
1530  /*
1531  get reversed result for SHL
1532  */
1533  unpack_result->generate_r1cs_constraints(true);
1534  pack_reversed_result->generate_r1cs_constraints(false);
1535 
1536  /*
1537  select the correct output:
1538  r = result * opcode_indicators[SHR] + reverse(result) *
1539  (1-opcode_indicators[SHR]) r - reverse(result) = (result -
1540  reverse(result)) * opcode_indicators[SHR]
1541  */
1542  this->pb.add_r1cs_constraint(
1543  r1cs_constraint<FieldT>(
1544  {this->result, reversed_result * (-1)},
1545  {this->opcode_indicators[tinyram_opcode_SHR]},
1546  {shr_result, reversed_result * (-1)}),
1547  FMT(this->annotation_prefix, " shr_result"));
1548 
1549  this->pb.add_r1cs_constraint(
1550  r1cs_constraint<FieldT>(
1551  {this->result, reversed_result * (-1)},
1552  {this->opcode_indicators[tinyram_opcode_SHR]},
1553  {shr_result, reversed_result * (-1)}),
1554  FMT(this->annotation_prefix, " shl_result"));
1555 
1556  this->pb.add_r1cs_constraint(
1557  r1cs_constraint<FieldT>({ONE}, {this->arg1val.bits[0]}, {shr_flag}),
1558  FMT(this->annotation_prefix, " shr_flag"));
1559 
1560  this->pb.add_r1cs_constraint(
1561  r1cs_constraint<FieldT>(
1562  {ONE}, {this->arg1val.bits[this->pb.ap.w - 1]}, {shl_flag}),
1563  FMT(this->annotation_prefix, " shl_flag"));
1564 }
1565 
1566 template<typename FieldT>
1567 void ALU_shr_shl_gadget<FieldT>::generate_r1cs_witness()
1568 {
1569  /* select the input for barrel shifter */
1570  pack_reversed_input->generate_r1cs_witness_from_bits();
1571 
1572  this->pb.val(barrel_right_internal[0]) =
1573  (this->pb.val(this->opcode_indicators[tinyram_opcode_SHR]) ==
1574  FieldT::one()
1575  ? this->pb.val(this->arg1val.packed)
1576  : this->pb.val(reversed_input));
1577 
1578  /*
1579  do logw iterations of barrel shifts.
1580 
1581  old_result =
1582  (shifted_result * 2^i + shifted_out_part) * need_to_shift +
1583  (shfited_result) * (1-need_to_shift)
1584  */
1585 
1586  for (size_t i = 0; i < logw; ++i) {
1587  this->pb.val(barrel_right_internal[i + 1]) =
1588  (this->pb.val(this->arg2val.bits[i]) == FieldT::zero())
1589  ? this->pb.val(barrel_right_internal[i])
1590  : FieldT(
1591  this->pb.val(barrel_right_internal[i]).as_ulong() >>
1592  (i + 1));
1593 
1594  shifted_out_bits[i].fill_with_bits_of_ulong(
1595  this->pb,
1596  this->pb.val(barrel_right_internal[i]).as_ulong() % (2u << i));
1597  }
1598 
1599  /*
1600  get result as the logw iterations or zero if shift was oversized
1601 
1602  result = (1-is_oversize_shift) * barrel_right_internal[logw]
1603  */
1604  check_oversize_shift->generate_r1cs_witness();
1605  this->pb.val(this->result) =
1606  (FieldT::one() - this->pb.val(is_oversize_shift)) *
1607  this->pb.val(barrel_right_internal[logw]);
1608 
1609  /*
1610  get reversed result for SHL
1611  */
1612  unpack_result->generate_r1cs_witness_from_packed();
1613  pack_reversed_result->generate_r1cs_witness_from_bits();
1614 
1615  /*
1616  select the correct output:
1617  r = result * opcode_indicators[SHR] + reverse(result) *
1618  (1-opcode_indicators[SHR]) r - reverse(result) = (result -
1619  reverse(result)) * opcode_indicators[SHR]
1620  */
1621  this->pb.val(shr_result) =
1622  (this->pb.val(this->opcode_indicators[tinyram_opcode_SHR]) ==
1623  FieldT::one())
1624  ? this->pb.val(this->result)
1625  : this->pb.val(reversed_result);
1626 
1627  this->pb.val(shl_result) = this->pb.val(shr_result);
1628  this->pb.val(shr_flag) = this->pb.val(this->arg1val.bits[0]);
1629  this->pb.val(shl_flag) =
1630  this->pb.val(this->arg1val.bits[this->pb.ap.w - 1]);
1631 }
1632 
1633 template<typename FieldT> void test_ALU_shr_gadget(const size_t w)
1634 {
1635  libff::print_time("starting shr test");
1636  brute_force_arithmetic_gadget<ALU_shr_shl_gadget<FieldT>, FieldT>(
1637  w,
1638  tinyram_opcode_SHR,
1639  [](tinyram_protoboard<FieldT> &pb,
1640  pb_variable_array<FieldT> &opcode_indicators,
1641  word_variable_gadget<FieldT> &desval,
1642  word_variable_gadget<FieldT> &arg1val,
1643  word_variable_gadget<FieldT> &arg2val,
1644  pb_variable<FieldT> &flag,
1645  pb_variable<FieldT> &result,
1646  pb_variable<FieldT> &result_flag) -> ALU_shr_shl_gadget<FieldT> * {
1647  pb_variable<FieldT> shl_result;
1648  shl_result.allocate(pb, "shl_result");
1649  pb_variable<FieldT> shl_flag;
1650  shl_flag.allocate(pb, "shl_flag");
1651  return new ALU_shr_shl_gadget<FieldT>(
1652  pb,
1653  opcode_indicators,
1654  desval,
1655  arg1val,
1656  arg2val,
1657  flag,
1658  result,
1659  result_flag,
1660  shl_result,
1661  shl_flag,
1662  "ALU_shr_shl_gadget");
1663  },
1664  [w](size_t, bool, size_t x, size_t y) -> size_t { return (x >> y); },
1665  [w](size_t, bool, size_t x, size_t) -> bool { return (x & 1); });
1666  libff::print_time("shr tests successful");
1667 }
1668 
1669 template<typename FieldT> void test_ALU_shl_gadget(const size_t w)
1670 {
1671  libff::print_time("starting shl test");
1672  brute_force_arithmetic_gadget<ALU_shr_shl_gadget<FieldT>, FieldT>(
1673  w,
1674  tinyram_opcode_SHL,
1675  [](tinyram_protoboard<FieldT> &pb,
1676  pb_variable_array<FieldT> &opcode_indicators,
1677  word_variable_gadget<FieldT> &desval,
1678  word_variable_gadget<FieldT> &arg1val,
1679  word_variable_gadget<FieldT> &arg2val,
1680  pb_variable<FieldT> &flag,
1681  pb_variable<FieldT> &result,
1682  pb_variable<FieldT> &result_flag) -> ALU_shr_shl_gadget<FieldT> * {
1683  pb_variable<FieldT> shr_result;
1684  shr_result.allocate(pb, "shr_result");
1685  pb_variable<FieldT> shr_flag;
1686  shr_flag.allocate(pb, "shr_flag");
1687  return new ALU_shr_shl_gadget<FieldT>(
1688  pb,
1689  opcode_indicators,
1690  desval,
1691  arg1val,
1692  arg2val,
1693  flag,
1694  shr_result,
1695  shr_flag,
1696  result,
1697  result_flag,
1698  "ALU_shr_shl_gadget");
1699  },
1700  [w](size_t, bool, size_t x, size_t y) -> size_t {
1701  return (x << y) & ((1ul << w) - 1);
1702  },
1703  [w](size_t, bool, size_t x, size_t) -> bool { return (x >> (w - 1)); });
1704  libff::print_time("shl tests successful");
1705 }
1706 
1707 } // namespace libsnark
1708 
1709 #endif