Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
scalar_multiplication.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Planned, auditors: [Sergei], commit: }
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#pragma once
8// This header hosts TWO implementations behind one facade:
9// * `bb::scalar_multiplication::legacy::*` — the pre-rewrite Pippenger MSM, bodies
10// byte-identical to merge-train (only wrapped in the `legacy` sub-namespace).
11// * the round-parallel rewrite in scalar_multiplication_fast.hpp (`*_fast`, `MSM_fast`).
12// The public facade (`pippenger`, `pippenger_unsafe`, `MSM`) at the bottom dispatches to
13// the rewrite by default, or to `legacy::` when `use_legacy_msm()` (env BB_MSM_LEGACY).
14// Remove the legacy half + the facade dispatch once the rewrite has soaked.
18
22
23#include "./bitvector.hpp"
24#include "./process_buckets.hpp"
26
27template <typename Curve> class MSM {
28 public:
29 using Element = typename Curve::Element;
31 using BaseField = typename Curve::BaseField;
33
34 static constexpr size_t NUM_BITS_IN_FIELD = ScalarField::modulus.get_msb() + 1;
35
36 // ======================= Algorithm Tuning Constants =======================
37 //
38 // These constants control the behavior of the Pippenger MSM algorithm.
39 // They are empirically tuned for performance on typical hardware.
40
41 // Below this threshold, use naive scalar multiplication instead of Pippenger
42 static constexpr size_t PIPPENGER_THRESHOLD = 16;
43
44 // Below this threshold, the affine batch inversion trick is not beneficial
45 // (cost of inversions exceeds savings from cheaper affine additions)
46 static constexpr size_t AFFINE_TRICK_THRESHOLD = 128;
47
48 // Maximum bits per scalar slice (2^20 = 1M buckets, far beyond practical use)
49 static constexpr size_t MAX_SLICE_BITS = 20;
50 static_assert(MAX_SLICE_BITS < 64,
51 "get_scalar_slice uses 1ULL << lo_slice_bits where lo_slice_bits <= MAX_SLICE_BITS - 1; "
52 "shifting uint64_t by >= 64 is UB.");
53
54 // Number of points to look ahead for memory prefetching
55 static constexpr size_t PREFETCH_LOOKAHEAD = 32;
56
57 // Prefetch every N iterations (must be power of 2); mask is N-1 for efficient modulo
58 static constexpr size_t PREFETCH_INTERVAL = 16;
59 static constexpr size_t PREFETCH_INTERVAL_MASK = PREFETCH_INTERVAL - 1;
60
61 // ======================= Cost Model Constants =======================
62 //
63 // These constants define the relative costs of various operations,
64 // used to decide between algorithm variants.
65
66 // Cost of bucket accumulation relative to a single point addition
67 // (2 Jacobian adds per bucket, each ~2.5x cost of affine add)
68 static constexpr size_t BUCKET_ACCUMULATION_COST = 5;
69
70 // Field multiplications saved per group operation when using affine trick
71 static constexpr size_t AFFINE_TRICK_SAVINGS_PER_OP = 5;
72
73 // Extra cost of Jacobian group operation when Z coordinate != 1
74 static constexpr size_t JACOBIAN_Z_NOT_ONE_PENALTY = 5;
75
76 // Cost of computing 4-bit lookup table for modular exponentiation (14 muls)
77 static constexpr size_t INVERSION_TABLE_COST = 14;
78 // ===========================================================================
79
80 // Offset generator used in bucket reduction to probabilistically avoid incomplete-addition
81 // edge cases in the accumulator. Derived from domain-separated precomputed generators.
83 {
84 static const AffineElement offset_generator = []() {
86 return get_precomputed_generators<typename Curve::Group, "ECCVM_OFFSET_GENERATOR", 1>()[0];
87 } else {
88 return get_precomputed_generators<typename Curve::Group, "DEFAULT_DOMAIN_SEPARATOR", 8>()[0];
89 }
90 }();
91 return offset_generator;
92 }
93
102 struct MSMWorkUnit {
103 size_t batch_msm_index = 0;
104 size_t start_index = 0;
105 size_t size = 0;
106 };
108
113 struct MSMData {
114 std::span<const ScalarField> scalars; // Scalars (non-Montgomery form)
115 std::span<const AffineElement> points; // Input points
116 std::span<const uint32_t> scalar_indices; // Indices of nonzero scalars
117 std::span<uint64_t> point_schedule; // Scratch space for point scheduling
118
123 static MSMData from_work_unit(std::span<std::span<ScalarField>> all_scalars,
124 std::span<std::span<const AffineElement>> all_points,
125 const std::vector<std::vector<uint32_t>>& all_indices,
126 std::span<uint64_t> point_schedule_buffer,
127 const MSMWorkUnit& work_unit) noexcept
128 {
129 const auto& indices = all_indices[work_unit.batch_msm_index];
130 // Avoid indexing into an empty vector when all scalars are zero (work_unit.size == 0)
131 std::span<const uint32_t> scalar_indices =
132 work_unit.size > 0 ? std::span<const uint32_t>{ &indices[work_unit.start_index], work_unit.size }
133 : std::span<const uint32_t>{};
134 return MSMData{
135 .scalars = all_scalars[work_unit.batch_msm_index],
136 .points = all_points[work_unit.batch_msm_index],
137 .scalar_indices = scalar_indices,
138 .point_schedule = point_schedule_buffer,
139 };
140 }
141 };
142
151 std::vector<AffineElement> buckets;
153
154 BucketAccumulators(size_t num_buckets) noexcept
155 : buckets(num_buckets)
156 , bucket_exists(num_buckets)
157 {}
158 };
159
168 std::vector<Element> buckets;
170
172 : buckets(num_buckets)
173 , bucket_exists(num_buckets)
174 {}
175 };
180 static constexpr size_t BATCH_SIZE = 2048;
181 // when adding affine points, we have an edge case where the number of points in the batch can overflow by 2
182 static constexpr size_t BATCH_OVERFLOW_SIZE = 2;
183 std::vector<AffineElement> points_to_add;
184 std::vector<BaseField> inversion_scratch_space; // Used for Montgomery batch inversion denominators
186 AffineElement null_location{}; // Dummy write target for branchless conditional moves
187
193 };
194
200 uint64_t data;
201
202 [[nodiscard]] static constexpr PointScheduleEntry create(uint32_t point_index, uint32_t bucket_index) noexcept
203 {
204 return { (static_cast<uint64_t>(point_index) << 32) | bucket_index };
205 }
206 [[nodiscard]] constexpr uint32_t point_index() const noexcept { return static_cast<uint32_t>(data >> 32); }
207 [[nodiscard]] constexpr uint32_t bucket_index() const noexcept { return static_cast<uint32_t>(data); }
208 };
209
210 // ======================= Public Methods =======================
211 // See README.md for algorithm details and mathematical derivations.
212
218 static AffineElement msm(std::span<const AffineElement> points,
220 bool handle_edge_cases = false) noexcept;
221
227 static std::vector<AffineElement> batch_multi_scalar_mul(std::span<std::span<const AffineElement>> points,
228 std::span<std::span<ScalarField>> scalars,
229 bool handle_edge_cases = true) noexcept;
230
231 // ======================= Test-Visible Methods =======================
232 // Exposed for unit testing; not part of the public API.
233
234 static uint32_t get_num_rounds(size_t num_points) noexcept
235 {
236 const uint32_t bits_per_slice = get_optimal_log_num_buckets(num_points);
237 return static_cast<uint32_t>((NUM_BITS_IN_FIELD + bits_per_slice - 1) / bits_per_slice);
238 }
239
241 static void add_affine_points(AffineElement* points,
242 const size_t num_points,
243 typename Curve::BaseField* scratch_space) noexcept;
244
246 static uint32_t get_scalar_slice(const ScalarField& scalar, size_t round, size_t slice_size) noexcept;
247
249 static uint32_t get_optimal_log_num_buckets(size_t num_points) noexcept;
250
256 static std::vector<ThreadWorkUnits> partition_by_weight(std::span<const std::vector<uint16_t>> msm_scalar_weights,
257 size_t num_threads) noexcept;
258
261 std::span<const AffineElement> points,
262 AffineAdditionData& affine_data,
263 BucketAccumulators& bucket_data) noexcept;
264
266 template <typename BucketType> static Element accumulate_buckets(BucketType& bucket_accumulators) noexcept
267 {
268 auto& buckets = bucket_accumulators.buckets;
269 BB_ASSERT_DEBUG(buckets.size() > static_cast<size_t>(0));
270 int starting_index = static_cast<int>(buckets.size() - 1);
271 Element running_sum;
272 bool found_start = false;
273 while (!found_start && starting_index > 0) {
274 const size_t idx = static_cast<size_t>(starting_index);
275 if (bucket_accumulators.bucket_exists.get(idx)) {
276
277 running_sum = buckets[idx];
278 found_start = true;
279 } else {
280 starting_index -= 1;
281 }
282 }
283 if (!found_start) {
284 return Curve::Group::point_at_infinity;
285 }
286 BB_ASSERT_DEBUG(starting_index > 0);
287 const auto& offset_generator = get_offset_generator();
288 Element sum = running_sum + offset_generator;
289 for (int i = starting_index - 1; i > 0; --i) {
290 size_t idx = static_cast<size_t>(i);
291 BB_ASSERT_DEBUG(idx < bucket_accumulators.bucket_exists.size());
292 if (bucket_accumulators.bucket_exists.get(idx)) {
293 running_sum += buckets[idx];
294 }
295 sum += running_sum;
296 }
297 return sum - offset_generator;
298 }
299
300 private:
301 // ======================= Private Implementation =======================
302
305 std::vector<uint32_t>& nonzero_scalar_indices) noexcept;
306
311 static void compute_scalar_slice_weights(std::span<const ScalarField> scalars,
312 std::span<const uint32_t> nonzero_indices,
313 uint32_t bits_per_slice,
314 std::vector<uint16_t>& weights) noexcept;
315
322 std::vector<std::vector<uint32_t>>& msm_scalar_indices) noexcept;
323
325 static bool use_affine_trick(size_t num_points, size_t num_buckets) noexcept;
326
328 static Element jacobian_pippenger_with_transformed_scalars(MSMData& msm_data) noexcept;
329
331 static Element affine_pippenger_with_transformed_scalars(MSMData& msm_data) noexcept;
332
333 // Helpers for batch_accumulate_points_into_buckets. Inlined for performance.
334
335 // Process single point: if bucket has accumulator, pair them for addition; else cache in bucket.
336 __attribute__((always_inline)) static void process_single_point(size_t bucket,
340 size_t& scratch_it,
341 size_t& point_it) noexcept
342 {
343 bool has_accumulator = bucket_data.bucket_exists.get(bucket);
344 if (has_accumulator) {
346 affine_data.points_to_add[scratch_it + 1] = bucket_data.buckets[bucket];
347 bucket_data.bucket_exists.set(bucket, false);
348 affine_data.addition_result_bucket_destinations[scratch_it >> 1] = static_cast<uint32_t>(bucket);
349 scratch_it += 2;
350 } else {
351 bucket_data.buckets[bucket] = *point_source;
352 bucket_data.bucket_exists.set(bucket, true);
353 }
355 }
356
357 // Branchless bucket pair processing. Updates point_it (by 2 if same bucket, else 1) and scratch_it.
358 // See README.md "batch_accumulate_points_into_buckets Algorithm" for case analysis.
359 __attribute__((always_inline)) static void process_bucket_pair(size_t lhs_bucket,
365 size_t& scratch_it,
366 size_t& point_it) noexcept
367 {
368 bool has_bucket_accumulator = bucket_data.bucket_exists.get(lhs_bucket);
369 bool buckets_match = lhs_bucket == rhs_bucket;
370 bool do_affine_add = buckets_match || has_bucket_accumulator;
371
373
378
380 dest_bucket = do_affine_add ? static_cast<uint32_t>(lhs_bucket) : dest_bucket;
381
384
385 bucket_data.bucket_exists.set(lhs_bucket, (has_bucket_accumulator && buckets_match) || !do_affine_add);
387 point_it += (do_affine_add && buckets_match) ? 2 : 1;
388 }
389};
390
392template <typename Curve>
395 bool handle_edge_cases = true) noexcept;
396
398template <typename Curve>
399typename Curve::Element pippenger_unsafe(PolynomialSpan<const typename Curve::ScalarField> scalars,
400 std::span<const typename Curve::AffineElement> points) noexcept;
401
402extern template class MSM<curve::Grumpkin>;
403extern template class MSM<curve::BN254>;
404
405} // namespace bb::scalar_multiplication::legacy
406
407// ===================================================================================
408// Public MSM facade — the surface every caller uses. Dispatches to the `_fast` rewrite
409// by default, or `legacy::` when use_legacy_msm() (env BB_MSM_LEGACY, read once).
410// Signatures match the rewrite; the legacy branch adapts (legacy has no dedup pre-pass,
411// and its batch entry takes per-MSM point spans).
412// ===================================================================================
413namespace bb::scalar_multiplication {
414
415[[nodiscard]] bool use_legacy_msm() noexcept;
416
417template <typename Curve>
420 bool handle_edge_cases = true,
421 bool dedup_hint = false) noexcept;
422
423template <typename Curve>
426 bool dedup_hint = false) noexcept;
427
430 bool handle_edge_cases,
431 bool dedup_hint) noexcept;
435 bool handle_edge_cases,
436 bool dedup_hint) noexcept;
440 bool dedup_hint) noexcept;
444 bool dedup_hint) noexcept;
445
446template <typename Curve> class MSM {
447 public:
448 using Element = typename Curve::Element;
451
452 static AffineElement msm(std::span<const AffineElement> points,
454 bool handle_edge_cases = false,
455 bool dedup_hint = false) noexcept;
456
457 static std::vector<AffineElement> batch_multi_scalar_mul(std::span<const AffineElement> points,
458 std::span<PolynomialSpan<ScalarField>> scalars,
459 bool handle_edge_cases = true,
460 std::span<const uint8_t> dedup_hints = {}) noexcept;
461};
462
463extern template class MSM<curve::BN254>;
464extern template class MSM<curve::Grumpkin>;
465
466} // namespace bb::scalar_multiplication
#define BB_ASSERT_DEBUG(expression,...)
Definition assert.hpp:55
Custom class to handle packed vectors of bits.
Definition bitvector.hpp:23
typename Group::element Element
Definition bn254.hpp:21
typename Group::element Element
Definition grumpkin.hpp:63
typename grumpkin::g1 Group
Definition grumpkin.hpp:62
typename Group::affine_element AffineElement
Definition grumpkin.hpp:64
typename Curve::ScalarField ScalarField
typename Curve::AffineElement AffineElement
static const AffineElement & get_offset_generator() noexcept
static void transform_scalar_and_get_nonzero_scalar_indices(std::span< ScalarField > scalars, std::vector< uint32_t > &nonzero_scalar_indices) noexcept
Convert scalars from Montgomery form and collect indices of nonzero scalars.
__attribute__((always_inline)) static void process_single_point(size_t bucket
size_t const AffineElement const AffineElement * rhs_source_if_match
const AffineElement AffineAdditionData BucketAccumulators & bucket_data
static Element jacobian_pippenger_with_transformed_scalars(MSMData &msm_data) noexcept
Pippenger using Jacobian buckets (handles edge cases: doubling, infinity)
typename Curve::AffineElement AffineElement
static uint32_t get_num_rounds(size_t num_points) noexcept
static void compute_scalar_slice_weights(std::span< const ScalarField > scalars, std::span< const uint32_t > nonzero_indices, uint32_t bits_per_slice, std::vector< uint16_t > &weights) noexcept
Compute per-scalar slice-count weights ceil(bit_length / bits_per_slice).
static std::vector< ThreadWorkUnits > partition_by_weight(std::span< const std::vector< uint16_t > > msm_scalar_weights, size_t num_threads) noexcept
Partition per-MSM scalar weights into num_threads work units of approximately equal cumulative weight...
static uint32_t get_optimal_log_num_buckets(size_t num_points) noexcept
Compute optimal bits per slice by minimizing cost over c in [1, MAX_SLICE_BITS)
static void add_affine_points(AffineElement *points, const size_t num_points, typename Curve::BaseField *scratch_space) noexcept
Batch add n/2 independent point pairs using Montgomery's trick.
static std::vector< AffineElement > batch_multi_scalar_mul(std::span< std::span< const AffineElement > > points, std::span< std::span< ScalarField > > scalars, bool handle_edge_cases=true) noexcept
Compute multiple MSMs in parallel with work balancing.
static AffineElement msm(std::span< const AffineElement > points, PolynomialSpan< const ScalarField > scalars, bool handle_edge_cases=false) noexcept
Main entry point for single MSM computation.
static std::vector< ThreadWorkUnits > get_work_units(std::span< std::span< ScalarField > > scalars, std::vector< std::vector< uint32_t > > &msm_scalar_indices) noexcept
Distribute multiple MSMs across threads with balanced bucket-accumulation work.
__attribute__((always_inline)) static void process_bucket_pair(size_t lhs_bucket
static bool use_affine_trick(size_t num_points, size_t num_buckets) noexcept
Decide if batch inversion saves work vs Jacobian additions.
const AffineElement AffineAdditionData & affine_data
const AffineElement AffineAdditionData BucketAccumulators size_t & scratch_it
static Element affine_pippenger_with_transformed_scalars(MSMData &msm_data) noexcept
Pippenger using affine buckets with batch inversion (faster, no edge case handling)
static uint32_t get_scalar_slice(const ScalarField &scalar, size_t round, size_t slice_size) noexcept
Extract c-bit slice from scalar for bucket index computation.
static Element accumulate_buckets(BucketType &bucket_accumulators) noexcept
Reduce buckets to single point using running (suffix) sum from high to low: R = sum(k * B_k)
const AffineElement AffineAdditionData BucketAccumulators size_t size_t &point_it noexcept
static void batch_accumulate_points_into_buckets(std::span< const uint64_t > point_schedule, std::span< const AffineElement > points, AffineAdditionData &affine_data, BucketAccumulators &bucket_data) noexcept
Process sorted point schedule into bucket accumulators using batched affine additions.
template curve::BN254::Element pippenger< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars, std::span< const curve::BN254::AffineElement > points, bool handle_edge_cases=true)
Curve::Element pippenger_unsafe(PolynomialSpan< const typename Curve::ScalarField > scalars, std::span< const typename Curve::AffineElement > points) noexcept
Fast MSM wrapper for linearly independent points (no edge case handling)
template curve::Grumpkin::Element pippenger_unsafe< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars, std::span< const curve::Grumpkin::AffineElement > points)
Curve::Element pippenger(PolynomialSpan< const typename Curve::ScalarField > scalars, std::span< const typename Curve::AffineElement > points, bool handle_edge_cases) noexcept
Safe MSM wrapper (defaults to handle_edge_cases=true)
template curve::Grumpkin::Element pippenger< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars, std::span< const curve::Grumpkin::AffineElement > points, bool handle_edge_cases=true) noexcept
template curve::BN254::Element pippenger_unsafe< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars, std::span< const curve::BN254::AffineElement > points)
Entry point for Barretenberg command-line interface.
Definition api.hpp:5
Inner sum(Cont< Inner, Args... > const &in)
Definition container.hpp:70
constexpr std::span< const typename Group::affine_element > get_precomputed_generators()
@ BN254
Definition types.hpp:10
STL namespace.
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
Curve::Element Element
Scratch space for batched affine point additions (one per thread)
Affine bucket accumulators for the fast affine-trick Pippenger variant.
Jacobian bucket accumulators for the safe Pippenger variant.
Container for MSM input data passed between algorithm stages.
static MSMData from_work_unit(std::span< std::span< ScalarField > > all_scalars, std::span< std::span< const AffineElement > > all_points, const std::vector< std::vector< uint32_t > > &all_indices, std::span< uint64_t > point_schedule_buffer, const MSMWorkUnit &work_unit) noexcept
Factory method to construct MSMData from a work unit.
MSMWorkUnit describes an MSM that may be part of a larger MSM.
Packed point schedule entry: (point_index << 32) | bucket_index.
constexpr uint32_t point_index() const noexcept
uint64_t data
constexpr uint32_t bucket_index() const noexcept
static constexpr PointScheduleEntry create(uint32_t point_index, uint32_t bucket_index) noexcept