Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
multi_scalar_mul.test.cpp
Go to the documentation of this file.
2#include "acir_format.hpp"
6
7#include <gtest/gtest.h>
8#include <vector>
9
10using namespace ::acir_format;
11
12enum class InputConstancy : uint8_t { None, Points, Scalars, Both };
13
23template <typename Builder_, InputConstancy Constancy> class MultiScalarMulTestingFunctions {
24 public:
25 using Builder = Builder_;
26 using AcirConstraint = MultiScalarMul;
29 using FF = bb::fr;
30
32 public:
33 enum class Target : uint8_t {
34 None,
35 Points, // Invalidate point inputs
36 Scalars, // Invalidate scalar inputs
37 Result // Invalidate result output
38 };
39
44
45 static std::vector<std::string> get_labels() { return { "None", "Points", "Scalars", "Result" }; }
46 };
47
48 static ProgramMetadata generate_metadata() { return ProgramMetadata{}; }
49
50 static void generate_constraints(AcirConstraint& msm_constraint, WitnessVector& witness_values)
51 {
52 // Generate a single point and scalar for simplicity
54 bb::fq scalar_native = bb::fq::random_element();
55 GrumpkinPoint result = point * scalar_native;
56 BB_ASSERT(result != GrumpkinPoint::one()); // Ensure that tampering works correctly
57
58 // Split scalar into low and high limbs (128 bits each) as FF for witness values
59 uint256_t scalar_u256 = uint256_t(scalar_native);
60 FF scalar_lo = scalar_u256.slice(0, 128);
61 FF scalar_hi = scalar_u256.slice(128, 256);
62
63 // Determine which inputs are constants based on the Constancy template parameter
64 constexpr bool points_are_constant = (Constancy == InputConstancy::Points || Constancy == InputConstancy::Both);
65 constexpr bool scalars_are_constant =
66 (Constancy == InputConstancy::Scalars || Constancy == InputConstancy::Both);
67
68 // Helper to add points: either as witnesses or constants based on Constancy.
69 // Points are encoded as (x, y); the point at infinity is encoded as (0, 0).
70 auto construct_points = [&]() -> std::vector<WitnessOrConstant<FF>> {
71 if constexpr (points_are_constant) {
72 return { WitnessOrConstant<FF>::from_constant(point.x), WitnessOrConstant<FF>::from_constant(point.y) };
73 }
74 std::vector<uint32_t> point_indices = add_to_witness_and_track_indices(witness_values, point);
75 return { WitnessOrConstant<FF>::from_index(point_indices[0]),
76 WitnessOrConstant<FF>::from_index(point_indices[1]) };
77 };
78
79 // Helper to add scalars: either as witnesses or constants based on Constancy
80 auto construct_scalars = [&]() -> std::vector<WitnessOrConstant<FF>> {
81 if constexpr (scalars_are_constant) {
82 // Scalars are constants
83 return { WitnessOrConstant<FF>::from_constant(scalar_lo),
84 WitnessOrConstant<FF>::from_constant(scalar_hi) };
85 }
86 // Scalars are witnesses
87 uint32_t scalar_lo_index = static_cast<uint32_t>(witness_values.size());
88 witness_values.emplace_back(scalar_lo);
89 uint32_t scalar_hi_index = static_cast<uint32_t>(witness_values.size());
90 witness_values.emplace_back(scalar_hi);
91 return { WitnessOrConstant<FF>::from_index(scalar_lo_index),
92 WitnessOrConstant<FF>::from_index(scalar_hi_index) };
93 };
94
95 // Add points and scalars according to constancy template parameter
96 auto point_fields = construct_points();
97 auto scalar_fields = construct_scalars();
98
99 // Construct result and predicate as witnesses
100 std::vector<uint32_t> result_indices = add_to_witness_and_track_indices(witness_values, result);
101 uint32_t predicate_index = static_cast<uint32_t>(witness_values.size());
102 witness_values.emplace_back(FF::one()); // predicate
103
104 // Build the constraint
105 msm_constraint = MultiScalarMul{
106 .points = point_fields,
107 .scalars = scalar_fields,
108 .predicate = WitnessOrConstant<FF>::from_index(predicate_index),
109 .out_point_x = result_indices[0],
110 .out_point_y = result_indices[1],
111 };
112 }
113
115 AcirConstraint constraint, WitnessVector witness_values, const InvalidWitness::Target& invalid_witness_target)
116 {
117 switch (invalid_witness_target) {
119 // Invalidate the point by adding 1 to x coordinate
120 if constexpr (Constancy == InputConstancy::None || Constancy == InputConstancy::Scalars) {
121 witness_values[constraint.points[0].index] += bb::fr(1);
122 } else {
123 constraint.points[0] = WitnessOrConstant<FF>::from_constant(constraint.points[0].value + bb::fr(1));
124 }
125 break;
126 }
128 // Invalidate the scalar by adding 1 to the low limb
129 if constexpr (Constancy == InputConstancy::None || Constancy == InputConstancy::Points) {
130 witness_values[constraint.scalars[0].index] += bb::fr(1);
131 } else {
132 constraint.scalars[0] = WitnessOrConstant<FF>::from_constant(constraint.scalars[0].value + bb::fr(1));
133 }
134 break;
135 }
137 // Tamper with the result by setting it to the generator point
138 witness_values[constraint.out_point_x] = GrumpkinPoint::one().x;
139 witness_values[constraint.out_point_y] = GrumpkinPoint::one().y;
140 break;
141 }
143 default:
144 break;
145 }
146
147 return { constraint, witness_values };
148 };
149};
150
151template <typename Builder>
153 : public ::testing::Test,
154 public TestClassWithPredicate<MultiScalarMulTestingFunctions<Builder, InputConstancy::None>> {
155 protected:
157};
158
159template <typename Builder>
161 : public ::testing::Test,
162 public TestClassWithPredicate<MultiScalarMulTestingFunctions<Builder, InputConstancy::Points>> {
163 protected:
165};
166
167template <typename Builder>
169 : public ::testing::Test,
170 public TestClassWithPredicate<MultiScalarMulTestingFunctions<Builder, InputConstancy::Scalars>> {
171 protected:
173};
174
175template <typename Builder>
177 : public ::testing::Test,
178 public TestClassWithPredicate<MultiScalarMulTestingFunctions<Builder, InputConstancy::Both>> {
179 protected:
181};
182
183using BuilderTypes = testing::Types<UltraCircuitBuilder, MegaCircuitBuilder>;
184
189
191{
193 TestFixture::template test_vk_independence<Flavor>();
194}
195
197{
199 TestFixture::test_constant_true(TestFixture::InvalidWitnessTarget::Result);
200}
201
203{
205 TestFixture::test_witness_true(TestFixture::InvalidWitnessTarget::Result);
206}
207
209{
211 TestFixture::test_witness_false_slow();
212}
213
215{
217 [[maybe_unused]] std::vector<std::string> _ = TestFixture::test_invalid_witnesses();
218}
219
221{
223 TestFixture::template test_vk_independence<Flavor>();
224}
225
227{
229 TestFixture::test_constant_true(TestFixture::InvalidWitnessTarget::Result);
230}
231
233{
235 TestFixture::test_witness_true(TestFixture::InvalidWitnessTarget::Result);
236}
237
239{
241 TestFixture::test_witness_false_slow();
242}
243
245{
247 [[maybe_unused]] std::vector<std::string> _ = TestFixture::test_invalid_witnesses();
248}
249
251{
253 TestFixture::template test_vk_independence<Flavor>();
254}
255
257{
259 TestFixture::test_constant_true(TestFixture::InvalidWitnessTarget::Result);
260}
261
263{
265 TestFixture::test_witness_true(TestFixture::InvalidWitnessTarget::Result);
266}
267
269{
271 TestFixture::test_witness_false_slow();
272}
273
275{
277 [[maybe_unused]] std::vector<std::string> _ = TestFixture::test_invalid_witnesses();
278}
279
281{
283 TestFixture::template test_vk_independence<Flavor>();
284}
285
287{
289 TestFixture::test_constant_true(TestFixture::InvalidWitnessTarget::Result);
290}
291
293{
295 TestFixture::test_witness_true(TestFixture::InvalidWitnessTarget::Result);
296}
297
299{
301 TestFixture::test_witness_false_slow();
302}
303
305{
307 [[maybe_unused]] std::vector<std::string> _ = TestFixture::test_invalid_witnesses();
308}
309
310// ============================================================
311// Infinity tests: the point at infinity is encoded as (0, 0).
312// ============================================================
313
315using MsmFF = bb::fr;
316
319 static MsmAcirPoint from_native(const MsmGrumpkinPoint& p) { return { p.x, p.y }; }
320 static MsmAcirPoint infinity() { return { MsmFF(0), MsmFF(0) }; }
321};
322
323// Grumpkin scalar split into low 128-bit and high 128-bit field limbs.
324struct MsmScalar {
326 static MsmScalar zero() { return { MsmFF(0), MsmFF(0) }; }
327 static MsmScalar from_native(const bb::fq& s)
328 {
329 uint256_t u = uint256_t(s);
330 return { u.slice(0, 128), u.slice(128, 256) };
331 }
332};
333
334// Shared single-term MSM circuit helpers: build a one point/one scalar MSM constraint with predicate=1
335// from explicit witness values, and run the resulting circuit.
336template <typename Builder> class MsmSingleTermFixture : public ::testing::Test {
337 protected:
339
340 // Push an MsmAcirPoint to witness; return [x, y] indices.
341 static std::array<uint32_t, 2> push_point(WitnessVector& witness, const MsmAcirPoint& pt)
342 {
343 uint32_t xi = static_cast<uint32_t>(witness.size());
344 witness.emplace_back(pt.x);
345 uint32_t yi = static_cast<uint32_t>(witness.size());
346 witness.emplace_back(pt.y);
347 return { xi, yi };
348 }
349
350 // Push a scalar (lo, hi) to witness; return [lo_idx, hi_idx].
351 static std::array<uint32_t, 2> push_scalar(WitnessVector& witness, const MsmScalar& s)
352 {
353 uint32_t lo_idx = static_cast<uint32_t>(witness.size());
354 witness.emplace_back(s.lo);
355 uint32_t hi_idx = static_cast<uint32_t>(witness.size());
356 witness.emplace_back(s.hi);
357 return { lo_idx, hi_idx };
358 }
359
360 // Build a single-term MSM constraint (predicate=1) from a point, scalar, and expected result.
361 // Returns the constraint and the populated witness vector.
363 {
364 WitnessVector witness;
365 auto p = push_point(witness, point);
366 auto s = push_scalar(witness, scalar);
367 auto r = push_point(witness, result);
368 uint32_t pred_idx = static_cast<uint32_t>(witness.size());
369 witness.emplace_back(MsmFF(1));
370
371 MultiScalarMul c{
372 .points = { WitnessOrConstant<MsmFF>::from_index(p[0]), WitnessOrConstant<MsmFF>::from_index(p[1]) },
373 .scalars = { WitnessOrConstant<MsmFF>::from_index(s[0]), WitnessOrConstant<MsmFF>::from_index(s[1]) },
374 .predicate = WitnessOrConstant<MsmFF>::from_index(pred_idx),
375 .out_point_x = r[0],
376 .out_point_y = r[1],
377 };
378 return { c, witness };
379 }
380
381 // Run the circuit and return (satisfied, error_string).
382 static std::pair<bool, std::string> run_circuit(MultiScalarMul constraint, WitnessVector witness)
383 {
384 AcirFormat cs = constraint_to_acir_format(constraint);
385 AcirProgram program{ cs, witness };
386 auto builder = create_circuit<Builder>(program, ProgramMetadata{});
387 bool ok = CircuitChecker::check(builder) && !builder.failed();
388 return { ok, builder.err() };
389 }
390};
391
392template <typename Builder> class MultiScalarMulInfinityTests : public MsmSingleTermFixture<Builder> {};
393
395
396// scalar=0 → result = (0, 0): valid circuit.
398{
401 auto [constraint, witness] =
402 TestFixture::make_msm(MsmAcirPoint::from_native(point), MsmScalar::zero(), MsmAcirPoint::infinity());
403
404 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
405 EXPECT_TRUE(ok) << "0 * P = infinity should produce a valid circuit";
406}
407
408// ============================================================
409// Scalar field-bounds tests
410// ============================================================
411//
412// The MSM opcode receives a Grumpkin scalar as two field limbs: lo (low 128 bits) and hi (next 126
413// bits), reconstructing v = lo + hi * 2^128. cycle_scalar's public constructor adds an in-circuit
414// check that v < r, where r == bb::fq::modulus is the Grumpkin scalar field modulus (and also the
415// order of the Grumpkin group, since Grumpkin's scalar field is BN254's base field). batch_mul
416// additionally range-constrains the limbs to lo < 2^128 and hi < 2^126. These tests pin behaviour
417// at and beyond the modulus boundary: an out-of-range scalar must make the circuit unsatisfiable,
418// and the group law's s ≡ s + r equivalence must not let a caller smuggle a non-canonical scalar
419// through to barretenberg.
420
421namespace {
422// r = order of the Grumpkin group = bb::fq::modulus.
423const uint256_t grumpkin_scalar_modulus = bb::fq::modulus;
424
425// Build an MsmScalar straight from a uint256_t value, splitting at the 128-bit limb boundary with no
426// modular reduction (so out-of-field values can be expressed).
427MsmScalar msm_scalar_from_u256(const uint256_t& v)
428{
429 return { MsmFF(v.slice(0, 128)), MsmFF(v.slice(128, 256)) };
430}
431} // namespace
432
433template <typename Builder> class MultiScalarMulScalarBoundsTests : public MsmSingleTermFixture<Builder> {};
434
436
437// scalar == r: rejected. The in-circuit "scalar < r" check fails. (r·P = O, so the caller gains
438// nothing by claiming the point at infinity as the result.)
440{
443 auto [constraint, witness] = TestFixture::make_msm(
444 MsmAcirPoint::from_native(point), msm_scalar_from_u256(grumpkin_scalar_modulus), MsmAcirPoint::infinity());
445
446 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
447 EXPECT_FALSE(ok) << "scalar == Grumpkin scalar modulus must not produce a satisfiable circuit";
448}
449
450// scalar == r + 1: rejected, even though (r + 1)·P == 1·P == P. The in-circuit "scalar < r" check
451// fails despite both limbs being within their range constraints.
453{
456 auto [constraint, witness] = TestFixture::make_msm(MsmAcirPoint::from_native(point),
457 msm_scalar_from_u256(grumpkin_scalar_modulus + uint256_t(1)),
459
460 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
461 EXPECT_FALSE(ok) << "scalar == Grumpkin scalar modulus + 1 must not produce a satisfiable circuit";
462}
463
464// scalar == r - 1: the largest in-field scalar. This must prove fine.
465TYPED_TEST(MultiScalarMulScalarBoundsTests, ScalarModulusMinusOneProves)
466{
469 bb::fq scalar_native = bb::fq(grumpkin_scalar_modulus - uint256_t(1));
470 MsmGrumpkinPoint result = point * scalar_native;
471 ASSERT_FALSE(result.is_point_at_infinity());
472 auto [constraint, witness] = TestFixture::make_msm(MsmAcirPoint::from_native(point),
473 msm_scalar_from_u256(grumpkin_scalar_modulus - uint256_t(1)),
475
476 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
477 EXPECT_TRUE(ok) << "scalar == Grumpkin scalar modulus - 1 (largest in-field scalar) should prove. err: " << err;
478}
479
480// scalar == 2^254 - 1: the largest value the (128 + 126)-bit limb encoding can represent. Both limbs
481// satisfy their range constraints, so the only thing rejecting it is the in-circuit "scalar < r"
482// check — exercising the "limbs in range but value out of field" path.
483TYPED_TEST(MultiScalarMulScalarBoundsTests, MaxRepresentableScalarFails)
484{
487 uint256_t max_representable = (uint256_t(1) << 254) - uint256_t(1);
488 MsmGrumpkinPoint result = point * bb::fq(max_representable); // (2^254 - 1) mod r
489 auto [constraint, witness] = TestFixture::make_msm(
490 MsmAcirPoint::from_native(point), msm_scalar_from_u256(max_representable), MsmAcirPoint::from_native(result));
491
492 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
493 EXPECT_FALSE(ok) << "scalar == 2^254 - 1 (> Grumpkin modulus) must not produce a satisfiable circuit";
494}
495
496// hi limb == 2^126 (one bit too wide), lo == 0, i.e. scalar value 2^254. Both the limb range
497// constraint (hi < 2^126) and the "scalar < r" check reject it.
498TYPED_TEST(MultiScalarMulScalarBoundsTests, ScalarWithOversizedHiLimbFails)
499{
502 uint256_t two_pow_254 = uint256_t(1) << 254;
503 MsmGrumpkinPoint result = point * bb::fq(two_pow_254);
504 MsmScalar scalar{ MsmFF(0), MsmFF(uint256_t(1) << 126) }; // hi has 127 bits
505 auto [constraint, witness] =
506 TestFixture::make_msm(MsmAcirPoint::from_native(point), scalar, MsmAcirPoint::from_native(result));
507
508 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
509 EXPECT_FALSE(ok) << "scalar hi limb of 127 bits (value 2^254) must not produce a satisfiable circuit";
510}
511
512// Group-law equivalence does not transfer: s·P == (s + r)·P, but the circuit accepts only the
513// canonical scalar s. Adding the Grumpkin modulus to a scalar cannot reprove the same output.
514TYPED_TEST(MultiScalarMulScalarBoundsTests, AddingGrumpkinModulusDoesNotReproveSameOutput)
515{
518 bb::fq scalar_native = bb::fq(5);
519 MsmGrumpkinPoint result = point * scalar_native;
520 ASSERT_FALSE(result.is_point_at_infinity());
521
522 // Sanity: the canonical scalar proves the result.
523 {
524 auto [constraint, witness] = TestFixture::make_msm(
526 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
527 EXPECT_TRUE(ok) << "canonical scalar should prove the MSM result. err: " << err;
528 }
529
530 // The non-canonical scalar s + r yields the same point mathematically, but the circuit rejects it.
531 {
532 auto [constraint, witness] = TestFixture::make_msm(MsmAcirPoint::from_native(point),
533 msm_scalar_from_u256(uint256_t(5) + grumpkin_scalar_modulus),
535 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
536 EXPECT_FALSE(ok) << "scalar s + r must not reprove the output of scalar s";
537 }
538}
#define BB_ASSERT(expression,...)
Definition assert.hpp:70
#define BB_DISABLE_ASSERTS()
Definition assert.hpp:33
static std::pair< bool, std::string > run_circuit(MultiScalarMul constraint, WitnessVector witness)
static std::array< uint32_t, 2 > push_scalar(WitnessVector &witness, const MsmScalar &s)
static std::pair< MultiScalarMul, WitnessVector > make_msm(MsmAcirPoint point, MsmScalar scalar, MsmAcirPoint result)
static std::array< uint32_t, 2 > push_point(WitnessVector &witness, const MsmAcirPoint &pt)
Testing functions to generate the MultiScalarMul test suite. Constancy specifies which inputs to the ...
static ProgramMetadata generate_metadata()
static void generate_constraints(AcirConstraint &msm_constraint, WitnessVector &witness_values)
static std::pair< AcirConstraint, WitnessVector > invalidate_witness(AcirConstraint constraint, WitnessVector witness_values, const InvalidWitness::Target &invalid_witness_target)
static bool check(const Builder &circuit)
Check the witness satisifies the circuit.
constexpr bool is_point_at_infinity() const noexcept
static affine_element random_element(numeric::RNG *engine=nullptr) noexcept
Samples a random point on the curve.
static constexpr affine_element one() noexcept
group class. Represents an elliptic curve group element. Group is parametrised by Fq and Fr
Definition group.hpp:38
group_elements::affine_element< Fq, Fr, Params > affine_element
Definition group.hpp:44
constexpr uint256_t slice(uint64_t start, uint64_t end) const
AluTraceBuilder builder
Definition alu.test.cpp:124
TYPED_TEST(MultiScalarMulTestsNoneConstant, GenerateVKFromConstraints)
bb::fr MsmFF
TYPED_TEST_SUITE(MultiScalarMulTestsNoneConstant, BuilderTypes)
bb::group< bb::fr, bb::fq, G1Params > g1
Definition grumpkin.hpp:46
std::filesystem::path bb_crs_path()
void init_file_crs_factory(const std::filesystem::path &path)
field< Bn254FqParams > fq
Definition fq.hpp:153
field< Bn254FrParams > fr
Definition fr.hpp:155
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
::testing::Types< UltraCircuitBuilder, MegaCircuitBuilder > BuilderTypes
static MsmAcirPoint infinity()
static MsmAcirPoint from_native(const MsmGrumpkinPoint &p)
static MsmScalar zero()
static MsmScalar from_native(const bb::fq &s)
static constexpr field one()
static constexpr uint256_t modulus
static field random_element(numeric::RNG *engine=nullptr) noexcept