Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
scalar_multiplication_fast.cpp
Go to the documentation of this file.
2
13
14#include <algorithm>
15#include <atomic>
16#include <bit>
17#include <cstddef>
18#include <cstdint>
19#include <limits>
20#include <memory>
21#include <span>
22#include <vector>
23
24#ifdef __wasm_simd128__
25#include <wasm_simd128.h>
26#endif
27
29
31{
32#ifdef __wasm__
33 if (n_input <= (size_t{ 1 } << 11)) {
34 return 1;
35 }
36 if (n_input <= (size_t{ 1 } << 15)) {
37 return 2;
38 }
39 return 4;
40#else
41 static_cast<void>(n_input);
42 return 4;
43#endif
44}
45
46namespace round_parallel_detail {
47
48// Anonymous namespace gives all TU-private helpers in `round_parallel_detail` internal
49// linkage (clang-tidy `misc-use-anonymous-namespace`). It is briefly closed and reopened
50// around `pippenger_round_parallel_jacobian_fast`, which has external linkage via
51// `extern template` declarations in the header.
52namespace {
53
54// Bulk-copy a 64-byte affine point (BN254 / Grumpkin layout: 8 × uint64_t).
55// On wasm, V8 TurboFan compiles the default struct copy to 8 i64 loads/stores; explicit
56// v128 loads/stores halve that and roughly double throughput on random-gather access.
57// On native, std::memcpy of a constant-size struct already lowers to 4 × movdqu.
58template <typename AffineElement>
59[[gnu::always_inline]] inline void copy_affine64(AffineElement& dst, const AffineElement& src) noexcept
60{
61 static_assert(sizeof(AffineElement) == 64, "copy_affine64 requires 64-byte affine point");
63 "AffineElement must be trivially copyable for memcpy / SIMD bulk copy "
64 "(also required by the bulk std::memcpy of reduce_chunk output into "
65 "ThreadScratch::window_pts in recursive_affine_bucket_reduce_strided's caller)");
66#ifdef __wasm_simd128__
67 const auto* s = reinterpret_cast<const v128_t*>(&src);
68 auto* d = reinterpret_cast<v128_t*>(&dst);
69 const v128_t a = wasm_v128_load(s + 0);
70 const v128_t b = wasm_v128_load(s + 1);
71 const v128_t c = wasm_v128_load(s + 2);
72 const v128_t e = wasm_v128_load(s + 3);
73 wasm_v128_store(d + 0, a);
74 wasm_v128_store(d + 1, b);
75 wasm_v128_store(d + 2, c);
76 wasm_v128_store(d + 3, e);
77#else
78 std::memcpy(&dst, &src, sizeof(AffineElement));
79#endif
80}
81
82// Constantine signed-Booth window recoder (scalar + SIMD x4 paths) lives in
83// pippenger_constantine.hpp.
84
85// `choose_window_bits` and `build_var_window_schedule` are defined inline in
86// `pippenger_arena_layout.hpp` so the test suite can build identical schedules.
87// `VAR_WINDOW_MAX_WINDOWS` and `VariableWindowSchedule` likewise live there.
88
89// Sentinel value for `msb_per_scalar[i]` when scalar i is zero. uint8_t fits the 254 valid msb
90// positions (0..253) plus this sentinel; matching `msb_hist` bin layout uses bin 0 = zero count
91// so callers index via `msb + 1` (with -1 → bin 0 for the zero case).
92inline constexpr uint8_t MSB_ZERO_SENTINEL = 255;
93
94// Batched-affine drain trigger. `tree_reduce_in_place` accumulates same-bucket pair
95// candidates into the per-thread `points_to_add` / `pair_dest` scratch and drains via a
96// single inversion + N-pair add when the queue hits this size. Sizing trade-off:
97// - higher = larger inversion amortisation = lower per-pair cost,
98// - lower = smaller scratch / less L1 pressure but more drain calls.
99// 256 was chosen empirically: keeps `points_to_add` (256 × 64 B = 16 KB) inside L1, is
100// well above the ~32-pair amortisation breakeven, and is the value the per-OS-thread
101// scratch buffers (`points_to_add`, `inversion_scratch`, `pair_dest`) are sized for.
102//
103// Deliberately a compile-time constant rather than a per-call parameter: the only sites
104// that ever passed a different value were chunks shorter than 256, where the early-drain
105// branch never fires anyway (the end-of-loop drain catches the residue). Keeping it
106// constexpr lets the compiler turn the per-iter `if (pair_count >= BATCH_CAPACITY)` into
107// a compare-against-immediate and fold the drain-trigger condition into the loop shape.
108// `BATCH_CAPACITY` is defined in `pippenger_arena_layout.hpp` so the layout struct can
109// reference it without depending on this TU.
110
111inline int msb_of_2limb(uint64_t lo, uint64_t hi) noexcept
112{
113 if (hi != 0) {
114 return 64 + 63 - __builtin_clzll(hi);
115 }
116 if (lo != 0) {
117 return 63 - __builtin_clzll(lo);
118 }
119 return -1;
120}
121
122// Accepts the raw `uint64_t[4]` `.data` of `uint256_t` / field elements directly.
123inline int msb_of_4limb(const uint64_t (&d)[4]) noexcept // NOLINT(cppcoreguidelines-avoid-c-arrays)
124{
125 if (d[3] != 0) {
126 return 192 + 63 - __builtin_clzll(d[3]);
127 }
128 if (d[2] != 0) {
129 return 128 + 63 - __builtin_clzll(d[2]);
130 }
131 if (d[1] != 0) {
132 return 64 + 63 - __builtin_clzll(d[1]);
133 }
134 if (d[0] != 0) {
135 return 63 - __builtin_clzll(d[0]);
136 }
137 return -1;
138}
139
140inline void record_msb(int msb, uint8_t& dst, std::array<uint32_t, 256>& th_hist) noexcept
141{
142 dst = (msb < 0) ? MSB_ZERO_SENTINEL : static_cast<uint8_t>(msb);
143 ++th_hist[static_cast<size_t>(msb) + 1];
144}
145
149// `AffineBucketChunkInfo` is defined in `pippenger_arena_layout.hpp` (included above).
150
159template <typename Curve> struct ThreadScratch {
160 using AffineElement = typename Curve::AffineElement;
161 using Element = typename Curve::Element;
162 using BaseField = typename Curve::BaseField;
163
164 // reduce_chunk's tree-reduce buffer. Per level the inner loop walks with a read cursor
165 // `i` and a write cursor `next_len ≤ i`, compacting in-place; the next level re-enters
166 // the same buffer without a swap.
168 std::span<uint32_t> curr_buckets;
169
170 // reduce_chunk's batch-affine scratch.
173 std::span<uint32_t> pair_dest;
174
175 size_t result_len = 0;
176
177 // Stage 6a seam-overflow buffer: when a sub-chunk emits a partial for a slot whose
178 // dense bucket entry is already populated (i.e. the digit's run was split across two
179 // sub-chunks), the partial is deferred here and merged at end-of-window via a single
180 // Montgomery-batched tree reduce. Reset to length 0 between windows.
181 std::span<uint32_t> overflow_slots;
183 size_t overflow_len = 0;
184
185 // Recursive affine bucket reduction scratch (cross-window batched, sparse-aware).
186 // `dense_buckets` holds W chunks worth of dense AffineElement arrays back-to-back.
187 // Layout: dense_buckets[w * affine_bucket_stride + i] for window w and 0-indexed slot i.
188 // `is_present` is a parallel uint8_t array marking non-identity slots (0 = empty, 1 = present).
189 // `affine_bucket_pairs` is the scratch buffer for the real-pairs list (single pass: filtered
190 // inline as candidates are generated, no intermediate candidate buffer).
191 // `affine_bucket_indices` is the scratch index buffer for the doubling kernel.
192 // `affine_bucket_inversion_scratch` is reused for the indexed batch-affine kernels.
196 std::span<uint32_t> affine_bucket_indices;
199 // Per-window metadata consumed by recursive_affine_bucket_reduce_strided (lo, hi, buckets_padded,
200 // empty per window). Filled in the lambda before the call.
202};
203
204struct MsmArena {
205 std::unique_ptr<std::byte[]> local_owner; // NOLINT(cppcoreguidelines-avoid-c-arrays)
206 std::byte* data = nullptr;
207 uintptr_t base_addr = 0;
208 size_t capacity = 0;
209 size_t cursor = 0;
210
211 MsmArena(size_t required_bytes, std::span<std::byte> external_arena)
212 {
213 if (!external_arena.empty() && required_bytes <= external_arena.size()) {
214 data = external_arena.data();
215 capacity = external_arena.size();
216 } else {
217 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
219 data = local_owner.get();
220 capacity = required_bytes;
221 }
222 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
223 base_addr = reinterpret_cast<uintptr_t>(data);
224 }
225
226 template <typename T> std::span<T> alloc(size_t count) { return bump_alloc<T>(count, cursor, capacity, 0); }
227
228 template <typename T> std::span<T> bump_alloc(size_t count, size_t& local_cursor, size_t bound, size_t base_offset)
229 {
230 const size_t align = alignof(T);
231 const uintptr_t cur_addr = base_addr + base_offset + local_cursor;
232 const uintptr_t aligned_addr = (cur_addr + align - 1) & ~(uintptr_t{ align } - 1);
233 const size_t aligned_local = static_cast<size_t>(aligned_addr - (base_addr + base_offset));
234 const size_t bytes = count * sizeof(T);
235 BB_ASSERT_LTE(aligned_local + bytes, bound);
236 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
237 T* p = reinterpret_cast<T*>(data + base_offset + aligned_local);
238 local_cursor = aligned_local + bytes;
239 return std::span<T>{ p, count };
240 }
241};
242
243template <typename Curve> inline void drain_batch(ThreadScratch<Curve>& s, size_t pair_count) noexcept
244{
245 if (pair_count == 0) {
246 return;
247 }
248 bb::group_elements::batch_affine_add_interleaved<typename Curve::AffineElement, typename Curve::BaseField>(
249 s.points_to_add.data(), 2 * pair_count, s.inversion_scratch.data());
250 // In-place compaction: each `pair_dest[i]` is the `next_len` value at the moment the
251 // pair was queued, which is < the read cursor `i_outer` and < the current `next_len`
252 // — so writing back into curr_pts at `pair_dest[i]` lands on a slot that is already
253 // past the read cursor. See reduce_chunk for the full invariant.
254 for (size_t i = 0; i < pair_count; ++i) {
255 s.curr_pts[s.pair_dest[i]] = s.points_to_add[pair_count + i];
256 }
257}
258
274template <typename Curve> void tree_reduce_in_place(ThreadScratch<Curve>& s, size_t initial_len) noexcept
275{
276 size_t curr_len = initial_len;
277
278 while (true) {
279 size_t i = 0;
280 size_t next_len = 0;
281 size_t pair_count = 0;
282 bool made_pair = false;
283
284 while (i < curr_len) {
285 if (i + 1 < curr_len && s.curr_buckets[i] == s.curr_buckets[i + 1]) {
286 const size_t slot = 2 * pair_count;
287 s.points_to_add[slot] = s.curr_pts[i];
288 s.points_to_add[slot + 1] = s.curr_pts[i + 1];
289 s.curr_buckets[next_len] = s.curr_buckets[i];
290 s.pair_dest[pair_count] = static_cast<uint32_t>(next_len);
291 ++next_len;
292 ++pair_count;
293 i += 2;
294 made_pair = true;
295
296 if (pair_count >= BATCH_CAPACITY) {
297 drain_batch<Curve>(s, pair_count);
298 pair_count = 0;
299 }
300 } else {
301 s.curr_pts[next_len] = s.curr_pts[i];
302 s.curr_buckets[next_len] = s.curr_buckets[i];
303 ++next_len;
304 ++i;
305 }
306 }
307
308 drain_batch<Curve>(s, pair_count);
309
310 if (!made_pair) {
311 break;
312 }
313
314 curr_len = next_len;
315 }
316
317 s.result_len = curr_len;
318}
319
336template <typename Curve>
337void merge_overflow(ThreadScratch<Curve>& s, typename Curve::AffineElement* dst_dense) noexcept
338{
339 if (s.overflow_len == 0) {
340 return;
341 }
342
343 size_t merge_len = 0;
344 size_t i = 0;
345 while (i < s.overflow_len) {
346 const uint32_t slot = s.overflow_slots[i];
347 s.curr_buckets[merge_len] = slot;
348 s.curr_pts[merge_len] = dst_dense[slot];
349 ++merge_len;
350 while (i < s.overflow_len && s.overflow_slots[i] == slot) {
351 s.curr_buckets[merge_len] = slot;
352 s.curr_pts[merge_len] = s.overflow_pts[i];
353 ++merge_len;
354 ++i;
355 }
356 }
357
358 tree_reduce_in_place<Curve>(s, merge_len);
359
360 for (size_t k = 0; k < s.result_len; ++k) {
361 dst_dense[s.curr_buckets[k]] = s.curr_pts[k];
362 }
363
364 s.overflow_len = 0;
365}
366
371template <typename Curve>
372void reduce_chunk(ThreadScratch<Curve>& s,
373 const uint32_t* schedule,
374 const size_t* bucket_start,
375 size_t chunk_lo,
376 size_t chunk_hi,
377 size_t& bucket_cursor,
378 size_t chunk_bucket_hi,
380 std::span<const typename Curve::AffineElement> dedup_extra_points = {}) noexcept
381{
382 const size_t chunk_len = chunk_hi - chunk_lo;
383 if (chunk_len == 0) {
384 s.result_len = 0;
385 return;
386 }
387
388 BB_ASSERT_LTE(chunk_len, s.curr_pts.size());
389 static_assert(BATCH_CAPACITY <= 4096, "BATCH_CAPACITY must fit in pair_dest scratch");
390
391 // Compact entries while loading: dedup non-rep entries (DEDUP_SKIP_BIT set in the
392 // schedule entry) carry no contribution — their points are already accumulated
393 // into the cluster's combined `extra_points[cid]` emitted at the rep's slot. Skip
394 // them to avoid double-counting and to shrink the tree-reduce input.
395 size_t valid_len = 0;
396 size_t bucket = bucket_cursor;
397 size_t pos = chunk_lo;
398 while (bucket <= chunk_bucket_hi && pos < chunk_hi) {
399 const size_t run_lo = std::max(pos, bucket_start[bucket]);
400 const size_t run_hi = std::min(chunk_hi, bucket_start[bucket + 1]);
401 if (run_lo >= run_hi) {
402 ++bucket;
403 continue;
404 }
405
406 const uint32_t bucket_u32 = static_cast<uint32_t>(bucket);
407 for (size_t i = run_lo; i < run_hi; ++i) {
408 const uint32_t e = schedule[i];
409 if ((e & DEDUP_SKIP_BIT) != 0) {
410 continue; // non-rep: skip, don't consume a curr_pts slot
411 }
412 const uint32_t raw_idx = e & SCHEDULE_INDEX_MASK;
413 const bool neg = (e & SCHEDULE_SIGN_BIT) != 0;
414 s.curr_buckets[valid_len] = bucket_u32;
415 // SIMD-widened gather: 4 × v128.load on WASM (2× faster than the
416 // default 8 × i64.load struct copy on V8 TurboFan); 4 × movdqu on
417 // native (already optimal). The conditional negation runs after the
418 // copy because Fq::operator-() is a modular subtract, not a bit flip,
419 // so it can't be folded into the SIMD load lanes.
420 auto& dst_pt = s.curr_pts[valid_len];
421 // Dedup redirect: if the redirect bit is set, fetch from the dedup
422 // extra-points buffer (combined point for a cluster of duplicate scalars)
423 // instead of the original points span. The branch is always-not-taken when
424 // dedup is inactive (`dedup_extra_points` empty) and predictably-mostly-taken-or-not
425 // when active, since cluster-rep scheduling is uniform per MSM_fast.
426 if ((e & DEDUP_REDIRECT_BIT) != 0) {
427 copy_affine64(dst_pt, dedup_extra_points[raw_idx]);
428 } else {
429 copy_affine64(dst_pt, points[raw_idx]);
430 }
431 if (neg) {
432 dst_pt.y = -dst_pt.y;
433 }
434 ++valid_len;
435 }
436 pos = run_hi;
437 if (pos < chunk_hi) {
438 ++bucket;
439 }
440 }
441 bucket_cursor = bucket;
442
443 tree_reduce_in_place<Curve>(s, valid_len);
444}
445
446// `ChunkOutput<Curve>` (Stage 6 per-chunk bucket-reduce output) is defined in
447// `pippenger_arena_layout.hpp` so the test suite can size the Zone S slot the
448// same way the live allocator does.
449
450// `AffineBucketChunkInfo` is defined in `pippenger_arena_layout.hpp` (forward declared
451// above at line ~674 for ThreadScratch). It describes one chunk's contribution to the
452// cross-window recursive affine bucket reduction (lo/hi digit bounds, buckets_padded,
453// empty flag).
454
474template <typename Curve>
475[[gnu::always_inline]] inline void try_filter_pair(typename Curve::AffineElement* buckets,
476 uint8_t* is_present,
477 uint32_t dst_idx,
478 uint32_t src_idx,
480 size_t& real_count) noexcept
481{
482 using Element = typename Curve::Element;
483 using AffineElement = typename Curve::AffineElement;
484
485 if (is_present[src_idx] == 0) {
486 return; // src is identity → no-op
487 }
488 if (is_present[dst_idx] == 0) {
489 buckets[dst_idx] = buckets[src_idx]; // dst was identity → just copy
490 is_present[dst_idx] = 1;
491 return;
492 }
493 // Edge case: dst.x == src.x. Since both points are on-curve, this means either
494 // dst == src (doubling case) or dst == -src (inverse case, result is identity).
495 // batch_affine_add_indexed_impl would invert zero here, so handle out-of-band.
496 if (buckets[dst_idx].x == buckets[src_idx].x) {
497 if (buckets[dst_idx].y == buckets[src_idx].y) {
498 // dst == src → result is 2 * dst.
499 Element doubled = Element(buckets[dst_idx]);
500 doubled.self_dbl();
501 buckets[dst_idx] = AffineElement{ doubled };
502 } else {
503 // dst == -src → result is identity.
504 buckets[dst_idx].self_set_infinity();
505 is_present[dst_idx] = 0;
506 }
507 return;
508 }
509 real_pairs[real_count++] = { dst_idx, src_idx };
510}
511
517[[gnu::always_inline]] inline void try_filter_idx(const uint8_t* is_present,
518 uint32_t idx,
519 uint32_t* real_indices,
520 size_t& real_count) noexcept
521{
522 if (is_present[idx] != 0) {
523 real_indices[real_count++] = idx;
524 }
525}
526
564template <typename Curve>
565void recursive_affine_bucket_reduce_strided(ThreadScratch<Curve>& s,
566 const AffineBucketChunkInfo* chunk_infos,
567 size_t windows_in_batch,
568 ChunkOutput<Curve>* outputs_base,
569 size_t output_stride) noexcept
570{
571 using AffineElement = typename Curve::AffineElement;
572 using Element = typename Curve::Element;
573
574 auto out_at = [outputs_base, output_stride](size_t w) -> ChunkOutput<Curve>& {
575 return outputs_base[w * output_stride];
576 };
577
578 if (windows_in_batch == 0) {
579 return;
580 }
581
582 // Stride is the caller's pre-sized layout width (`s.affine_bucket_stride`, set via
583 // `ensure_affine_bucket_capacity`). The densification step in the caller scattered buckets at
584 // `w * s.affine_bucket_stride + i`, so we MUST use the same value for our own indexing — any
585 // re-derivation that disagrees with the layout would index neighbouring windows. The
586 // pre-size already enforces `stride ≥ max_w(buckets_padded_w)` AND `stride ≥ 2` AND
587 // `stride is a power of two`, so the trivial-stride fast path and the 4-phase math
588 // both stay valid here. Per-window buckets_padded controls how many slots each window walks
589 // and is bounded by `stride` — verified below in debug.
590 const size_t stride = s.affine_bucket_stride;
591 bool any_nonempty = false;
592 for (size_t w = 0; w < windows_in_batch; ++w) {
593 if (chunk_infos[w].empty == 0) {
594 any_nonempty = true;
595 BB_ASSERT_LTE(chunk_infos[w].buckets_padded, stride);
596 }
597 }
598 if (!any_nonempty) {
599 for (size_t w = 0; w < windows_in_batch; ++w) {
600 out_at(w).R = Curve::Group::point_at_infinity;
601 out_at(w).L = Curve::Group::point_at_infinity;
602 }
603 return;
604 }
605
606 AffineElement* const buckets = s.dense_buckets.data();
607 uint8_t* const is_present = s.is_present.data();
608
609 // Pick L0 (the leaf-partition size). c0 = floor(log2(stride) / 2)
610 // gives L0 ≈ sqrt(stride) — balances Phase A batch size (W·D) vs Phase A iter count
611 // (L0 - 1). Both L0 and D = stride / L0 must be powers of two.
612 BB_ASSERT_GT(stride, size_t{ 0 });
613 const size_t c_log = static_cast<size_t>(std::countr_zero(stride));
614 BB_ASSERT_EQ(static_cast<size_t>(1) << c_log, stride);
615 // Trivial-stride fast paths. The 4-phase algorithm requires c_log ≥ 2 (so we can pick
616 // c0 ∈ [1, c_log - 1]) — fall back to direct computation for stride ∈ {1, 2}.
617 if (stride <= 2) {
618 for (size_t w = 0; w < windows_in_batch; ++w) {
619 if (chunk_infos[w].empty != 0) {
620 out_at(w).R = Curve::Group::point_at_infinity;
621 out_at(w).L = Curve::Group::point_at_infinity;
622 continue;
623 }
624 // Walk the (up to two) populated slots directly.
625 const size_t base = w * stride;
626 Element R = Curve::Group::point_at_infinity;
627 Element L = Curve::Group::point_at_infinity;
628 for (size_t i = 0; i < chunk_infos[w].buckets_padded; ++i) {
629 if (is_present[base + i] == 0) {
630 continue;
631 }
632 R += Element(buckets[base + i]);
633 L += Element(buckets[base + i]); // weight 1
634 if (i == 1) {
635 L += Element(buckets[base + i]); // weight 2 for i=1
636 }
637 }
638 out_at(w).R = R;
639 out_at(w).L = L;
640 }
641 return;
642 }
643
644 // Choose c0 = floor(c_log / 2), clamped so that 1 ≤ c0 ≤ c_log - 1.
645 size_t c0 = c_log / 2;
646 if (c0 == 0) {
647 c0 = 1;
648 }
649 if (c0 >= c_log) {
650 c0 = c_log - 1;
651 }
652 const size_t L0 = static_cast<size_t>(1) << c0;
653 const size_t D = stride >> c0; // == stride / L0
654 BB_ASSERT_EQ(L0 * D, stride);
655 BB_ASSERT_GTE(L0, size_t{ 2 });
656 BB_ASSERT_GTE(D, size_t{ 2 });
657
658 auto* const reals = s.affine_bucket_pairs.data();
659 auto* const dbl_reals = s.affine_bucket_indices.data();
660 auto* const inv_scratch = s.affine_bucket_inversion_scratch.data();
661
662 // Phase A: per-sub-partition running-sum (suffix sums).
663 // For each window w and each sub-partition d, walk slots from L0-1 down to 1 within the
664 // sub-partition, accumulating buckets[w*stride + d*L0 + l - 1] += buckets[... l]. All
665 // (w, d, l) triples for a fixed l share one batch-affine inversion (up to windows_in_batch
666 // · D pairs). Short windows (my_M_w < L0) are treated as a single sub-partition of length
667 // my_M_w to skip dead candidates; effective per-(w, d) length is min(L0, my_M_w - d·L0).
668 {
669 for (size_t l = L0 - 1; l >= 1; --l) {
670 size_t real_count = 0;
671 for (size_t w = 0; w < windows_in_batch; ++w) {
672 if (chunk_infos[w].empty != 0) {
673 continue;
674 }
675 const size_t my_M_w = chunk_infos[w].buckets_padded;
676 const size_t base = w * stride;
677 if (my_M_w < L0) {
678 // Short window: single sub-partition of effective length `my_M_w`.
679 if (l >= my_M_w) {
680 continue; // l is in the empty-padding region, skip
681 }
682 const uint32_t src = static_cast<uint32_t>(base + l);
683 const uint32_t dst = static_cast<uint32_t>(base + l - 1);
684 try_filter_pair<Curve>(buckets, is_present, dst, src, reals, real_count);
685 } else {
686 const size_t my_D = my_M_w >> c0; // ≥ 1
687 for (size_t d = 0; d < my_D; ++d) {
688 const uint32_t src = static_cast<uint32_t>(base + (d * L0) + l);
689 const uint32_t dst = static_cast<uint32_t>(base + (d * L0) + l - 1);
690 try_filter_pair<Curve>(buckets, is_present, dst, src, reals, real_count);
691 }
692 }
693 }
694 if (real_count > 0) {
695 bb::group_elements::batch_affine_add_indexed_impl<typename Curve::AffineElement,
696 typename Curve::BaseField>(
697 buckets, reals, real_count, inv_scratch);
698 }
699 }
700 }
701
702 // After Phase A, each window's slot 0 holds the simple sum of its sub-partition 0,
703 // and slot d*L0 (d ≥ 1) holds the simple sum of sub-partition d. The other slots within
704 // each sub-partition hold suffix sums that Phase D will combine.
705
706 // Phase B: log-recombine sub-partition simple sums into slot 0.
707 // For L1 = L0, 2*L0, 4*L0, ..., stride/2: pair (slot 2d*L1, slot (2d+1)*L1).
708 {
709 size_t L1 = L0;
710 while (L1 < stride) {
711 size_t real_count = 0;
712 const size_t step = 2 * L1;
713 for (size_t w = 0; w < windows_in_batch; ++w) {
714 if (chunk_infos[w].empty != 0) {
715 continue;
716 }
717 const size_t my_M = chunk_infos[w].buckets_padded;
718 if (step > my_M) {
719 continue;
720 }
721 const size_t base = w * stride;
722 const size_t num_pairs_w = my_M / step;
723 for (size_t d = 0; d < num_pairs_w; ++d) {
724 const uint32_t dst = static_cast<uint32_t>(base + ((2 * d) * L1));
725 const uint32_t src = static_cast<uint32_t>(base + (((2 * d) + 1) * L1));
726 try_filter_pair<Curve>(buckets, is_present, dst, src, reals, real_count);
727 }
728 }
729 if (real_count > 0) {
730 bb::group_elements::batch_affine_add_indexed_impl<typename Curve::AffineElement,
731 typename Curve::BaseField>(
732 buckets, reals, real_count, inv_scratch);
733 }
734 L1 *= 2;
735 }
736 }
737
738 // After Phase B, each window's slot 0 holds Σ_d B_{c,d} = R_c. Save R_c into outputs
739 // before Phase D's tree-add overwrites slot 0.
740 for (size_t w = 0; w < windows_in_batch; ++w) {
741 if (chunk_infos[w].empty != 0) {
742 out_at(w).R = Curve::Group::point_at_infinity;
743 continue;
744 }
745 const AffineElement& slot0 = buckets[w * stride];
746 if (is_present[w * stride] == 0) {
747 out_at(w).R = Curve::Group::point_at_infinity;
748 } else {
749 out_at(w).R = Element(slot0);
750 }
751 }
752
753 // Phase C: doublings.
754 // The candidate index list for the initial pass is constant across all c0 iters —
755 // every slot d*L0 for d ∈ [1, my_D - 1] in every non-empty window. Build the empty-
756 // filtered list once and chain c0 doublings on it instead of filtering c0 times.
757 // Subsequent levels (L1 = 2*L0, 4*L0, ...) do one doubling per level on level-specific
758 // index sets handled separately below.
759 {
760 size_t real_count = 0;
761 for (size_t w = 0; w < windows_in_batch; ++w) {
762 if (chunk_infos[w].empty != 0) {
763 continue;
764 }
765 const size_t my_M_w = chunk_infos[w].buckets_padded;
766 const size_t my_D = (my_M_w >= L0) ? (my_M_w >> c0) : size_t{ 0 };
767 const size_t base = w * stride;
768 for (size_t d = 1; d < my_D; ++d) {
769 try_filter_idx(is_present, static_cast<uint32_t>(base + (d * L0)), dbl_reals, real_count);
770 }
771 }
772 // c0 chained doublings on the same real list.
773 if (real_count > 0) {
774 for (size_t j = 0; j < c0; ++j) {
775 bb::group_elements::batch_affine_double_indexed_impl<typename Curve::AffineElement,
776 typename Curve::BaseField>(
777 buckets, dbl_reals, real_count, inv_scratch);
778 }
779 }
780 }
781 // Successive: at L1 = 2*L0, 4*L0, ..., stride/2: every d ≥ 1 in the sub-partition
782 // grid of size `stride / L1` gets one more doubling.
783 {
784 size_t L1 = 2 * L0;
785 while (L1 < stride) {
786 size_t real_count = 0;
787 for (size_t w = 0; w < windows_in_batch; ++w) {
788 if (chunk_infos[w].empty != 0) {
789 continue;
790 }
791 const size_t my_M = chunk_infos[w].buckets_padded;
792 if (L1 >= my_M) {
793 continue; // this window has no sub-partitions at this hierarchy
794 }
795 const size_t my_D1 = my_M / L1;
796 const size_t base = w * stride;
797 for (size_t d = 1; d < my_D1; ++d) {
798 try_filter_idx(is_present, static_cast<uint32_t>(base + (d * L1)), dbl_reals, real_count);
799 }
800 }
801 if (real_count > 0) {
802 bb::group_elements::batch_affine_double_indexed_impl<typename Curve::AffineElement,
803 typename Curve::BaseField>(
804 buckets, dbl_reals, real_count, inv_scratch);
805 }
806 L1 *= 2;
807 }
808 }
809
810 // Phase D: flat tree-add over the buckets_padded slots. For m = 1, 2, 4, ...,
811 // buckets_padded/2: pair (slot pos, slot pos+m) for pos = 0, 2m, 4m, ...
812 // Once the level's candidate count drops below BATCH_AFFINE_BREAKEVEN, the per-batch
813 // inversion overhead exceeds the projective per-add cost; bail and finish in Jacobian.
814 constexpr size_t BATCH_AFFINE_BREAKEVEN = 32;
815 size_t m = 1;
816 while (m < stride) {
817 // Live-slot count after this iter: stride / (2m) per window worst-case.
818 // Decision: would this iter's batch be too small? Estimate as
819 // `windows_in_batch * stride / (2m)` (upper bound on candidates).
820 const size_t est_cands_this_iter = windows_in_batch * (stride / (2 * m));
821 if (est_cands_this_iter < BATCH_AFFINE_BREAKEVEN) {
822 break;
823 }
824 size_t real_count = 0;
825 const size_t step = 2 * m;
826 for (size_t w = 0; w < windows_in_batch; ++w) {
827 if (chunk_infos[w].empty != 0) {
828 continue;
829 }
830 const size_t my_M = chunk_infos[w].buckets_padded;
831 if (m >= my_M) {
832 continue;
833 }
834 const size_t base = w * stride;
835 for (size_t pos = 0; pos + m < my_M; pos += step) {
836 try_filter_pair<Curve>(buckets,
838 static_cast<uint32_t>(base + pos),
839 static_cast<uint32_t>(base + pos + m),
840 reals,
841 real_count);
842 }
843 }
844 if (real_count > 0) {
845 bb::group_elements::batch_affine_add_indexed_impl<typename Curve::AffineElement, typename Curve::BaseField>(
846 buckets, reals, real_count, inv_scratch);
847 }
848 m *= 2;
849 }
850
851 // Write L_c. After Phase D's loop, `m` is the level NOT performed (or `stride` if all
852 // levels ran). The "live" slots — those holding cumulative tree-sums of consecutive m
853 // original buckets each — are {0, m, 2m, 3m, ...} ∩ [0, my_M):
854 // - loop completed (m == stride): only slot 0 is live; it holds the final L.
855 // - loop broke at level m: sum the live slots in Jacobian (live_step = m).
856 // - loop broke at m == 1: every original bucket is still live, sum them all.
857 // The Jacobian sum recovers what the unfinished levels would have computed in the
858 // batch-affine inner loop.
859 for (size_t w = 0; w < windows_in_batch; ++w) {
860 if (chunk_infos[w].empty != 0) {
861 out_at(w).L = Curve::Group::point_at_infinity;
862 continue;
863 }
864 const size_t base = w * stride;
865 const size_t my_M = chunk_infos[w].buckets_padded;
866 Element L = Curve::Group::point_at_infinity;
867 const size_t live_step = m; // distance between live slots after the affine phase
868 for (size_t pos = 0; pos < my_M; pos += live_step) {
869 if (is_present[base + pos] != 0) {
870 L += Element(buckets[base + pos]);
871 }
872 }
873 out_at(w).L = L;
874 }
875}
876
892template <typename Curve>
893[[gnu::always_inline]] inline typename Curve::Element chunk_contribution(const ChunkOutput<Curve>& chunk) noexcept
894{
895 using Element = typename Curve::Element;
896 if (chunk.empty != 0) {
897 return Curve::Group::point_at_infinity;
898 }
899 const uint32_t k = chunk.lo - 1;
900 Element acc = chunk.L;
901 if (k != 0) {
902 Element p = chunk.R;
903 uint32_t kk = k;
904 while (kk != 0) {
905 if ((kk & 1U) != 0) {
906 acc += p;
907 }
908 kk >>= 1;
909 if (kk != 0) {
910 p.self_dbl();
911 }
912 }
913 }
914 return acc;
915}
916
917} // namespace
918// `pippenger_round_parallel_jacobian_fast` has external linkage via the `extern template`
919// declarations in the header (used by the batched driver). Defined at namespace scope.
920
940template <typename Curve>
944 size_t min_pts_per_thread_override) noexcept
945{
946 using Element = typename Curve::Element;
947 using ScalarField = typename Curve::ScalarField;
948 using BaseField = typename Curve::BaseField;
949
950 const size_t n = scalars.size();
951 if (n == 0) {
952 return Curve::Group::point_at_infinity;
953 }
954
955 constexpr size_t NUM_BITS = ScalarField::modulus.get_msb() + 1;
956
957 // Cost-model window-size selection (mirrors MSM_fast<Curve>::get_optimal_log_num_buckets,
958 // with BUCKET_ACCUMULATION_COST = 5 = J-J-add-equiv-muls / J-A-add-equiv-muls ≈ 16/11
959 // rounded up). We do NOT delegate to the public method — keeping it self-contained
960 // avoids dragging the AffineAddition / AFFINE_TRICK_THRESHOLD machinery in here.
961 constexpr size_t BUCKET_ACCUMULATION_COST = 5;
962 constexpr uint32_t MAX_C = 18;
963 auto cost = [n](uint32_t bits) -> size_t {
964 size_t rounds = (NUM_BITS + bits - 1) / bits;
965 size_t buckets = size_t{ 1 } << bits;
966 return rounds * (n + buckets * BUCKET_ACCUMULATION_COST);
967 };
968 uint32_t window_bits = 1;
969 size_t best_cost = cost(1);
970 for (uint32_t b = 2; b <= MAX_C; ++b) {
971 const size_t this_cost = cost(b);
972 if (this_cost < best_cost) {
973 best_cost = this_cost;
974 window_bits = b;
975 }
976 }
977 const size_t num_buckets = size_t{ 1 } << window_bits;
978 const uint32_t num_rounds = static_cast<uint32_t>((NUM_BITS + window_bits - 1) / window_bits);
979 const uint32_t last_round_bits =
980 static_cast<uint32_t>(NUM_BITS - (static_cast<size_t>(num_rounds - 1) * window_bits));
981
982 // Each thread owns a num_buckets-sized scratch slice and runs num_rounds passes; below
983 // ~256 points per thread the parallel_for wakeup + per-call bucket reset dominate.
984 // wasm is forced single-threaded — its barrier cost is much higher than native.
985#ifdef __wasm__
986 constexpr size_t MIN_PTS_PER_THREAD_DEFAULT = SIZE_MAX;
987#else
988 constexpr size_t MIN_PTS_PER_THREAD_DEFAULT = 256;
989#endif
990 const size_t MIN_PTS_PER_THREAD =
991 (min_pts_per_thread_override == 0) ? MIN_PTS_PER_THREAD_DEFAULT : min_pts_per_thread_override;
992 const size_t max_threads = get_num_cpus();
993 size_t num_threads = std::min(std::max<size_t>(1, n / MIN_PTS_PER_THREAD), max_threads);
994 if (num_threads == 0) {
995 num_threads = 1;
996 }
997
998 // Allocate the per-thread bucket + presence scratch ONCE, indexed by tid inside the
999 // parallel_for. Allocating inside the lambda body would re-malloc on every call (and
1000 // on WASM the malloc cost is non-trivial relative to the arithmetic work at small n).
1001 std::vector<Element> per_thread_results(num_threads);
1002 std::vector<Element> all_buckets(num_threads * num_buckets);
1003 std::vector<uint8_t> all_present(num_threads * num_buckets);
1004
1005 auto thread_body = [&](size_t tid) {
1006 const size_t lo = (tid * n) / num_threads;
1007 const size_t hi = ((tid + 1) * n) / num_threads;
1008
1009 Element* const buckets = all_buckets.data() + (tid * num_buckets);
1010 uint8_t* const present = all_present.data() + (tid * num_buckets);
1011
1012 Element result = Curve::Group::point_at_infinity;
1013
1014 for (uint32_t round = 0; round < num_rounds; ++round) {
1015 std::memset(present, 0, num_buckets);
1016
1017 const size_t hi_bit = NUM_BITS - (static_cast<size_t>(round) * window_bits);
1018 const size_t lo_bit = (hi_bit < window_bits) ? size_t{ 0 } : (hi_bit - window_bits);
1019 const size_t actual_size = hi_bit - lo_bit;
1020 const size_t start_limb = lo_bit >> 6;
1021 const size_t end_limb = hi_bit >> 6;
1022 const size_t lo_off = lo_bit & 63;
1023 const size_t lo_bits = (64 - lo_off < actual_size) ? (64 - lo_off) : actual_size;
1024 const size_t hi_bits = actual_size - lo_bits;
1025 const uint64_t lo_mask = (lo_bits == 64) ? ~uint64_t{ 0 } : ((uint64_t{ 1 } << lo_bits) - 1);
1026 const uint64_t hi_mask = (hi_bits == 0) ? uint64_t{ 0 } : ((uint64_t{ 1 } << hi_bits) - 1);
1027
1028 for (size_t i = lo; i < hi; ++i) {
1029 const uint64_t s_lo = (scalars[i].data[start_limb] >> lo_off) & lo_mask;
1030 const uint64_t s_hi = (start_limb != end_limb) ? (scalars[i].data[end_limb] & hi_mask) : uint64_t{ 0 };
1031 const uint32_t slice = static_cast<uint32_t>(s_lo | (s_hi << lo_bits));
1032 if (slice == 0) {
1033 continue;
1034 }
1035 if (present[slice] == 0) {
1036 buckets[slice].x = points[i].x;
1037 buckets[slice].y = points[i].y;
1038 buckets[slice].z = BaseField::one();
1039 present[slice] = 1;
1040 } else {
1041 buckets[slice] += points[i];
1042 }
1043 }
1044
1045 // Running suffix sum over populated buckets only.
1046 // acc = Σ_{j ≥ i, present[j]} bucket[j]
1047 // bucket_sum = Σ_{i in [first_pop_low, top]} acc(i) = Σ_k k * bucket[k]
1048 // Bucket 0 carries no contribution and is never added.
1049 std::ptrdiff_t top = static_cast<std::ptrdiff_t>(num_buckets) - 1;
1050 while (top >= 1 && present[static_cast<size_t>(top)] == 0) {
1051 --top;
1052 }
1053 Element bucket_sum = Curve::Group::point_at_infinity;
1054 if (top >= 1) {
1055 Element acc = buckets[static_cast<size_t>(top)];
1056 bucket_sum = acc;
1057 for (std::ptrdiff_t i = top - 1; i >= 1; --i) {
1058 if (present[static_cast<size_t>(i)] != 0) {
1059 acc += buckets[static_cast<size_t>(i)];
1060 }
1061 bucket_sum += acc;
1062 }
1063 }
1064
1065 const uint32_t doublings = (round == num_rounds - 1) ? last_round_bits : window_bits;
1066 for (uint32_t d = 0; d < doublings; ++d) {
1067 result.self_dbl();
1068 }
1069 result += bucket_sum;
1070 }
1071
1072 per_thread_results[tid] = result;
1073 };
1074
1075 if (num_threads == 1) {
1076 thread_body(0);
1077 } else {
1078 bb::parallel_for(num_threads, thread_body);
1079 }
1080
1081 Element total = per_thread_results[0];
1082 for (size_t t = 1; t < num_threads; ++t) {
1083 total += per_thread_results[t];
1084 }
1085 return total;
1086}
1087
1088// PerWorkerArenaLayout (and its dependencies BATCH_CAPACITY, DEDUP_MAX_CHUNK_MEMBERS,
1089// AffineBucketChunkInfo) lives in `pippenger_arena_layout.hpp`. Used by the sizer
1090// below, the live allocator in `pippenger_round_parallel`, and the arena-layout
1091// regression test.
1092} // namespace round_parallel_detail
1093
1112
1113// Compute the exact arena bytes a single MSM_fast of `n_input` points will need.
1114// Mirrors the inline budget calculation inside `pippenger_round_parallel`.
1115// Returns 0 when N is small enough that we'll fall back to the Jacobian fast path
1116// (no affine arena needed). Exposed (declared in `scalar_multiplication_fast.hpp`)
1117// so the test suite can exercise the same sizer the live allocator uses.
1118template <typename Curve>
1119size_t compute_arena_bytes_for_msm(size_t n_input, bool external_glv_provided, bool dedup_active) noexcept
1120{
1121 using ScalarField = typename Curve::ScalarField;
1122 constexpr size_t FULL_NUM_BITS = ScalarField::modulus.get_msb() + 1;
1123
1124 if (n_input < 4) {
1125 return 0; // trivial path
1126 }
1127
1128 const bool use_glv = external_glv_provided || (n_input <= round_parallel_detail::GLV_SMALL_N_THRESHOLD);
1129 const size_t n = use_glv ? 2 * n_input : n_input;
1130 const size_t NUM_BITS = use_glv ? size_t{ 128 } : FULL_NUM_BITS;
1131 BB_ASSERT_LTE(n,
1133 "working scalar indices must fit in the 29-bit schedule payload");
1134
1139
1140 // window-bits selection uses the ideal per-window oversubscription factor (not the dispatch lmul).
1141 const size_t num_logical_threads_for_c = bb::get_num_cpus() * window_bits_tuning_oversub_factor(n_input);
1142 const size_t window_bits =
1143 round_parallel_detail::choose_window_bits(n, NUM_BITS, n_input, num_logical_threads_for_c);
1144 const size_t num_windows = (NUM_BITS + 2 + window_bits - 1) / window_bits;
1145 const size_t num_buckets = (size_t{ 1 } << (window_bits - 1)) + 1;
1146
1147 const size_t desired_threads = std::max<size_t>(1, bb::get_num_cpus());
1148 const size_t max_threads_for_min_batch = n / MIN_BATCH_CAPACITY;
1149 const size_t min_threads_allowed =
1150 std::max<size_t>(1, (desired_threads + MIN_AFFINE_THREAD_RATIO - 1) / MIN_AFFINE_THREAD_RATIO);
1151
1152 if (max_threads_for_min_batch < min_threads_allowed) {
1153 return 0; // jacobian-fast fallback, no affine arena
1154 }
1155
1156 const size_t num_threads = std::min(desired_threads, std::max<size_t>(1, max_threads_for_min_batch));
1157
1158 // num_threads sizes the per-task arrays; worker_total sizes the per-OS-thread scratch
1159 // (FIFO-shared by every task that lands on that OS thread).
1160 const size_t worker_total_for_budget = num_threads;
1161 const size_t dense_stride_est = round_parallel_detail::compute_dense_stride(num_buckets, num_threads);
1162
1163 // Pre-schedule conservative per-window cost: uses `num_buckets` (= 2^(c-1)+1) as the
1164 // B upper bound. The lambda below recomputes once the actual schedule is built.
1165 const size_t per_window_bytes = round_parallel_detail::compute_per_window_bytes<Curve>(
1166 num_threads, num_buckets, n, dense_stride_est, worker_total_for_budget);
1167
1168 const size_t global_max_overflow_per_window =
1169 round_parallel_detail::compute_global_max_overflow_per_window(n, num_threads, SUBCHUNK_ENTRIES_CAP);
1170
1171 const bool inline_glv_double = use_glv && !external_glv_provided;
1172 const size_t profile_threads = std::max<size_t>(1, bb::get_num_cpus());
1173 const size_t phase_one_prologue_bytes =
1174 round_parallel_detail::compute_phase_one_prologue_bytes(n, use_glv, inline_glv_double, profile_threads);
1175
1176 const auto phase_a_caps = round_parallel_detail::compute_phase_a_caps(n, num_threads);
1177 const size_t phase_a_cluster_members_cap = phase_a_caps.members_cap;
1178 const size_t phase_a_cluster_offsets_cap = phase_a_caps.offsets_cap;
1179
1180 // Zone W per-worker UNION via the canonical layout walk. Stage 6a, Stage 6b, and
1181 // Phase A overlay the same per-worker bytes; the struct returns the max-of-layouts
1182 // (the Stage 6 wpb-dependent tail is added below once `windows_per_batch` is known).
1183 // Passing `windows_per_batch = 0` here skips the tail — we only need the union bytes
1184 // for the fixed_overhead → wpb solve.
1185 const round_parallel_detail::PerWorkerArenaLayout<Curve> union_layout(/*chunk_capacity=*/SUBCHUNK_ENTRIES_CAP,
1186 global_max_overflow_per_window,
1187 dedup_active,
1188 phase_a_cluster_members_cap,
1189 phase_a_cluster_offsets_cap,
1190 /*windows_per_batch=*/0,
1191 /*dense_stride_est=*/0);
1192 const size_t worker_union_bytes = union_layout.per_worker_union_bytes;
1193
1194 const size_t fixed_overhead = (worker_union_bytes * worker_total_for_budget) +
1195 (size_t{ 96 } * round_parallel_detail::VAR_WINDOW_MAX_WINDOWS) // window_sums_storage
1196 + (size_t{ 8 } * (num_threads + 1)) // rebalanced_bucket_lo_partition
1197 + phase_one_prologue_bytes;
1198
1199 // wpb fallback when fixed_overhead has eaten the BATCH_MEM_BUDGET headroom: the inline
1200 // `solve_wpb` in `pippenger_round_parallel` returns `W_R` (the whole region) — running
1201 // every window in a single batch — when `available_budget == 0`. Previously the sizer
1202 // returned `wpb = 1` and relied on a `worst_case_arena = BATCH_MEM_BUDGET + 32K` floor;
1203 // that floor failed for large num_threads where fixed_overhead alone exceeds the budget.
1204 const size_t available_budget_outer =
1205 (BATCH_MEM_BUDGET > fixed_overhead) ? (BATCH_MEM_BUDGET - fixed_overhead) : size_t{ 0 };
1206 const size_t windows_per_batch =
1207 round_parallel_detail::solve_wpb(per_window_bytes, available_budget_outer, num_windows);
1208 // Dedup state lives in the arena (allocated post-Phase-1, retained through Stage 6a).
1209 // Worst-case sizes: redirect_lookup is one uint32 per working scalar (4n bytes);
1210 // extra_points is the fixed DEDUP_MAX_CLUSTERS cap (≈1 MB) regardless of n.
1211 const size_t dedup_bytes = dedup_active ? ((size_t{ 4 } * n) + (size_t{ sizeof(typename Curve::AffineElement) } *
1213 : size_t{ 0 };
1214 auto arena_bytes_for_window_layout = [&](size_t bit_budget) {
1215 const size_t wb = round_parallel_detail::choose_window_bits(n, bit_budget, n_input, num_logical_threads_for_c);
1216 const auto layout_sched = round_parallel_detail::build_var_window_schedule(bit_budget, wb);
1217 size_t B_eff_layout = (size_t{ 1 } << (wb - 1)) + 1;
1218 for (size_t w = 0; w < layout_sched.num_windows; ++w) {
1219 B_eff_layout = std::max(B_eff_layout, static_cast<size_t>(layout_sched.num_buckets[w]));
1220 }
1221 const size_t dense_stride_layout = round_parallel_detail::compute_dense_stride(B_eff_layout, num_threads);
1222 const size_t per_window_bytes_layout = round_parallel_detail::compute_per_window_bytes<Curve>(
1223 num_threads, B_eff_layout, n, dense_stride_layout, worker_total_for_budget);
1224
1225 const size_t available_budget =
1226 (BATCH_MEM_BUDGET > fixed_overhead) ? (BATCH_MEM_BUDGET - fixed_overhead) : size_t{ 0 };
1227 const size_t wpb = round_parallel_detail::solve_wpb(
1228 per_window_bytes_layout, available_budget, static_cast<size_t>(layout_sched.num_windows));
1229 return fixed_overhead + (wpb * per_window_bytes_layout) + 32768 + dedup_bytes;
1230 };
1231
1232 // Tight return: the arena holds `fixed_overhead + wpb · per_window_bytes` of typed
1233 // buffers plus a 32 KiB alignment pad and the dedup state (when active). Sizing
1234 // tightly — rather than padding up to BATCH_MEM_BUDGET — matters for many-MSM_fast flows
1235 // (e.g. PerMsmChonk's 256 separate per-circuit MSMs) where every per-MSM_fast
1236 // `make_unique_for_overwrite<std::byte[]>` mmap/munmaps the buffer above glibc's
1237 // M_MMAP_THRESHOLD; a 32 MiB floor here would tax every MSM_fast with the page-fault
1238 // first-touch cost regardless of how much of the arena the small MSM_fast actually uses.
1239 size_t arena_bytes = fixed_overhead + (windows_per_batch * per_window_bytes) + 32768 + dedup_bytes;
1240
1241 // The live pipeline shrinks NUM_BITS to the observed max scalar bit before choosing
1242 // window_bits. GLV MSMs and large non-GLV MSMs can therefore select a different
1243 // schedule/zone layout than the full-bit pre-sizer. Keep the common Chonk wire/IPA
1244 // non-GLV sizes on the original tight path.
1245 if (use_glv || n_input >= (size_t{ 1 } << 17)) {
1246 for (size_t bit_budget = 1; bit_budget <= NUM_BITS; ++bit_budget) {
1247 arena_bytes = std::max(arena_bytes, arena_bytes_for_window_layout(bit_budget));
1248 }
1249 }
1250 return arena_bytes;
1251}
1252
1253// Round-parallel Pippenger MSM_fast.
1254// `external_glv_doubled` — optional caller-supplied [P_0, φP_0, …, P_{n-1}, φP_{n-1}]
1255// buffer (length 2·n_input). When non-empty, forces use_glv=true and skips the
1256// internal doubling pass. The interleaved layout means longer-prefix aliasing
1257// (length 2·Nmax) is valid for any n ≤ Nmax with no copy.
1258// `external_arena` — optional caller-supplied scratch buffer ≥ this MSM_fast's required
1259// bytes. When empty, allocate per-MSM_fast via make_unique_for_overwrite and free at
1260// return. The batched driver supplies a single arena sized to the largest member.
1261template <typename Curve>
1262// NOLINTNEXTLINE(readability-function-size, readability-function-cognitive-complexity,
1263// google-readability-function-size)
1266 bool dedup_hint,
1268 std::span<std::byte> external_arena) noexcept
1269{
1270 using Element = typename Curve::Element;
1271 using AffineElement = typename Curve::AffineElement;
1272 using ScalarField = typename Curve::ScalarField;
1273 using BaseField = typename Curve::BaseField;
1274
1275 const size_t n_input = scalars_span.size();
1276 if (n_input == 0) {
1277 return Curve::Group::point_at_infinity;
1278 }
1279
1280 // Bail to trivial_msm_threaded when each worker would own fewer than
1281 // MIN_PTS_PER_THREAD_FOR_PIPPENGER points — pippenger_fast's per-window scaffolding loses
1282 // to straus_msm at this density. Caller-supplied GLV doubling is wasted at this size,
1283 // but the overhead is negligible.
1284 {
1285 const size_t max_threads = bb::get_num_cpus();
1286 const size_t num_threads_dispatch = std::max<size_t>(1, std::min(n_input, max_threads));
1287 const size_t pts_per_thread = (n_input + num_threads_dispatch - 1) / num_threads_dispatch;
1288 if (pts_per_thread < MIN_PTS_PER_THREAD_FOR_PIPPENGER) {
1289 return trivial_msm_threaded<Curve>(scalars_span, all_points);
1290 }
1291 }
1292
1293 BB_ASSERT_GTE(all_points.size(), scalars_span.start_index + n_input);
1294 std::span<const AffineElement> input_points(&all_points[scalars_span.start_index], n_input);
1295
1296 constexpr size_t FULL_NUM_BITS = ScalarField::modulus.get_msb() + 1;
1297
1298 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1299 ScalarField* scalar_ptr = const_cast<ScalarField*>(&scalars_span[scalars_span.start_index]);
1300 std::span<ScalarField> input_scalars(scalar_ptr, n_input);
1301
1302 // GLV: split k ≡ k1 − k2·λ (mod r), giving 2n pairs at NUM_BITS=128. Halves num_windows;
1303 // costs an extra n point doubles. Applied only below GLV_SMALL_N_THRESHOLD where the
1304 // win-on-windows beats the lose-on-doubled-scan, OR forced on by the batched dispatcher
1305 // supplying `external_glv_doubled` (it amortises the doubling across the whole batch).
1306 // Empirical crossover (best-of-3 sweep at HC=16, P ∈ {4, 8, 16}): wasmtime keeps GLV up
1307 // to n=2^16; native to n=2^13 (clang's branchless bias-decode is fast enough that the 2×
1308 // point-count cost dominates above that). Threshold is platform-conditional in the
1309 // hoisted GLV_SMALL_N_THRESHOLD declaration.
1310 const bool external_glv_provided = !external_glv_doubled.empty();
1311 const bool use_glv = external_glv_provided || n_input <= round_parallel_detail::GLV_SMALL_N_THRESHOLD;
1312
1313 // Stage 6 splits into 6a (per-thread bucket partials over the contiguous-by-schedule-
1314 // index partition) and 6b (cross-thread bucket reduction over a uniform-width digit
1315 // slice). Small MSMs short-circuit to trivial_msm_threaded above this point.
1316
1317 // n is the working scalar/point count (GLV doubles it); NUM_BITS is the post-recoding
1318 // window-bit budget (128 for GLV, FULL_NUM_BITS otherwise).
1319 const size_t n = use_glv ? (2 * n_input) : n_input;
1320 const size_t NUM_BITS = use_glv ? size_t{ 128 } : FULL_NUM_BITS;
1321 BB_ASSERT_LTE(n,
1323 "working scalar indices must fit in the 29-bit schedule payload");
1324 std::span<ScalarField> scalars;
1325 std::span<const AffineElement> points;
1326 const bool inline_glv_double = use_glv && !external_glv_provided;
1327
1328 // Activation gate: caller-supplied hint opts this MSM_fast into the dedup pre-pass.
1329 // Hint-driven so polynomials with low duplicate density (PC counters, range checks)
1330 // skip the O(n) tagging cost. The small-n bail above (pts_per_thread <
1331 // MIN_PTS_PER_THREAD_FOR_PIPPENGER) already shed every case where dedup wouldn't fit
1332 // — n ≥ MIN_PTS_PER_THREAD_FOR_PIPPENGER * 1 = 24 here.
1333 const bool dedup_active = dedup_hint;
1334
1335 // ---------------------------------------------------------------------------------------
1336 // Arena setup (pre-Phase-1).
1337 //
1338 // The per-MSM_fast arena is allocated BEFORE Phase 1 so the Phase 1 prologue (msb_per_scalar,
1339 // glv_*_storage, per_thread_msb_hist) lives inside the arena instead of on the heap.
1340 // Once Phase 1 finishes and the window schedule is known (T, B_eff, dense_stride, wpb),
1341 // we partition the remaining capacity into three named zones
1342 // (Zone P / Zone W / Zone S) — see the "Arena zone layout" block after the wpb solve.
1343 //
1344 // We size the buffer using `compute_arena_bytes_for_msm`, whose conservative bound
1345 // dominates the inline-tight (P + W + S) sum for any wpb we choose below.
1346 // ---------------------------------------------------------------------------------------
1347 const size_t arena_total_bytes = compute_arena_bytes_for_msm<Curve>(n_input, external_glv_provided, dedup_active);
1348 round_parallel_detail::MsmArena arena(arena_total_bytes, external_arena);
1349
1350 // ---------------------------------------------------------------------------------------
1351 // Phase 1 — convert scalars from Montgomery, optionally GLV-split, populate msb buffer.
1352 // The msb_per_scalar buffer feeds max-msb num_windows selection;
1353 // per-thread msb_hist counts (bin 0 = zero, bin k+1 = msb == k) feed the n_active gate
1354 // and the active-scalar gate.
1355 //
1356 // When dedup is active the per-scalar dedup work (hash + linear-probe shared atomic
1357 // table, per-thread dup_pair recording) is fused into the same per-thread loop so
1358 // scalars stay hot in L1 between from-Mont and the hash. The post-pass (sort, cluster
1359 // build, chunked tree-reduce, redirect_lookup) runs sequentially after the parallel_for
1360 // — see `dedup_finalize_parallel`.
1361 // ---------------------------------------------------------------------------------------
1362 using round_parallel_detail::MSB_ZERO_SENTINEL;
1363 const size_t profile_threads = std::max<size_t>(1, bb::get_num_cpus());
1364 auto msb_per_scalar = arena.template alloc<uint8_t>(n);
1365 auto per_thread_msb_hist = arena.template alloc<std::array<uint32_t, 256>>(profile_threads);
1366 // MsmArena::alloc returns uninitialised memory; the histograms must be zero-initialised so
1367 // record_msb's increments land on a clean slate.
1368 std::fill_n(per_thread_msb_hist.data(), profile_threads, std::array<uint32_t, 256>{});
1369
1370 // GLV storage (optional). `glv_scalars_storage` is the GLV-split working scalar buffer;
1371 // `glv_points_storage` is the inline-doubled point buffer (skipped when the caller
1372 // supplied an external doubled buffer). Both span empty when `use_glv` is false.
1373 std::span<ScalarField> glv_scalars_storage;
1374 std::span<AffineElement> glv_points_storage;
1375 if (use_glv) {
1376 glv_scalars_storage = arena.template alloc<ScalarField>(n);
1377 if (inline_glv_double) {
1378 glv_points_storage = arena.template alloc<AffineElement>(n);
1379 } else {
1380 BB_ASSERT_EQ(external_glv_doubled.size(), n);
1381 }
1382 }
1383
1384 if (use_glv) {
1385 // Convert each input scalar from-Mont into a stack local, GLV-split it, store both
1386 // 128-bit halves and their msb into the profile buffer. input_scalars is read-only on
1387 // this path so the user's buffer is preserved (no Montgomery restore needed). Inline
1388 // path additionally GLV-doubles the points in the same parallel pass; external path
1389 // aliases the caller-supplied doubled buffer.
1390 const BaseField beta = inline_glv_double ? BaseField::cube_root_of_unity() : BaseField{};
1391 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
1392 auto& th_hist = per_thread_msb_hist[chunk.thread_index];
1393 for (size_t i : chunk.range(n_input)) {
1394 const ScalarField canonical = input_scalars[i].from_montgomery_form_reduced();
1395 const auto split = ScalarField::split_into_endomorphism_scalars(canonical);
1396 const auto& k1 = split.first;
1397 const auto& k2 = split.second;
1398 glv_scalars_storage[2 * i].data[0] = k1[0];
1399 glv_scalars_storage[2 * i].data[1] = k1[1];
1400 glv_scalars_storage[2 * i].data[2] = 0;
1401 glv_scalars_storage[2 * i].data[3] = 0;
1402 glv_scalars_storage[(2 * i) + 1].data[0] = k2[0];
1403 glv_scalars_storage[(2 * i) + 1].data[1] = k2[1];
1404 glv_scalars_storage[(2 * i) + 1].data[2] = 0;
1405 glv_scalars_storage[(2 * i) + 1].data[3] = 0;
1406 if (inline_glv_double) {
1407 glv_points_storage[2 * i] = input_points[i];
1408 glv_points_storage[(2 * i) + 1].x = input_points[i].x * beta;
1409 glv_points_storage[(2 * i) + 1].y = -input_points[i].y;
1410 }
1411 round_parallel_detail::record_msb(
1412 round_parallel_detail::msb_of_2limb(k1[0], k1[1]), msb_per_scalar[2 * i], th_hist);
1413 round_parallel_detail::record_msb(
1414 round_parallel_detail::msb_of_2limb(k2[0], k2[1]), msb_per_scalar[(2 * i) + 1], th_hist);
1415 }
1416 });
1417 points =
1418 inline_glv_double ? std::span<const AffineElement>(glv_points_storage.data(), n) : external_glv_doubled;
1419 scalars = glv_scalars_storage;
1420 } else {
1421 // Non-GLV path: in-place from-Mont (later restored in the Stage-7 epilogue).
1422 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
1423 auto& th_hist = per_thread_msb_hist[chunk.thread_index];
1424 for (size_t i : chunk.range(n_input)) {
1425 input_scalars[i].self_from_montgomery_form_reduced();
1426 round_parallel_detail::record_msb(
1427 round_parallel_detail::msb_of_4limb(input_scalars[i].data), msb_per_scalar[i], th_hist);
1428 }
1429 });
1430 scalars = input_scalars;
1431 points = input_points;
1432 }
1433
1434 std::array<uint64_t, 256> msb_hist{};
1435 for (size_t t = 0; t < profile_threads; ++t) {
1436 for (size_t b = 0; b < 256; ++b) {
1437 msb_hist[b] += per_thread_msb_hist[t][b];
1438 }
1439 }
1440 const size_t n_active_early = n - static_cast<size_t>(msb_hist[0]);
1441
1442 // ---------------------------------------------------------------------------------------
1443 // Phase 2 — bail to trivial_msm_threaded when n_active is too small to amortise pippenger_fast's
1444 // per-window scaffolding. trivial_msm_threaded -> straus_msm wants Montgomery scalars, so
1445 // re-Mont-form them in parallel before dispatching.
1446 // ---------------------------------------------------------------------------------------
1447 {
1448 const size_t max_threads_dispatch = bb::get_num_cpus();
1449 const size_t threads_for_dispatch = std::max<size_t>(1, std::min(n_active_early, max_threads_dispatch));
1450 const size_t pts_per_thread = (n_active_early + threads_for_dispatch - 1) / threads_for_dispatch;
1451 if (pts_per_thread < MIN_PTS_PER_THREAD_FOR_PIPPENGER) {
1452 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
1453 for (size_t i : chunk.range(n)) {
1454 scalars[i].self_to_montgomery_form();
1455 }
1456 });
1457 std::span<const ScalarField> scalars_const(scalars.data(), n);
1458 PolynomialSpan<const ScalarField> ps(0, scalars_const);
1459 return trivial_msm_threaded<Curve>(ps, points);
1460 }
1461 }
1462
1463 // ---------------------------------------------------------------------------------------
1464 // Phase 3 — pick the window layout, build the schedule, run the pipeline, sum into the result.
1465 // ---------------------------------------------------------------------------------------
1466 const size_t num_logical_threads_for_c = bb::get_num_cpus() * window_bits_tuning_oversub_factor(n_input);
1467
1468 // Shrink the bit budget to the highest non-empty msb_hist bin so num_windows is determined
1469 // by the actual data, not the conservative GLV / FULL_NUM_BITS bound.
1470 size_t effective_num_bits = 0;
1471 for (size_t bin = 256; bin > 1;) {
1472 --bin;
1473 if (msb_hist[bin] != 0) {
1474 effective_num_bits = bin;
1475 break;
1476 }
1477 }
1478 if (effective_num_bits == 0 || effective_num_bits > NUM_BITS) {
1479 effective_num_bits = NUM_BITS;
1480 }
1481 const size_t window_bits =
1482 round_parallel_detail::choose_window_bits(n, effective_num_bits, n_input, num_logical_threads_for_c);
1483 const size_t num_buckets = (size_t{ 1 } << (window_bits - 1)) + 1;
1484
1485 // Schedule-based dedup state. The two arrays are allocated from the per-MSM_fast arena
1486 // *from the arena after Phase 1.
1487 // Until then, both spans are empty.
1488 // Lifetimes:
1489 // redirect_lookup — written by Phase A; read by Stage 4b's dedup_patch_schedule per batch
1490 // extra_points — written by Phase A; read by Stage 6a's reduce_chunk per batch
1491 // Both must survive until the last Stage 6a, so they sit in the arena (which is freed
1492 // when this function returns).
1494
1495 // Variable-window split was removed from the production path after Chonk traces showed
1496 // it regressing this rewrite. Keep the schedule uniform and run one region over all
1497 // non-zero scalars.
1498 const auto sched = round_parallel_detail::build_var_window_schedule(effective_num_bits, window_bits);
1499 BB_ASSERT_LTE(sched.num_windows,
1501 "window schedule exceeds compile-time max window count");
1502
1507
1508 // Thread count: aim for `lmul × physical_cpus` logical tasks so the rpmsm pool can
1509 // FIFO-balance heterogeneous P/E cores; cap at `n / MIN_BATCH_CAPACITY` so each chunk
1510 // can saturate the batched-affine drains. `bb::get_num_cpus() <= 1` is the chonk
1511 // batch-verifier's signal that outer parallelism owns all cores — run sequentially.
1512 const size_t desired_threads = std::max<size_t>(1, bb::get_num_cpus());
1513 const size_t max_threads_for_min_batch = std::max<size_t>(1, n / MIN_BATCH_CAPACITY);
1514 const size_t num_threads = std::min(desired_threads, max_threads_for_min_batch);
1515
1516 // Stage 6's tree-reduce splits each thread's chunk into sub-chunks of at most
1517 // SUBCHUNK_ENTRIES_CAP entries before calling reduce_chunk, bounding per-thread scratch
1518 // independent of n. 2048 keeps level-0 saturated (≥ 4 BATCH_CAPACITY drains at typical
1519 // c=16) while the deepest level still hits BATCH_AFFINE_BREAKEVEN (~32 pairs); halving
1520 // breaks the deep levels and doubling wastes memory.
1521 // Pick windows_in_batch so per-MSM_fast working set fits in ~32 MB. Empirically 32 MB
1522 // performs as well as 128 MB on the WASM grid (the recursive affine bucket reduction
1523 // recovers most of the small-batch loss).
1524 // The per_window_bytes / fixed_overhead formulas below mirror this enum of allocations
1525 // exactly. Anyone adding an arena buffer must update both the alloc and the corresponding
1526 // term in those formulas, otherwise windows_per_batch drifts off the BATCH_MEM_BUDGET.
1527
1528 // Per-(w, t) slot stride must fit the widest schedule window.
1529 size_t B_eff = num_buckets;
1530 for (size_t w = 0; w < sched.num_windows; ++w) {
1531 B_eff = std::max(B_eff, static_cast<size_t>(sched.num_buckets[w]));
1532 }
1533
1534 const size_t worker_total_for_budget = num_threads;
1535 const size_t dense_stride_est = round_parallel_detail::compute_dense_stride(B_eff, num_threads);
1536 const size_t bucket_partials_per_window_max =
1538 const size_t per_window_bytes_lo = round_parallel_detail::compute_per_window_bytes<Curve>(
1539 num_threads, B_eff, n, dense_stride_est, worker_total_for_budget);
1540
1541 const size_t global_max_overflow_per_window_for_budget =
1542 round_parallel_detail::compute_global_max_overflow_per_window(n, num_threads, SUBCHUNK_ENTRIES_CAP);
1543
1544 const size_t phase_one_prologue_bytes =
1545 round_parallel_detail::compute_phase_one_prologue_bytes(n, use_glv, inline_glv_double, profile_threads);
1546
1547 const auto phase_a_caps = round_parallel_detail::compute_phase_a_caps(n, num_threads);
1548 const size_t phase_a_cluster_members_cap = phase_a_caps.members_cap;
1549 const size_t phase_a_cluster_offsets_cap = phase_a_caps.offsets_cap;
1550
1551 // Zone W per-worker UNION via the canonical layout walk. The wpb-dependent Stage 6
1552 // tail is added separately after `windows_per_batch` is solved; here we only need
1553 // the union bytes for the fixed_overhead → wpb budget.
1555 /*chunk_capacity=*/SUBCHUNK_ENTRIES_CAP,
1556 global_max_overflow_per_window_for_budget,
1557 dedup_active,
1558 phase_a_cluster_members_cap,
1559 phase_a_cluster_offsets_cap,
1560 /*windows_per_batch=*/0,
1561 /*dense_stride_est=*/0);
1562 const size_t worker_union_bytes_for_budget = budget_layout.per_worker_union_bytes;
1563
1564 const size_t fixed_overhead = (worker_union_bytes_for_budget * worker_total_for_budget) +
1565 (size_t{ 96 } * round_parallel_detail::VAR_WINDOW_MAX_WINDOWS) // window_sums_storage
1566 + (size_t{ 8 } * (num_threads + 1)) // rebalanced_bucket_lo_partition
1567 + phase_one_prologue_bytes;
1568
1569 // Solve `wpb · per_window_bytes ≤ BATCH_MEM_BUDGET − fixed_overhead`.
1570 const size_t available_budget =
1571 (BATCH_MEM_BUDGET > fixed_overhead) ? (BATCH_MEM_BUDGET - fixed_overhead) : size_t{ 0 };
1572 const size_t windows_per_batch =
1573 round_parallel_detail::solve_wpb(per_window_bytes_lo, available_budget, sched.num_windows);
1574
1575 // Per-thread chunk-capacity scratch sizing. A thread's per-window slice is split into
1576 // sub-chunks of at most SUBCHUNK_ENTRIES_CAP entries. Worst-case overflow per
1577 // (thread, window) is one partial per sub-chunk boundary that lands mid-run, bounded
1578 // above by `ceil(max_chunk_len / SUBCHUNK_ENTRIES_CAP)` where max_chunk_len ≤ n/T.
1579 // The Stage 6a end-of-window overflow merge runs tree_reduce on `2 × overflow` entries
1580 // (each affected slot contributes a dense head + ≥1 overflow entry). Tree-reduce
1581 // scratch must fit either a sub-chunk's reduce_chunk input (up to SUBCHUNK_ENTRIES_CAP)
1582 // or a full overflow merge — take the max.
1583 const size_t global_max_chunk_len = (n + num_threads - 1) / num_threads;
1584 const size_t global_max_overflow_per_window =
1585 (global_max_chunk_len + SUBCHUNK_ENTRIES_CAP - 1) / SUBCHUNK_ENTRIES_CAP;
1586 const size_t chunk_capacity = std::max(SUBCHUNK_ENTRIES_CAP, 2 * global_max_overflow_per_window);
1587
1588 // Per-OS-thread scratch. The rpmsm pool dispatches `num_threads` logical tasks across
1589 // `worker_total = num_threads = physical_cpus` OS threads. Tasks on the same
1590 // OS thread run sequentially (FIFO claim), so they share scratch — every field in
1591 // ThreadScratch is overwritten fresh at task start, never read across tasks. Indexing
1592 // by `worker_id` (rather than `tid`) keeps memory linear in physical_cpus instead of
1593 // num_threads = lmul × physical_cpus.
1594 const size_t worker_total = num_threads;
1595 std::vector<round_parallel_detail::ThreadScratch<Curve>> thread_scratch(worker_total);
1597 if (dedup_active) {
1598 phase_a_scratch.resize(worker_total);
1599 }
1600
1601 // ---------------------------------------------------------------------------------------
1602 // Arena zone layout — set up after Phase 1 and schedule selection (see
1603 // https://gist.github.com/AztecBot/7c5ef0581350f6fdb9711679552fd86f §1, §4, §5).
1604 //
1605 // [0 .. bytes_P) Zone P — whole-MSM_fast permanent
1606 // msb_per_scalar (already alloc'd above)
1607 // glv_scalars / glv_points (already alloc'd above)
1608 // per_thread_msb_hist (already alloc'd above)
1609 // window_sums (Stage 7 accumulator)
1610 // redirect_lookup, extra_points (dedup, if active)
1611 // [bytes_P .. bytes_P + bytes_W) Zone W — per-worker union slab × T
1612 // Stage 6a/6b ThreadScratch fields and PhaseA
1613 // scratch overlay the same per-worker bytes; the
1614 // wpb-dependent Stage 6 fields sit immediately
1615 // after the union. Stage 6a, Stage 6b, and Phase A
1616 // run in distinct parallel_for invocations and
1617 // never co-exist on a worker.
1618 // [bytes_P + bytes_W .. arena.capacity)
1619 // Zone S — per-batch swing region (schedule, HIST slot,
1620 // DENSE slot, partition metadata).
1621 // HIST slot overlays H ↔ O on one byte slab:
1622 // H (S1-S4): digit_cursors
1623 // O (S6b-S7): chunk_outputs/window_partial_sums
1624 // Slot per-window = max(H, O). At chonk this is
1625 // H-bound (~256 KiB/window).
1626 // DENSE slot is dedicated for D (S6a-S6b):
1627 // bucket_partials_dense / _present
1628 // (~135 KiB/window at chonk). The D-class was
1629 // moved out of the HIST slot to eliminate L1
1630 // cache aliasing on the Stage 6a scatter writes
1631 // (+1.29% regression observed when D was overlaid
1632 // at the HIST offset).
1633 //
1634 // wpb solve: BATCH_MEM_BUDGET - bytes_P - bytes_W_fixed - bytes_S_shared - 32 KiB pad,
1635 // divided by (bytes_S_per_window + bytes_W_per_wpb). per_window_bytes_shared accounts
1636 // for HIST + DENSE as two separate slots.
1637 // ---------------------------------------------------------------------------------------
1638
1639 // Freeze Zone P prefix at the post-Phase-1 cursor — everything allocated so far
1640 // (msb_per_scalar, glv storage, per_thread_msb_hist) is Zone P permanent state.
1641 const size_t bytes_P_prefix = arena.cursor;
1642
1643 // Per-worker fixed-bytes "union": ThreadScratch's wpb-independent fields overlay the
1644 // PhaseAScratch fields. Compute each layout's strict byte requirement (including the
1645 // alignment slop a bump cursor would consume), then take the max.
1646 auto align_up = [](size_t off, size_t align) -> size_t { return (off + align - 1) & ~(align - 1); };
1647 auto layout_add = [&](size_t& off, size_t bytes, size_t align) { off = align_up(off, align) + bytes; };
1648
1649 // Per-worker layout via the canonical walk (single source of truth shared with
1650 // `compute_arena_bytes_for_msm`). Pre-wpb-solve usage there passes wpb=0; here we
1651 // pass the actual windows_per_batch so the Stage 6 wpb-dependent tail is included.
1652 const round_parallel_detail::PerWorkerArenaLayout<Curve> worker_layout(chunk_capacity,
1653 global_max_overflow_per_window,
1654 dedup_active,
1655 phase_a_cluster_members_cap,
1656 phase_a_cluster_offsets_cap,
1657 windows_per_batch,
1658 dense_stride_est);
1660 const size_t per_worker_union_bytes = worker_layout.per_worker_union_bytes;
1661 const size_t per_worker_bytes = worker_layout.per_worker_bytes;
1662
1663 // Zone P extra (post-decision permanent state): window_sums + dedup state. Sized
1664 // with the strict alignment a bump cursor would apply.
1665 constexpr size_t VAR_WINDOW_WINDOW_SUMS_CAP = round_parallel_detail::VAR_WINDOW_MAX_WINDOWS;
1666 size_t bytes_P_extra_layout = 0;
1667 layout_add(bytes_P_extra_layout, sizeof(Element) * VAR_WINDOW_WINDOW_SUMS_CAP, alignof(Element));
1668 if (dedup_active) {
1669 layout_add(bytes_P_extra_layout, sizeof(uint32_t) * n, alignof(uint32_t));
1670 layout_add(bytes_P_extra_layout,
1671 sizeof(AffineElement) * round_parallel_detail::DEDUP_MAX_CLUSTERS,
1672 alignof(AffineElement));
1673 }
1674
1675 // Zone sizes. The Zone W slab uses `MsmArena::bump_alloc` which aligns in ABSOLUTE address
1676 // space (the arena buffer base is only `__STDCPP_DEFAULT_NEW_ALIGNMENT__`-aligned, but
1677 // AffineElement is alignas(64)). To make the per-worker layout match the layout-only
1678 // calc (which assumes the slab starts on a 64-byte boundary), bias bytes_P so the
1679 // absolute address `arena.data + bytes_P` is 64-aligned.
1680 const size_t arena_base_misalign = static_cast<size_t>(arena.base_addr & (WORKER_SLAB_ALIGN - 1));
1681 const size_t bytes_P_min = align_up(bytes_P_prefix, alignof(Element)) + bytes_P_extra_layout;
1682 const size_t bytes_P = align_up(bytes_P_min + arena_base_misalign, WORKER_SLAB_ALIGN) - arena_base_misalign;
1683 // bytes_W: per_worker_bytes is already rounded to WORKER_SLAB_ALIGN, so consecutive
1684 // slabs stay aligned once the first slab is aligned.
1685 const size_t bytes_W = per_worker_bytes * worker_total;
1686
1687 // Sanity: zones must fit. The conservative `compute_arena_bytes_for_msm` upper bound
1688 // sized the buffer to `BATCH_MEM_BUDGET + 32K + dedup_bytes` at worst, which dominates
1689 // every reachable (P + W + S) sum at the inline-tight wpb chosen above.
1690 BB_ASSERT_LTE(bytes_P + bytes_W, arena.capacity);
1691 const size_t bytes_S_total = arena.capacity - bytes_P - bytes_W;
1692
1693 // Per-zone bump cursors. Zone P continues from `bytes_P_prefix`; Zones W and S start
1694 // fresh at their zone base. Zone P's bound is `bytes_P` so the bump cursor stays inside
1695 // its slot even if the extra slabs alignment-slop a hair.
1696 size_t zone_P_cursor = bytes_P_prefix;
1697 size_t zone_S_cursor = 0;
1698 auto zone_P_alloc = [&]<typename T>(size_t count) -> std::span<T> {
1699 return arena.template bump_alloc<T>(count, zone_P_cursor, bytes_P, 0);
1700 };
1701 auto zone_S_alloc = [&]<typename T>(size_t count) -> std::span<T> {
1702 return arena.template bump_alloc<T>(count, zone_S_cursor, bytes_S_total, bytes_P + bytes_W);
1703 };
1704 // Zone W is carved into per-worker slabs directly via `MsmArena::bump_alloc` below — each
1705 // worker gets its own (cursor, bound) pair, so a single zone-wide allocator would not
1706 // capture the per-worker discipline.
1707 // The pre-Phase-1 `MsmArena::alloc` cursor is retired here — every subsequent allocation
1708 // routes through `zone_P_alloc`, the per-worker Zone W allocators, or `zone_S_alloc`.
1709
1710 // Zone W: per-worker union slab — Stage6a/6b ThreadScratch and PhaseA fields overlay the
1711 // same per-worker bytes, with the wpb-dependent Stage 6 fields immediately after.
1712 for (size_t t = 0; t < worker_total; ++t) {
1713 // Each worker's slab is a contiguous `per_worker_bytes` window inside Zone W.
1714 const size_t slab_base = t * per_worker_bytes;
1715 auto& s = thread_scratch[t];
1716
1717 // ThreadScratch fixed fields — first view into the union. Bound = union size.
1718 size_t ts_fixed_cur = 0;
1719 auto ts_fixed_alloc = [&]<typename T>(size_t count) -> std::span<T> {
1720 return arena.template bump_alloc<T>(count, ts_fixed_cur, per_worker_union_bytes, bytes_P + slab_base);
1721 };
1722 s.curr_pts = ts_fixed_alloc.template operator()<AffineElement>(chunk_capacity);
1723 s.curr_buckets = ts_fixed_alloc.template operator()<uint32_t>(chunk_capacity);
1724 s.points_to_add = ts_fixed_alloc.template operator()<AffineElement>(2 * BATCH_CAPACITY);
1725 s.inversion_scratch = ts_fixed_alloc.template operator()<BaseField>(BATCH_CAPACITY);
1726 s.pair_dest = ts_fixed_alloc.template operator()<uint32_t>(BATCH_CAPACITY);
1727 s.overflow_slots = ts_fixed_alloc.template operator()<uint32_t>(global_max_overflow_per_window);
1728 s.overflow_pts = ts_fixed_alloc.template operator()<AffineElement>(global_max_overflow_per_window);
1729
1730 // PhaseA fields — second view, overlays the SAME per-worker union bytes. PhaseA's
1731 // parallel_for never overlaps Stage 6a/6b on the same worker, so reusing the bytes is
1732 // safe; the union's size is max(ts_fixed_layout, pa_layout) by construction.
1733 if (dedup_active) {
1734 size_t pa_cur = 0;
1735 auto pa_alloc = [&]<typename T>(size_t count) -> std::span<T> {
1736 return arena.template bump_alloc<T>(count, pa_cur, per_worker_union_bytes, bytes_P + slab_base);
1737 };
1738 auto& ps = phase_a_scratch[t];
1740 ps.cluster_members = pa_alloc.template operator()<uint32_t>(phase_a_cluster_members_cap);
1741 ps.cluster_offsets = pa_alloc.template operator()<uint32_t>(phase_a_cluster_offsets_cap);
1742 ps.dirty_slots = pa_alloc.template operator()<uint16_t>(PWAL::PHASE_A_DIRTY_SLOTS_CAP);
1743 ps.bucket_rep = pa_alloc.template operator()<uint32_t>(PWAL::PHASE_A_BUCKET_REP_CAP);
1744 ps.staged = pa_alloc.template operator()<std::pair<uint32_t, uint32_t>>(PWAL::PHASE_A_STAGED_CAP);
1745 ps.chunk_pts = pa_alloc.template operator()<AffineElement>(PWAL::PHASE_A_CHUNK_CAP);
1746 ps.chunk_ids = pa_alloc.template operator()<uint32_t>(PWAL::PHASE_A_CHUNK_CAP);
1747 }
1748
1749 // Stage 6 wpb-dependent fields — tail of the per-worker slab, BEYOND the union. Bound
1750 // = full per-worker slab size; cursor starts at per_worker_union_bytes so we don't
1751 // overwrite the union region.
1752 size_t ts_tail_cur = per_worker_union_bytes;
1753 auto ts_tail_alloc = [&]<typename T>(size_t count) -> std::span<T> {
1754 return arena.template bump_alloc<T>(count, ts_tail_cur, per_worker_bytes, bytes_P + slab_base);
1755 };
1756 const size_t dense_total = windows_per_batch * dense_stride_est;
1757 const size_t dense_pair_max = dense_total / 2;
1758 s.dense_buckets = ts_tail_alloc.template operator()<AffineElement>(dense_total);
1759 s.is_present = ts_tail_alloc.template operator()<uint8_t>(dense_total);
1760 s.affine_bucket_pairs = ts_tail_alloc.template operator()<std::pair<uint32_t, uint32_t>>(dense_pair_max);
1761 s.affine_bucket_indices = ts_tail_alloc.template operator()<uint32_t>(dense_pair_max);
1762 s.affine_bucket_inversion_scratch = ts_tail_alloc.template operator()<BaseField>(dense_pair_max);
1763 s.chunk_infos =
1764 ts_tail_alloc.template operator()<round_parallel_detail::AffineBucketChunkInfo>(windows_per_batch);
1765 std::fill_n(s.chunk_infos.begin(), windows_per_batch, round_parallel_detail::AffineBucketChunkInfo{});
1766 s.affine_bucket_stride = dense_stride_est;
1767 }
1768
1769 // Zone S: per-batch swing region — schedule + HIST slot + DENSE slot + partition metadata.
1770 const size_t schedule_total = windows_per_batch * n;
1771 auto schedule = zone_S_alloc.template operator()<uint32_t>(schedule_total);
1772
1773 // ----- HIST slot ------------------------------------------------------------------
1774 // Single byte slab backing two non-coexisting lifetime classes:
1775 // Epoch H (Stages 1-4): digit_cursors.
1776 // Epoch O (Stages 6b-7): chunk_outputs, window_partial_sums.
1777 // H dies before O is born (Stage 4 cursor advance ends before Stage 6b first writes
1778 // chunk_outputs / window_partial_sums).
1779 //
1780 // D-class (bucket_partials_dense + bucket_partials_present) previously overlaid this
1781 // slot too, but a 10× interleaved WASM Chonk bench showed Stage 6a regressed +1.29%
1782 // (t=+58) because of L1 cache aliasing on the `dense[slot]/present[slot]` scatter
1783 // writes when D sat at the HIST-overlaid offset. D-class now has its own dedicated
1784 // Zone-S DENSE slot below — see "DENSE slot" comment block.
1785 //
1786 // Phase 4: `digit_cursors` is dual-role within epoch H. After Stage 1 it holds
1787 // per-(w, t) counts of digit d; Stage 2 walks each (w, d) column from t = 0..T-1
1788 // reading the count from slot k and writing back the exclusive prefix-sum offset
1789 // (the count is consumed into `running` BEFORE the slot is overwritten, so the
1790 // in-place transform is mathematically identical to the previous out-of-place
1791 // version). Stage 4 then advances each (w, t) slice as a per-thread cursor.
1792 // Strict aliasing: every access goes through a std::span<T> obtained by
1793 // reinterpret_cast<T*>(hist_slot.data() + offset)
1794 // which is well-defined because std::byte is allowed by [basic.lval] to alias any
1795 // POD type. All overlaid types (uint32_t, size_t, Element, ChunkOutput<Curve>) are
1796 // trivially copyable / standard layout so the two epochs do not require construction
1797 // or destruction calls when the role of the bytes changes.
1798 static_assert(alignof(Element) <= 32, "HIST slot O layout assumes alignof(Element) <= 32");
1799 static_assert(alignof(round_parallel_detail::ChunkOutput<Curve>) <= 32,
1800 "HIST slot O layout assumes alignof(ChunkOutput) <= 32");
1801
1802 auto align_up_local = [](size_t off, size_t a) -> size_t { return (off + a - 1) & ~(a - 1); };
1803
1804 // Exact byte requirements for each epoch (matches the budget formula above).
1805 const size_t hist_h_bytes_total = (size_t{ 4 } * windows_per_batch * num_threads * B_eff); // digit_cursors
1806
1807 // O epoch layout — chunk_outputs first, then window_partial_sums. Both are alignof
1808 // <= 32; align each up to its own alignment.
1809 size_t o_layout_cur = 0;
1810 o_layout_cur = align_up_local(o_layout_cur, alignof(round_parallel_detail::ChunkOutput<Curve>));
1811 const size_t off_chunk_outputs = o_layout_cur;
1812 o_layout_cur += sizeof(round_parallel_detail::ChunkOutput<Curve>) * windows_per_batch * num_threads;
1813 o_layout_cur = align_up_local(o_layout_cur, alignof(typename Curve::Element));
1814 const size_t off_window_partial_sums = o_layout_cur;
1815 o_layout_cur += sizeof(typename Curve::Element) * num_threads * windows_per_batch;
1816 const size_t hist_o_bytes_total = o_layout_cur;
1817
1818 const size_t hist_slot_bytes_total = std::max(hist_h_bytes_total, hist_o_bytes_total);
1819 // Round up to AffineElement size so the bump allocator below treats the slot as a
1820 // whole number of 64-byte alignas(64) cells. Allocate via AffineElement to force the
1821 // slot base to be 64-byte aligned in absolute address space — sufficient for the
1822 // H-epoch uint32 digit_cursors span (alignof 4) and the O-epoch ChunkOutput/Element
1823 // spans (alignof ≤ 32).
1824 const size_t hist_slot_cells = (hist_slot_bytes_total + sizeof(AffineElement) - 1) / sizeof(AffineElement);
1825 auto hist_slot_cells_span = zone_S_alloc.template operator()<AffineElement>(hist_slot_cells);
1826 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1827 std::byte* const hist_slot_bytes = reinterpret_cast<std::byte*>(hist_slot_cells_span.data());
1828
1829 // H-epoch view — live S1..S4. `digit_cursors[(w*T + t) * stride + d]` holds three
1830 // distinct meanings depending on stage:
1831 // * After Stage 1: per-(w, t) count of digit d's occurrences in thread t's slice.
1832 // * After Stage 2: per-(w, t) exclusive prefix-sum offset (cursor base) for the
1833 // bucket-d run inside that window's schedule slot.
1834 // * After Stage 4: offset + count (final cursor end-state); dead from then on.
1835 // Stage 2 reads each (w, t, d) count from this buffer and writes the running prefix
1836 // sum back to the SAME slot before advancing `running`, so the count is preserved
1837 // long enough to feed the accumulator. Stage 4's `++` post-increment on each
1838 // thread's slice runs without atomics because each thread owns its (w, t, *) row
1839 // exclusively.
1840 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1841 auto digit_cursors =
1842 std::span<uint32_t>{ reinterpret_cast<uint32_t*>(hist_slot_bytes), windows_per_batch * num_threads * B_eff };
1843
1844 // O-epoch views — live S6b..S7. Backed by the SAME bytes as above; H contents are
1845 // dead by the time these are touched. ChunkOutput<Curve> and Curve::Element have
1846 // user-defined constructors so are not formally trivially_copyable, but they are
1847 // standard-layout PODs of fixed bytes (Element is alignas(32) over a fixed-width Fq
1848 // field array). The existing arena pre-Phase-3 already aliases them through std::byte
1849 // buffers via `make_unique_for_overwrite<std::byte[]>` + reinterpret_cast; the
1850 // std::byte aliasing rule in [basic.lval] applies regardless of trivial-copyability.
1852 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1853 reinterpret_cast<round_parallel_detail::ChunkOutput<Curve>*>(hist_slot_bytes + off_chunk_outputs),
1854 windows_per_batch * num_threads
1855 };
1856 auto window_partial_sums = std::span<typename Curve::Element>{
1857 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1858 reinterpret_cast<typename Curve::Element*>(hist_slot_bytes + off_window_partial_sums),
1859 num_threads * windows_per_batch
1860 };
1861 // window_partial_sums is reset to identity at the start of each Stage 6b worker
1862 // (`my_partials[w] = point_at_infinity` loop), so we deliberately do NOT initialise
1863 // it here. chunk_outputs is written unconditionally per (w, tprime) in Stage 6b
1864 // (the empty path sets `out.empty = 1`), so no pre-init is needed either.
1865 // ----- end HIST slot --------------------------------------------------------------
1866
1867 // ----- DENSE slot -----------------------------------------------------------------
1868 // Dedicated Zone-S slot for D-class (bucket_partials_dense + bucket_partials_present).
1869 // Lifetime is Stages 6a-6b only. Isolated from the HIST slot so Stage 6a's tight
1870 // scatter loop
1871 // `dst_dense[slot] = pt; dst_present[slot] = 1;`
1872 // does not L1-alias against the HIST slot's H/O bytes (the previous co-located
1873 // layout caused a +1.29% Stage 6a regression in WASM, t=+58 across 10× interleaved
1874 // runs). The dense ↔ present pair stays packed at fixed aligned offsets within this
1875 // slot — they MUST stay close because Stage 6a reads `present[slot]` then writes
1876 // `dense[slot]` / `present[slot]` in tandem in the inner loop.
1877 static_assert(alignof(AffineElement) == 64, "DENSE slot D layout assumes alignof(AffineElement) == 64");
1878 const size_t bp_total = windows_per_batch * bucket_partials_per_window_max;
1879 size_t d_layout_cur = 0;
1880 const size_t off_dense = d_layout_cur;
1881 d_layout_cur += sizeof(AffineElement) * bp_total; // bucket_partials_dense
1882 const size_t off_present = d_layout_cur;
1883 d_layout_cur += sizeof(uint8_t) * bp_total; // bucket_partials_present
1884 const size_t dense_slot_bytes_total = d_layout_cur;
1885 const size_t dense_slot_cells = (dense_slot_bytes_total + sizeof(AffineElement) - 1) / sizeof(AffineElement);
1886 // Allocate via AffineElement to force 64-byte alignment for the leading
1887 // bucket_partials_dense view.
1888 auto dense_slot_cells_span = zone_S_alloc.template operator()<AffineElement>(dense_slot_cells);
1889 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1890 std::byte* const dense_slot_bytes = reinterpret_cast<std::byte*>(dense_slot_cells_span.data());
1891
1892 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1893 auto bucket_partials_dense =
1894 std::span<AffineElement>{ reinterpret_cast<AffineElement*>(dense_slot_bytes + off_dense), bp_total };
1895 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1896 auto bucket_partials_present =
1897 std::span<uint8_t>{ reinterpret_cast<uint8_t*>(dense_slot_bytes + off_present), bp_total };
1898 // ----- end DENSE slot -------------------------------------------------------------
1899
1900 auto bucket_start_all = zone_S_alloc.template operator()<size_t>(windows_per_batch * (B_eff + 1));
1901 auto chunk_start_all = zone_S_alloc.template operator()<size_t>(windows_per_batch * (num_threads + 1));
1902 // chunk_bucket_lo_all[w*(T+1) + t] = bucket index of the first schedule entry in
1903 // chunk t of window w.
1904 // chunk_bucket_hi_all[w*T + t] = bucket index of the last schedule entry in chunk t.
1905 // Chunks are partitioned by schedule index (uniform t·m/T), not by bucket boundary, so
1906 // a bucket's run can straddle threads — both threads then carry a partial for that
1907 // shared bucket and Stage 7's chunk_contribution sum (Σ_d d · partial_d_in_t over t)
1908 // combines them without an explicit merge step.
1909 auto chunk_bucket_lo_all = zone_S_alloc.template operator()<size_t>(windows_per_batch * (num_threads + 1));
1910 auto chunk_bucket_hi_all = zone_S_alloc.template operator()<size_t>(windows_per_batch * num_threads);
1911
1912 // bucket_partials_offsets is the index table that maps (thread, window) -> slot
1913 // start in bucket_partials_dense/present. Lives S5..S6b alongside chunk_start_all,
1914 // and stays as its own Zone S allocation (separate from the DENSE slot).
1915 auto bucket_partials_offsets = zone_S_alloc.template operator()<size_t>((num_threads * windows_per_batch) + 1);
1916
1917 // Stage 6b rebalanced-task partition. The bucket range [1, num_buckets) is split evenly
1918 // across `num_threads` rebalanced tasks t'. The partition is uniform in num_buckets so
1919 // we store T+1 boundaries (not per-window). For each window we record the half-open
1920 // range of original threads whose chunk range intersects each task t' — usually 1-2
1921 // originals per task.
1922 auto rebalanced_bucket_lo_partition = zone_S_alloc.template operator()<size_t>(num_threads + 1);
1923 auto orig_thread_lo = zone_S_alloc.template operator()<size_t>(windows_per_batch * num_threads);
1924 auto orig_thread_hi = zone_S_alloc.template operator()<size_t>(windows_per_batch * num_threads);
1925
1926 // Zone P: window_sums (Stage 7 accumulator — survives the whole MSM_fast).
1927 auto window_sums = zone_P_alloc.template operator()<typename Curve::Element>(VAR_WINDOW_WINDOW_SUMS_CAP);
1928 std::fill_n(window_sums.begin(), VAR_WINDOW_WINDOW_SUMS_CAP, Curve::Group::point_at_infinity);
1929
1930 // Zone P: dedup state — written by Phase A and read through Stage 6a of every batch,
1931 // so it must outlive every batch.
1932 // - redirect_lookup: parallel-filled with DEDUP_INVALID_EXTRA below before Phase A reads it.
1933 // - extra_points: no init needed; Phase A writes per-thread cid ranges, and consumers
1934 // only read indices Phase A actually populated.
1935 if (dedup_active) {
1936 dedup_state.redirect_lookup = zone_P_alloc.template operator()<uint32_t>(n);
1937 dedup_state.extra_points =
1938 zone_P_alloc.template operator()<AffineElement>(round_parallel_detail::DEDUP_MAX_CLUSTERS);
1939 BB_BENCH_NAME("MSM_fast::dedup/redirect_invalid_fill");
1940 uint32_t* const rl = dedup_state.redirect_lookup.data();
1941 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
1942 for (size_t i : chunk.range(n)) {
1944 }
1945 });
1946 }
1947
1948 // BUCKET_MASK strips the sign bit off a packed (sign | bucket) digit produced by
1949 // get_constantine_packed_digit, leaving the unsigned bucket index.
1950 constexpr uint32_t BUCKET_MASK = (uint32_t{ 1 } << 31) - 1;
1951
1952 // Phase A runs at most once per MSM_fast (not per batch). Cluster membership is determined
1953 // by scalar value (memcmp) — independent of which window we walk — and bucket
1954 // adjacency holds in any window's sorted schedule because true duplicates land in the
1955 // same bucket of every window. So we Phase A on the very first batch's window-0
1956 // schedule, populate `dedup_state.{redirect_lookup, extra_points}` once, and reuse the
1957 // result for every subsequent batch.
1958 bool phase_a_done = false;
1959
1960 auto run_batch = [&](size_t batch_start, size_t windows_in_batch, size_t B_R) noexcept {
1961 // Per-(w, t) slot stride uses `B_eff` = max(num_buckets, B_lo, B_hi); each call
1962 // iterates only the region's first B_R entries. The arena was sized for B_eff per slot.
1963 const size_t bucket_stride = B_eff;
1964 // Per-window slice params. The final window can be narrower when the bit budget
1965 // does not divide evenly by the default window size; the Booth recoder must use
1966 // that narrower width or it encroaches on bits beyond the schedule.
1967 constexpr size_t SCALAR_UINT64_LIMBS = sizeof(ScalarField) / sizeof(uint64_t);
1974 std::array<uint8_t, 128> per_window_bits{};
1975 constexpr size_t SCALAR_U32_LIMBS = sizeof(ScalarField) / sizeof(uint32_t);
1976 for (size_t w = 0; w < windows_in_batch; ++w) {
1977 const size_t global_w = batch_start + w;
1978 const size_t window_bits_w = sched.window_bits_per_window[global_w];
1979 per_window_bits[w] = static_cast<uint8_t>(window_bits_w);
1981 sched.bit_base[global_w], window_bits_w, SCALAR_UINT64_LIMBS);
1983 sched.bit_base[global_w], window_bits_w, SCALAR_U32_LIMBS);
1984 slice_paths[w] = round_parallel_detail::classify_slice_path_u32(slice_params_u32[w]);
1985 const uint32_t lo_mask = slice_params_u32[w].lo_mask;
1986 const uint32_t hi_mask = slice_params_u32[w].hi_mask;
1987 const uint32_t val_mask = (uint32_t{ 1 } << static_cast<uint32_t>(window_bits_w)) - 1;
1988 lo_mask_vectors[w] = round_parallel_detail::SimdU32x4{ lo_mask, lo_mask, lo_mask, lo_mask };
1989 hi_mask_vectors[w] = round_parallel_detail::SimdU32x4{ hi_mask, hi_mask, hi_mask, hi_mask };
1990 val_mask_vectors[w] = round_parallel_detail::SimdU32x4{ val_mask, val_mask, val_mask, val_mask };
1991 }
1992
1993 constexpr size_t SIMD_BATCH = 64;
1994 static_assert(SIMD_BATCH % 4 == 0, "SIMD_BATCH must be divisible by 4");
1995 constexpr size_t LIMBS_PER_SCALAR = sizeof(ScalarField) / sizeof(uint32_t);
1996 const auto* scalars_u32 = reinterpret_cast<const uint32_t*>(scalars.data());
1998 auto fill_packed_digit_buffer = [&](size_t w, size_t i, uint32_t* packed_buf) noexcept {
1999 const auto& sp32 = slice_params_u32[w];
2000 const uint32_t window_bits_w = static_cast<uint32_t>(per_window_bits[w]);
2002 for (size_t k = 0; k < SIMD_BATCH; k += 4) {
2004 packed_buf + k,
2005 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2006 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2007 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2008 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2009 sp32.lo_limb,
2010 sp32.lo_off,
2011 lo_mask_vectors[w],
2012 one_v,
2013 val_mask_vectors[w],
2014 window_bits_w);
2015 }
2016 } else if (slice_paths[w] == round_parallel_detail::ConstantineSlicePath::Bottom) {
2017 for (size_t k = 0; k < SIMD_BATCH; k += 4) {
2019 packed_buf + k,
2020 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2021 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2022 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2023 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2024 sp32.hi_limb,
2025 sp32.lo_bits,
2026 hi_mask_vectors[w],
2027 one_v,
2028 val_mask_vectors[w],
2029 window_bits_w);
2030 }
2031 } else {
2032 for (size_t k = 0; k < SIMD_BATCH; k += 4) {
2034 packed_buf + k,
2035 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2036 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2037 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2038 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2039 sp32.lo_limb,
2040 sp32.hi_limb,
2041 sp32.lo_off,
2042 sp32.lo_bits,
2043 lo_mask_vectors[w],
2044 hi_mask_vectors[w],
2045 one_v,
2046 val_mask_vectors[w],
2047 window_bits_w);
2048 }
2049 }
2050 };
2051
2052 // Capture the dedup state before Stage 1. The first batch must build the ordinary
2053 // R14 schedule so Phase A can discover clusters, then patch+compact that batch.
2054 // Later batches can schedule cluster reps directly and omit non-reps up front.
2055 const bool phase_a_done_at_batch_start = phase_a_done;
2056 const bool dedup_known_for_batch =
2057 dedup_active && phase_a_done_at_batch_start && dedup_state.n_dedup_extras != 0;
2058
2059 // Stage 1 (digit extraction): per-thread per-window bucket histograms. Work is
2060 // scalar-blocked across the windows in this batch so scalars/msb/dedup metadata are
2061 // read once per block and reused while still hot.
2062 auto stage1_digit_extract = [&]<bool DedupKnown>(size_t tid) noexcept {
2063 [[maybe_unused]] const uint32_t* const rl_data = dedup_state.redirect_lookup.data();
2064 for (size_t w = 0; w < windows_in_batch; ++w) {
2065 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2066 std::memset(my_counts, 0, B_R * sizeof(uint32_t));
2067 }
2068 const size_t start = tid * n / num_threads;
2069 const size_t end = (tid + 1) * n / num_threads;
2070
2071 alignas(16) std::array<uint32_t, SIMD_BATCH> packed_buf{};
2072 // Pack the per-block filter into a uint64 bitmask. When every scalar in the block
2073 // is active (common in dense workloads), the inner scatter takes an all_included
2074 // fast path that drops the per-element predicate; mixed blocks bit-scan the mask.
2075 auto compute_include_mask = [&](size_t block_start) noexcept -> uint64_t {
2076 uint64_t include_mask = 0;
2077 for (size_t k = 0; k < SIMD_BATCH; ++k) {
2078 const size_t scalar_idx = block_start + k;
2079 const uint8_t m = msb_per_scalar[scalar_idx];
2080 bool include = (m != MSB_ZERO_SENTINEL);
2081 if constexpr (DedupKnown) {
2082 if (include) {
2083 const uint32_t patch = rl_data[scalar_idx];
2084 include = (patch == round_parallel_detail::DEDUP_INVALID_EXTRA ||
2086 }
2087 }
2088 include_mask |= static_cast<uint64_t>(include) << k;
2089 }
2090 return include_mask;
2091 };
2092
2093 size_t i = start;
2094 while (i + SIMD_BATCH <= end) {
2095 const uint64_t include_mask = compute_include_mask(i);
2096 if (include_mask == 0) {
2097 i += SIMD_BATCH;
2098 continue;
2099 }
2100 const bool all_included = include_mask == ~uint64_t{ 0 };
2101 for (size_t w = 0; w < windows_in_batch; ++w) {
2102 fill_packed_digit_buffer(w, i, packed_buf.data());
2103 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2104 if (all_included) {
2105 for (size_t k = 0; k < SIMD_BATCH; ++k) {
2106 ++my_counts[packed_buf[k] & BUCKET_MASK];
2107 }
2108 } else {
2109 uint64_t scatter_mask = include_mask;
2110 for (size_t k = 0; k < SIMD_BATCH; ++k) {
2111 if ((scatter_mask & uint64_t{ 1 }) != 0) {
2112 ++my_counts[packed_buf[k] & BUCKET_MASK];
2113 }
2114 scatter_mask >>= 1;
2115 }
2116 }
2117 }
2118 i += SIMD_BATCH;
2119 }
2120
2121 // Tail (0..SIMD_BATCH-1 scalars). Same scalar-major loop order; per-scalar
2122 // active check inlined since the block is short.
2123 for (; i < end; ++i) {
2124 const uint8_t m = msb_per_scalar[i];
2125 if (m == MSB_ZERO_SENTINEL) {
2126 continue;
2127 }
2128 if constexpr (DedupKnown) {
2129 const uint32_t patch = rl_data[i];
2132 continue;
2133 }
2134 }
2135 for (size_t w = 0; w < windows_in_batch; ++w) {
2136 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2137 const round_parallel_detail::ConstantineSliceParams sp = slice_params[w];
2138 const uint32_t window_bits_w = static_cast<uint32_t>(per_window_bits[w]);
2139 const uint32_t packed =
2141 sp.lo_limb,
2142 sp.hi_limb,
2143 sp.lo_off,
2144 sp.lo_bits,
2145 sp.lo_mask,
2146 sp.hi_mask,
2148 window_bits_w);
2149 ++my_counts[packed & BUCKET_MASK];
2150 }
2151 }
2152 };
2153 if (dedup_known_for_batch) {
2154 bb::parallel_for(num_threads, [&](size_t tid) { stage1_digit_extract.template operator()<true>(tid); });
2155 } else {
2156 bb::parallel_for(num_threads, [&](size_t tid) { stage1_digit_extract.template operator()<false>(tid); });
2157 }
2158
2159 // Stage 2 (bucket histogram): per-window per-digit totals + per-thread within-digit
2160 // offsets. Parallelised over digit-chunks; each worker handles its slice of 2^window_bits
2161 // for all windows_in_batch windows. In-place exclusive prefix-sum: each slot
2162 // `digit_cursors[(w*T + t) * stride + d]` is read for its Stage 1 count and then
2163 // overwritten with the running prefix sum (== the cursor base Stage 4 needs). The
2164 // count must be read BEFORE the write or `running` would skip its contribution.
2165 // Phase 5: the per-digit total `running` is written directly into
2166 // `bucket_start_all[w][d+1]` (one cell past where Stage 3 will read), so Stage 3 can
2167 // prefix-sum in place without a separate `bucket_total_counts` buffer. The size_t
2168 // bucket_start cell widens the uint32_t total implicitly.
2169 bb::parallel_for(num_threads, [&](size_t tid) {
2170 const size_t d_start = tid * B_R / num_threads;
2171 const size_t d_end = (tid + 1) * B_R / num_threads;
2172 for (size_t w = 0; w < windows_in_batch; ++w) {
2173 size_t* const bucket_start_w = bucket_start_all.data() + (w * (bucket_stride + 1));
2174 for (size_t d = d_start; d < d_end; ++d) {
2175 if (d == 0) {
2176 continue;
2177 }
2178 uint32_t running = 0;
2179 for (size_t t = 0; t < num_threads; ++t) {
2180 const size_t k = (((w * num_threads) + t) * bucket_stride) + d;
2181 const uint32_t cnt = digit_cursors[k];
2182 digit_cursors[k] = running;
2183 running += cnt;
2184 }
2185 bucket_start_w[d + 1] = running;
2186 }
2187 }
2188 });
2189
2190 // Stage 3 (bucket offsets / prefix sum): per-window serial prefix sum in place.
2191 // Stage 2 already deposited each digit's per-window total at bucket_start[d+1];
2192 // the loop accumulates the running prefix-sum without a separate counts buffer.
2193 {
2194 BB_BENCH_NAME("MSM_fast::Stage2_3_bucket_offsets");
2195 auto build_bucket_offsets_for_window = [&](size_t w) noexcept {
2196 size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2197 bucket_start[0] = 0;
2198 bucket_start[1] = 0;
2199 for (size_t d = 1; d < B_R; ++d) {
2200 bucket_start[d + 1] += bucket_start[d];
2201 }
2202 };
2203 const size_t offset_threads = std::min(num_threads, windows_in_batch);
2204 if (offset_threads <= 1) {
2205 for (size_t w = 0; w < windows_in_batch; ++w) {
2206 build_bucket_offsets_for_window(w);
2207 }
2208 } else {
2209 bb::parallel_for(offset_threads, [&](size_t tid) {
2210 for (size_t w = tid; w < windows_in_batch; w += offset_threads) {
2211 build_bucket_offsets_for_window(w);
2212 }
2213 });
2214 }
2215 }
2216
2217 // Stage 4 (digit scatter): scalar-cache-blocked, window-local scatter. Re-decodes each
2218 // (point, window) signed digit via the same Constantine carry-less recoder Stage 1 used.
2219 // Stage 4 stores only `sign | scalar_idx`; bucket magnitude is recovered later from
2220 // bucket_start ranges.
2221 // Stage 1 benefits from full scalar-major order because it only updates compact
2222 // per-window histograms. Stage 4 writes large bucket schedules, so full scalar-major
2223 // order opens too many cold write/cursor streams. Instead, process a scalar tile across
2224 // all windows: scalar/msb/dedup metadata are reused while the tile is cache-hot, but each
2225 // inner loop still scatters to one window's schedule at a time.
2226 //
2227 // First-batch Stage 4 is dedup-unaware: every scalar is emitted as
2228 // `sched_w[idx] = sign | scalar_idx`, then Phase A + patch/compact tags cluster
2229 // reps and removes non-reps. Later batches with known dedup state skip non-reps
2230 // here and emit redirect reps directly.
2231 // Splitting the dedup work out of this hot loop avoids a per-iteration
2232 // closure-indirection chain through `dedup_state.redirect_lookup[i]`
2233 // that the WASM JIT does not hoist (~13 ns/iter penalty observed).
2234 auto stage4_emit = [&]<bool DedupKnown>(size_t tid) noexcept {
2235 [[maybe_unused]] const uint32_t* const rl_data = dedup_state.redirect_lookup.data();
2236 const size_t start = tid * n / num_threads;
2237 const size_t end = (tid + 1) * n / num_threads;
2239 std::array<const size_t*, 128> bucket_starts{};
2240 std::array<uint32_t*, 128> schedules{};
2241 for (size_t w = 0; w < windows_in_batch; ++w) {
2242 cursors[w] = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2243 bucket_starts[w] = bucket_start_all.data() + (w * (bucket_stride + 1));
2244 schedules[w] = schedule.data() + (w * n);
2245 }
2246
2247 alignas(16) std::array<uint32_t, SIMD_BATCH> packed_buf{};
2248 constexpr size_t STAGE4_SCALAR_TILE = 2048;
2250 [[maybe_unused]] std::array<uint32_t, STAGE4_SCALAR_TILE> out_base_tile{};
2251
2252 for (size_t tile_start = start; tile_start < end; tile_start += STAGE4_SCALAR_TILE) {
2253 const size_t tile_end = std::min(end, tile_start + STAGE4_SCALAR_TILE);
2254 const size_t tile_len = tile_end - tile_start;
2255 for (size_t j = 0; j < tile_len; ++j) {
2256 const size_t scalar_idx = tile_start + j;
2257 const uint8_t m = msb_per_scalar[scalar_idx];
2258 bool include = (m != MSB_ZERO_SENTINEL);
2259 if constexpr (DedupKnown) {
2260 uint32_t out_base = static_cast<uint32_t>(scalar_idx);
2261 if (include) {
2262 const uint32_t patch = rl_data[scalar_idx];
2264 include = (patch & round_parallel_detail::DEDUP_SKIP_BIT) == 0;
2265 out_base = patch;
2266 }
2267 }
2268 out_base_tile[j] = out_base;
2269 }
2270 active_tile[j] = static_cast<uint8_t>(include);
2271 }
2272
2273 for (size_t w = 0; w < windows_in_batch; ++w) {
2274 uint32_t* my_cursor = cursors[w];
2275 const size_t* bucket_start = bucket_starts[w];
2276 uint32_t* sched_w = schedules[w];
2277 size_t i = tile_start;
2278 while (i + SIMD_BATCH <= tile_end) {
2279 const size_t rel = i - tile_start;
2280 uint64_t include_mask = 0;
2281 for (size_t k = 0; k < SIMD_BATCH; ++k) {
2282 include_mask |= static_cast<uint64_t>(active_tile[rel + k]) << k;
2283 }
2284 if (include_mask == 0) {
2285 i += SIMD_BATCH;
2286 continue;
2287 }
2288 fill_packed_digit_buffer(w, i, packed_buf.data());
2289 uint64_t scatter_mask = include_mask;
2290 for (size_t k = 0; k < SIMD_BATCH; ++k) {
2291 if ((scatter_mask & uint64_t{ 1 }) != 0) {
2292 const uint32_t packed = packed_buf[k];
2293 const uint32_t bucket_idx = packed & BUCKET_MASK;
2294 if (bucket_idx != 0) {
2295 const uint32_t idx =
2296 static_cast<uint32_t>(bucket_start[bucket_idx]) + my_cursor[bucket_idx]++;
2297 uint32_t out = packed & round_parallel_detail::SCHEDULE_SIGN_BIT;
2298 if constexpr (DedupKnown) {
2299 out |= out_base_tile[rel + k];
2300 } else {
2301 out |= static_cast<uint32_t>(i + k);
2302 }
2303 sched_w[idx] = out;
2304 }
2305 }
2306 scatter_mask >>= 1;
2307 }
2308 i += SIMD_BATCH;
2309 }
2310 for (; i < tile_end; ++i) {
2311 const size_t rel = i - tile_start;
2312 if (active_tile[rel] == 0) {
2313 continue;
2314 }
2315 const round_parallel_detail::ConstantineSliceParams sp = slice_params[w];
2317 scalars[i].data,
2318 sp.lo_limb,
2319 sp.hi_limb,
2320 sp.lo_off,
2321 sp.lo_bits,
2322 sp.lo_mask,
2323 sp.hi_mask,
2325 static_cast<uint32_t>(per_window_bits[w]));
2326 const uint32_t bucket_idx = packed & BUCKET_MASK;
2327 if (bucket_idx != 0) {
2328 const uint32_t idx =
2329 static_cast<uint32_t>(bucket_start[bucket_idx]) + my_cursor[bucket_idx]++;
2330 uint32_t out = packed & round_parallel_detail::SCHEDULE_SIGN_BIT;
2331 if constexpr (DedupKnown) {
2332 out |= out_base_tile[rel];
2333 } else {
2334 out |= static_cast<uint32_t>(i);
2335 }
2336 sched_w[idx] = out;
2337 }
2338 }
2339 }
2340 }
2341 };
2342
2343 if (dedup_known_for_batch) {
2344 bb::parallel_for(num_threads, [&](size_t tid) { stage4_emit.template operator()<true>(tid); });
2345 } else {
2346 bb::parallel_for(num_threads, [&](size_t tid) { stage4_emit.template operator()<false>(tid); });
2347 }
2348
2349 // Phase A: schedule-based dedup detection on window 0. Each thread owns a
2350 // contiguous range of window 0's schedule. Detects duplicate clusters via
2351 // consecutive-pair check (same bucket + memcmp on full scalar value), tree-reduces
2352 // members into an aggregate, and publishes results into `dedup_state.extra_points`,
2353 // `dedup_state.redirect_lookup`, and zeroed `msb_per_scalar` entries for non-reps.
2354 // Per-thread cluster-id ranges keep writes disjoint — no atomics needed.
2355 // Phase A: schedule-based dedup detection. Runs at most ONCE per MSM_fast (gated on
2356 // `phase_a_done` from the enclosing function scope). Cluster membership is decided
2357 // by scalar value (memcmp), so any window's bucket-sorted schedule places duplicates
2358 // consecutively — Phase A on this first-batch's window-0 schedule produces the
2359 // correct redirect_lookup + extra_points for all subsequent batches. We deliberately
2360 // do not re-run Phase A per batch: the dedup_state is populated once and reused.
2361 if (dedup_active && windows_in_batch > 0 && !phase_a_done) {
2362 BB_BENCH_NAME("MSM_fast::PhaseA_dedup_detect");
2363 uint32_t* sched_w0 = schedule.data();
2364 // Pre-Phase-A bucket sort: Stage 4 emits each bucket's run in scalar-emit
2365 // order, so different-value scalars that happen to share a window-0 digit
2366 // (bucket collisions are common — c=11 → 2048 buckets vs 60-90k entries)
2367 // interleave with same-value entries and break Phase A's consecutive-pair
2368 // detection. Sorting each bucket's run by scalar value makes same-value
2369 // entries adjacent so the simple consecutive-pair walk finds every cluster.
2370 // Sort cost: per bucket of size K, ~K log K comparisons × 32-byte memcmp;
2371 // for typical K=44 this is ~500 cycles per bucket × 2048 buckets = ~1 ms
2372 // wall (parallelized across threads).
2373 const uint32_t cids_per_thread =
2374 static_cast<uint32_t>(round_parallel_detail::DEDUP_MAX_CLUSTERS / num_threads);
2375 // Hash-based per-bucket dedup detection: every thread owns a
2376 // contiguous bucket range of window-0's schedule and runs an
2377 // open-addressing hash table over that range's long-scalar entries.
2378 // O(K) per bucket, avoids the 32-byte memcmp comparator inside any
2379 // sort, and keeps thread balance uniform because short-scalar
2380 // entries (the source of mega-buckets like digit_0 = 1) are skipped.
2381 // Catches ~99.94 % of long-scalar duplicates against MSM_DUMP's
2382 // theoretical maximum (`dup_input_extras`).
2383 {
2384 BB_BENCH_NAME("MSM_fast::PhaseA_dedup_detect_hash");
2385 const size_t* const w0_bucket_start = bucket_start_all.data();
2386 std::atomic<size_t> dedup_cluster_count{ 0 };
2387 bb::parallel_for(num_threads, [&, w0_bucket_start](size_t tid) noexcept {
2388 const size_t b_lo = 1 + ((tid * (B_R - 1)) / num_threads);
2389 const size_t b_hi = 1 + (((tid + 1) * (B_R - 1)) / num_threads);
2390 const uint32_t cid_lo = static_cast<uint32_t>(tid) * cids_per_thread;
2391 const uint32_t cid_max = cid_lo + cids_per_thread;
2392 const size_t local_clusters = round_parallel_detail::dedup_phase_a_worker_hash<Curve>(
2393 sched_w0,
2394 w0_bucket_start,
2395 b_lo,
2396 b_hi,
2397 std::span<const ScalarField>(scalars.data(), n),
2398 points,
2400 std::span<uint32_t>(dedup_state.redirect_lookup),
2401 msb_per_scalar.data(),
2402 window_bits,
2403 cid_lo,
2404 cid_max,
2405 phase_a_scratch[tid]);
2406 if (local_clusters != 0) {
2407 dedup_cluster_count.fetch_add(local_clusters, std::memory_order_relaxed);
2408 }
2409 });
2410 dedup_state.n_dedup_extras = dedup_cluster_count.load(std::memory_order_relaxed);
2411 }
2412 phase_a_done = true;
2413 }
2414
2415 // Schedule patch post-pass: tags cluster-member entries with SKIP/REDIRECT bits.
2416 // Runs only for the batch that just ran Phase A: later batches with known dedup
2417 // state skip non-reps in Stage 1/4 and emit redirect reps directly.
2418 // Parallel by window (one window per worker) because each window's slice of the
2419 // schedule is disjoint. Hoisting `redirect_lookup.data()` to a raw pointer outside
2420 // the lambda + passing it by value into the inner function avoids the per-iter
2421 // closure-indirection chain that made the inline form 3× slower per iter on WASM.
2422 auto partition_chunks_for_window = [&](size_t w) noexcept {
2423 const size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2424 const size_t* const bucket_start_end = bucket_start + B_R + 1;
2425 size_t* chunk_start = chunk_start_all.data() + (w * (num_threads + 1));
2426 size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2427 size_t* chunk_bucket_hi = chunk_bucket_hi_all.data() + (w * num_threads);
2428 const size_t m = bucket_start[B_R];
2429 const size_t* search_begin = bucket_start + 1;
2430 size_t lo = 0;
2431 chunk_start[0] = lo;
2432 for (size_t t = 0; t < num_threads; ++t) {
2433 const size_t hi = ((t + 1) == num_threads) ? m : (((t + 1) * m) / num_threads);
2434 chunk_start[t + 1] = hi;
2435 if (lo < hi) {
2436 const size_t* const lo_it = std::upper_bound(search_begin, bucket_start_end, lo);
2437 const size_t lo_bucket = static_cast<size_t>(lo_it - bucket_start - 1);
2438 const size_t* const hi_it = std::upper_bound(lo_it, bucket_start_end, hi - 1);
2439 const size_t hi_bucket = static_cast<size_t>(hi_it - bucket_start - 1);
2440 chunk_bucket_lo[t] = lo_bucket;
2441 chunk_bucket_hi[t] = hi_bucket;
2442 search_begin = hi_it;
2443 } else {
2444 chunk_bucket_lo[t] = B_R;
2445 chunk_bucket_hi[t] = 0;
2446 }
2447 lo = hi;
2448 }
2449 chunk_bucket_lo[num_threads] = B_R;
2450 };
2451
2452 bool chunk_partition_done = false;
2453 if (dedup_active && windows_in_batch > 0 && phase_a_done && !phase_a_done_at_batch_start) {
2454 BB_BENCH_NAME("MSM_fast::dedup_patch_schedule");
2455 const uint32_t* const rl_data = dedup_state.redirect_lookup.data();
2456 const size_t bs_stride = bucket_stride + 1;
2457 const size_t br = B_R;
2458 const size_t cap_R = n;
2459 bb::parallel_for(num_threads, [&, rl_data, bs_stride, br, cap_R](size_t tid) noexcept {
2460 for (size_t w = tid; w < windows_in_batch; w += num_threads) {
2461 uint32_t* sched_w = schedule.data() + (w * cap_R);
2462 size_t* bucket_start_w = bucket_start_all.data() + (w * bs_stride);
2463 round_parallel_detail::dedup_patch_schedule_window<Curve>(sched_w, bucket_start_w, br, rl_data);
2464 partition_chunks_for_window(w);
2465 }
2466 });
2467 chunk_partition_done = true;
2468 }
2469
2470 // Per-window chunk partition at schedule-index granularity (chunk_start[t] = t·m/T).
2471 // Balances across threads regardless of bucket-distribution skew. When the partition
2472 // lands mid-bucket, both adjacent threads build their own partial into the boundary
2473 // bucket; chunk_contribution combines them in Stage 7.
2474 {
2475 BB_BENCH_NAME("MSM_fast::Stage5_chunk_partition");
2476 if (!chunk_partition_done) {
2477 for (size_t w = 0; w < windows_in_batch; ++w) {
2478 partition_chunks_for_window(w);
2479 }
2480 }
2481 }
2482
2483 // Stage 6 bucket accumulation per thread:
2484 // (1) For each window w: reduce_chunk emits a digit-sorted (point, digit) list,
2485 // which we densify into a per-window dense bucket array at
2486 // tid's affine bucket buffer + w * stride. Empty slots stay identity.
2487 // (2) Call recursive_affine_bucket_reduce_strided once across all windows_in_batch
2488 // chunks; it computes (R_w, L_w) for each non-empty chunk via batch-affine
2489 // arithmetic, amortising the inversion across windows at every phase step.
2490 // (3) chunk_contribution(out) folds L_w + (lo_w-1)·R_w into the thread's per-window
2491 // partial.
2492 // The Stage-6 scratch is pre-sized for every thread BEFORE entering the parallel_for
2493 // so the per-thread vector resizes don't race the heap allocator.
2494 auto next_pow2 = [](size_t x) -> size_t {
2495 if (x <= 1) {
2496 return 1;
2497 }
2498 size_t p = 1;
2499 while (p < x) {
2500 p <<= 1;
2501 }
2502 return p;
2503 };
2504 // Drives reduce_chunk's per-thread tree-reduce buffer sizing.
2505 size_t max_chunk_len = 0;
2506 for (size_t t = 0; t < num_threads; ++t) {
2507 for (size_t w = 0; w < windows_in_batch; ++w) {
2508 const size_t* chunk_start = chunk_start_all.data() + (w * (num_threads + 1));
2509 const size_t entries_in_chunk = chunk_start[t + 1] - chunk_start[t];
2510 if (entries_in_chunk == 0) {
2511 continue;
2512 }
2513 max_chunk_len = std::max(max_chunk_len, entries_in_chunk);
2514 }
2515 }
2516
2517 // global_stride drives the per-thread `dense_buckets` layout (sized via
2518 // `ensure_affine_bucket_capacity` below). Stage 6a writes its per-thread bucket
2519 // partials into `bucket_partials_dense` (a separate buffer packed via
2520 // `bucket_partials_offsets`, no power-of-two stride); Stage 6b copies them into
2521 // `s.dense_buckets` keyed by Stage 6b's uniform bucket-index slice of width
2522 // `buckets_per_task ≈ ⌈(num_buckets-1)/T⌉`. The recursive bucket-reduction
2523 // algorithm (phases A-D) operates on `s.dense_buckets` with power-of-two row
2524 // stride — that's where `next_pow2` matters.
2525 size_t global_stride = 0;
2526
2527 {
2528 // Stage 6b's bucket-balanced partition. Uniform across windows: each rebalanced
2529 // task t' owns active digits [d_lo'[t'], d_hi'[t']] where d_lo'[t'] = 1 + t · (B-1) / T.
2530 const size_t active_digits = (B_R > 0) ? (B_R - 1) : 0;
2531 for (size_t t = 0; t <= num_threads; ++t) {
2532 rebalanced_bucket_lo_partition[t] = 1 + (t * active_digits) / num_threads;
2533 }
2534 rebalanced_bucket_lo_partition[num_threads] = B_R;
2535 size_t max_buckets_per_task = 0;
2536 for (size_t t = 0; t + 1 <= num_threads; ++t) {
2537 const size_t hi_d = (t + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[t + 1] - 1);
2538 const size_t lo_d = rebalanced_bucket_lo_partition[t];
2539 if (hi_d >= lo_d) {
2540 max_buckets_per_task = std::max(max_buckets_per_task, hi_d - lo_d + 1);
2541 }
2542 }
2543 global_stride = next_pow2(max_buckets_per_task);
2544 global_stride = std::max<size_t>(global_stride, 2);
2545
2546 // Per-window orig-thread contributing ranges (O(W·T·T) total — only paid for
2547 // the rebalance path, where T is small enough that this is sub-µs).
2548 for (size_t w = 0; w < windows_in_batch; ++w) {
2549 const size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2550 const size_t* chunk_bucket_hi = chunk_bucket_hi_all.data() + (w * num_threads);
2551 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2552 for (size_t tprime = 0; tprime < num_threads; ++tprime) {
2553 const size_t lo_d = rebalanced_bucket_lo_partition[tprime];
2554 const size_t hi_d =
2555 (tprime + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[tprime + 1] - 1);
2556 size_t lo_orig = num_threads;
2557 size_t hi_orig = 0;
2558 for (size_t t = 0; t < num_threads; ++t) {
2559 const size_t entries = chunk_start_w[t + 1] - chunk_start_w[t];
2560 if (entries == 0) {
2561 continue;
2562 }
2563 const size_t cl = chunk_bucket_lo[t];
2564 const size_t ch = chunk_bucket_hi[t];
2565 if (ch < lo_d || cl > hi_d) {
2566 continue;
2567 }
2568 if (lo_orig == num_threads) {
2569 lo_orig = t;
2570 }
2571 hi_orig = t;
2572 }
2573 orig_thread_lo[(w * num_threads) + tprime] = lo_orig;
2574 orig_thread_hi[(w * num_threads) + tprime] = hi_orig;
2575 }
2576 }
2577
2578 // bucket_partials_dense / _present packed via bucket_partials_offsets — each
2579 // (thread, window) row holds exactly buckets_per_thread[t][w] AffineElements (no
2580 // padding). The arena pre-sized to `windows_per_batch · (num_buckets - 1 + T)`
2581 // (covers the T-1 boundary-bucket shares); only the actual prefix is touched.
2582 size_t bucket_partials_cursor = 0;
2583 for (size_t t = 0; t < num_threads; ++t) {
2584 for (size_t w = 0; w < windows_in_batch; ++w) {
2585 bucket_partials_offsets[(t * windows_in_batch) + w] = bucket_partials_cursor;
2586 const size_t* chunk_bucket_lo_w = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2587 const size_t* chunk_bucket_hi_w = chunk_bucket_hi_all.data() + (w * num_threads);
2588 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2589 const size_t entries = chunk_start_w[t + 1] - chunk_start_w[t];
2590 if (entries > 0) {
2591 bucket_partials_cursor += chunk_bucket_hi_w[t] - chunk_bucket_lo_w[t] + 1;
2592 }
2593 }
2594 }
2595 bucket_partials_offsets[num_threads * windows_in_batch] = bucket_partials_cursor;
2596 const size_t bucket_partials_total = bucket_partials_cursor;
2597 BB_ASSERT_LTE(bucket_partials_total, bucket_partials_dense.size());
2598 std::memset(bucket_partials_present.data(), 0, bucket_partials_total);
2599 }
2600
2601 // thread_scratch is worker-indexed (one slot per OS thread, FIFO-shared by tasks);
2602 // update the stride on each worker's slot.
2603 for (size_t t = 0; t < worker_total; ++t) {
2604 thread_scratch[t].affine_bucket_stride = global_stride;
2605 }
2606
2607 {
2608 // Stage 6a — per-thread bucket partials. Each thread `tid` reduces its schedule
2609 // slice via reduce_chunk and scatters the (digit, point) output directly into the
2610 // per-thread dense bucket buffer at slot `(digit - chunk_bucket_lo[tid])`. Stage
2611 // 6b then reads this buffer with O(1) slot lookup. `bucket_partials_present` is
2612 // pre-zeroed per batch.
2613 auto bucket_partials_per_thread_lambda = [&](size_t tid) {
2614 auto& s = thread_scratch[tid];
2615 for (size_t w = 0; w < windows_in_batch; ++w) {
2616 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2617 const size_t cs_lo = chunk_start_w[tid];
2618 const size_t cs_hi = chunk_start_w[tid + 1];
2619 if (cs_lo == cs_hi) {
2620 continue;
2621 }
2622 const uint32_t* sched_w = schedule.data() + (w * n);
2623 const size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2624 AffineElement* dst_dense =
2625 bucket_partials_dense.data() + bucket_partials_offsets[(tid * windows_in_batch) + w];
2626 uint8_t* dst_present =
2627 bucket_partials_present.data() + bucket_partials_offsets[(tid * windows_in_batch) + w];
2628 const size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2629 const uint32_t my_lo = static_cast<uint32_t>(chunk_bucket_lo[tid]);
2630 const size_t my_hi = chunk_bucket_hi_all[(w * num_threads) + tid];
2631 size_t bucket_cursor = my_lo;
2632
2633 for (size_t pos = cs_lo; pos < cs_hi;) {
2634 const size_t end = std::min(pos + SUBCHUNK_ENTRIES_CAP, cs_hi);
2635 reduce_chunk<Curve>(s,
2636 sched_w,
2637 bucket_start,
2638 pos,
2639 end,
2640 bucket_cursor,
2641 my_hi,
2642 points,
2643 std::span<const AffineElement>(dedup_state.extra_points));
2644 const size_t len = s.result_len;
2645 for (size_t k = 0; k < len; ++k) {
2646 const uint32_t d = s.curr_buckets[k];
2647 const size_t slot = d - my_lo;
2648 if (dst_present[slot]) {
2649 s.overflow_slots[s.overflow_len] = static_cast<uint32_t>(slot);
2650 s.overflow_pts[s.overflow_len] = s.curr_pts[k];
2651 ++s.overflow_len;
2652 } else {
2653 dst_dense[slot] = s.curr_pts[k];
2654 dst_present[slot] = 1;
2655 }
2656 }
2657 pos = end;
2658 }
2659 merge_overflow<Curve>(s, dst_dense);
2660 }
2661 };
2662
2663 // Stage 6b (cross-thread bucket reduction): each rebalanced task `tprime` owns a
2664 // uniform-width slice of the bucket-index space [d_lo'(tprime), d_hi'(tprime)].
2665 // For each window in the batch, walk the contributing original threads' Stage 6a
2666 // dense outputs (range [orig_thread_lo, orig_thread_hi]), filter to digits in
2667 // this task's slice, scatter into the task's local dense_buckets (with
2668 // projective-add accumulation on the at-most-2 boundary digits per pair of
2669 // contributing originals), then run recursive_affine_bucket_reduce_strided +
2670 // chunk_contribution on a guaranteed-equal buckets_padded across all tasks.
2671 auto bucket_reduce_cross_thread_lambda = [&](size_t tprime) {
2672 auto& s = thread_scratch[tprime];
2673 Element* my_partials = window_partial_sums.data() + (tprime * windows_per_batch);
2674 for (size_t w = 0; w < windows_in_batch; ++w) {
2675 my_partials[w] = Curve::Group::point_at_infinity;
2676 }
2677
2678 const size_t stride = s.affine_bucket_stride;
2679 std::memset(s.is_present.data(), 0, windows_in_batch * stride);
2680
2681 const size_t lo_d = rebalanced_bucket_lo_partition[tprime];
2682 const size_t hi_d =
2683 (tprime + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[tprime + 1] - 1);
2684 const uint32_t lo_d_u = static_cast<uint32_t>(lo_d);
2685 const uint32_t hi_d_u = static_cast<uint32_t>(hi_d);
2686
2687 bool any_nonempty = false;
2688 for (size_t w = 0; w < windows_in_batch; ++w) {
2689 auto& info = s.chunk_infos[w];
2690 auto& out = chunk_outputs[(w * num_threads) + tprime];
2691 if (lo_d > hi_d) {
2692 info.empty = 1;
2693 info.lo = 0;
2694 info.hi = 0;
2695 info.buckets_padded = 0;
2696 out.empty = 1;
2697 continue;
2698 }
2699 const size_t orig_lo = orig_thread_lo[(w * num_threads) + tprime];
2700 const size_t orig_hi = orig_thread_hi[(w * num_threads) + tprime];
2701 if (orig_lo == num_threads) {
2702 info.empty = 1;
2703 info.lo = 0;
2704 info.hi = 0;
2705 info.buckets_padded = 0;
2706 out.empty = 1;
2707 continue;
2708 }
2709 const size_t base = w * stride;
2710 bool has_data = false;
2711
2712 // bucket_partials_dense holds per-(orig_t, w, slot) bucket points with
2713 // bucket_partials_present as the populated-slot bitmap. For each
2714 // contributing orig_t, intersect its [chunk_bucket_lo, chunk_bucket_hi]
2715 // range with this task's [lo_d, hi_d] slice and walk the intersection
2716 // only — no sorted scan, O(1) lookup per slot.
2717 const size_t* chunk_bucket_lo_w = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2718 const size_t* chunk_bucket_hi_w = chunk_bucket_hi_all.data() + (w * num_threads);
2719 for (size_t t = orig_lo; t <= orig_hi; ++t) {
2720 const size_t cl = chunk_bucket_lo_w[t];
2721 const size_t ch = chunk_bucket_hi_w[t];
2722 const size_t d_lo_clip = std::max<size_t>(lo_d, cl);
2723 const size_t d_hi_clip = std::min<size_t>(hi_d, ch);
2724 if (d_lo_clip > d_hi_clip) {
2725 continue;
2726 }
2727 const AffineElement* src_dense =
2728 bucket_partials_dense.data() + bucket_partials_offsets[(t * windows_in_batch) + w];
2729 const uint8_t* src_present =
2730 bucket_partials_present.data() + bucket_partials_offsets[(t * windows_in_batch) + w];
2731 for (size_t d = d_lo_clip; d <= d_hi_clip; ++d) {
2732 const size_t src_slot = d - cl;
2733 if (src_present[src_slot] == 0) {
2734 continue;
2735 }
2736 const size_t dst_slot = base + (d - lo_d);
2737 if (s.is_present[dst_slot] == 0) {
2738 s.dense_buckets[dst_slot] = src_dense[src_slot];
2739 s.is_present[dst_slot] = 1;
2740 } else {
2741 // Boundary digit shared between two consecutive originals
2742 // — projective add then re-normalise to affine. Under the
2743 // contiguous-by-schedule-index partition there are at most
2744 // W boundary points per task.
2745 Element acc = Element(s.dense_buckets[dst_slot]);
2746 acc += Element(src_dense[src_slot]);
2747 s.dense_buckets[dst_slot] = AffineElement(acc);
2748 }
2749 has_data = true;
2750 }
2751 }
2752 if (!has_data) {
2753 info.empty = 1;
2754 info.lo = 0;
2755 info.hi = 0;
2756 info.buckets_padded = 0;
2757 out.empty = 1;
2758 continue;
2759 }
2760 any_nonempty = true;
2761 const size_t M = hi_d - lo_d + 1;
2762 const uint32_t buckets_padded =
2763 (M == 1) ? 1 : (uint32_t{ 1 } << (32 - __builtin_clz(static_cast<uint32_t>(M - 1))));
2764 info.empty = 0;
2765 info.lo = lo_d_u;
2766 info.hi = hi_d_u;
2767 info.buckets_padded = buckets_padded;
2768 out.empty = 0;
2769 out.lo = lo_d_u;
2770 out.hi = hi_d_u;
2771 }
2772
2773 if (!any_nonempty) {
2774 return;
2775 }
2776
2777 round_parallel_detail::recursive_affine_bucket_reduce_strided<Curve>(
2778 s, s.chunk_infos.data(), windows_in_batch, chunk_outputs.data() + tprime, num_threads);
2779
2780 for (size_t w = 0; w < windows_in_batch; ++w) {
2781 auto& out = chunk_outputs[(w * num_threads) + tprime];
2782 if (out.empty == 0) {
2783 my_partials[w] = round_parallel_detail::chunk_contribution<Curve>(out);
2784 }
2785 }
2786 };
2787
2788 bb::parallel_for(num_threads, bucket_partials_per_thread_lambda);
2789 bb::parallel_for(num_threads, bucket_reduce_cross_thread_lambda);
2790 }
2791
2792 // Stage 7 (cross-window combine): per-window reduce of `num_threads` per-thread partials.
2793 // (Algebraic identity: `Σ_t (L_t + (lo_t − 1) · R_t) = window's bucket sum`,
2794 // with the per-chunk contributions already accumulated above.)
2795 {
2796 const size_t reduce_threads = std::min(num_threads, windows_in_batch);
2797 bb::parallel_for(reduce_threads, [&](size_t rid) {
2798 const size_t lo = rid * windows_in_batch / reduce_threads;
2799 const size_t hi = (rid + 1) * windows_in_batch / reduce_threads;
2800 for (size_t w = lo; w < hi; ++w) {
2801 Element sum = Curve::Group::point_at_infinity;
2802 for (size_t tid = 0; tid < num_threads; ++tid) {
2803 sum += window_partial_sums[(tid * windows_per_batch) + w];
2804 }
2805 window_sums[batch_start + w] = sum;
2806 }
2807 });
2808 }
2809 };
2810
2811 // Uniform-schedule dispatch over all windows.
2812 {
2813 const size_t B_R = (size_t{ 1 } << (window_bits - 1)) + 1;
2814 for (size_t batch_start = 0; batch_start < sched.num_windows; batch_start += windows_per_batch) {
2815 const size_t windows_in_batch = std::min(windows_per_batch, sched.num_windows - batch_start);
2816 run_batch(batch_start, windows_in_batch, B_R);
2817 }
2818 }
2819
2820 // Stage 7 horner: walk high-to-low, doubling by `window_bits_per_window[w]` between adjacent windows.
2821 // Init from the top window to skip a wasted doubling on identity.
2822 Element result = (sched.num_windows == 0) ? Curve::Group::point_at_infinity : window_sums[sched.num_windows - 1];
2823 for (size_t w_rev = sched.num_windows - 1; w_rev > 0; --w_rev) {
2824 const size_t window_bits_w = sched.window_bits_per_window[w_rev - 1];
2825 for (size_t d = 0; d < window_bits_w; ++d) {
2826 result.self_dbl();
2827 }
2828 result += window_sums[w_rev - 1];
2829 }
2830
2831 // GLV path leaves input_scalars untouched (it reads via from_montgomery_form_reduced into
2832 // a temporary). Non-GLV path mutated in place above and must restore.
2833 if (!use_glv) {
2834 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
2835 for (size_t i : chunk.range(n_input)) {
2836 input_scalars[i].self_to_montgomery_form();
2837 }
2838 });
2839 }
2840
2841 return result;
2842}
2843
2844template <typename Curve>
2847 bool dedup_hint) noexcept
2848{
2849 return pippenger_round_parallel<Curve>(scalars, points, dedup_hint);
2850}
2851
2852template <typename Curve>
2855 bool handle_edge_cases,
2856 bool dedup_hint) noexcept
2857{
2858 using Element = typename Curve::Element;
2859 using ScalarField = typename Curve::ScalarField;
2860 if (!handle_edge_cases) {
2861 return pippenger_round_parallel<Curve>(scalars, points, dedup_hint);
2862 }
2863 // Edge-case-handling path: route through the Jacobian fast-path. It uses
2864 // Jacobian additions throughout, so point-at-infinity and equal-x bucket
2865 // collisions don't trigger the affine-add edge-case bug. We need to convert
2866 // PolynomialSpan to a plain ScalarField span: the jacobian fast-path takes
2867 // a contiguous std::span and ignores `start_index`.
2868 const size_t n = scalars.span.size();
2869 if (n == 0) {
2870 return Curve::Group::point_at_infinity;
2871 }
2872 // Trivially small N: skip Pippenger / Jacobian-fast-path scaffolding entirely.
2873 // Affine operator* + Jacobian sum already handles all edge cases.
2874 if (n < 4) {
2875 return trivial_msm<Curve>(scalars, points);
2876 }
2877 const auto& start = scalars.start_index;
2878 if (start >= points.size()) {
2879 return Curve::Group::point_at_infinity;
2880 }
2881 const size_t n_used = std::min<size_t>(n, points.size() - start);
2882 std::span<const typename Curve::AffineElement> point_slice(points.data() + start, n_used);
2883 std::span<const ScalarField> scalar_slice(scalars.span.data(), n_used);
2884 // Convert scalars to non-Montgomery form for the jacobian path's bit-extraction loop,
2885 // then restore. Mirrors the round-parallel fast-path's scalar lifecycle.
2886 // Use the `_reduced` variant: the bit-extraction loop reads only bits 0..253
2887 // (NUM_BITS = 254). Plain `self_from_montgomery_form` leaves the value in [0, 2p),
2888 // so values in [2^254, 2p) would have bit 254 set and silently drop the contribution
2889 // of that bit. `_reduced` brings the value into [0, p) ⊂ [0, 2^254).
2890 auto* mutable_scalars =
2891 const_cast<ScalarField*>(scalar_slice.data()); // NOLINT(cppcoreguidelines-pro-type-const-cast)
2892 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
2893 for (size_t i : chunk.range(n_used)) {
2894 mutable_scalars[i].self_from_montgomery_form_reduced();
2895 }
2896 });
2897 const Element result =
2898 round_parallel_detail::pippenger_round_parallel_jacobian_fast<Curve>(scalar_slice, point_slice, 0);
2899 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
2900 for (size_t i : chunk.range(n_used)) {
2901 mutable_scalars[i].self_to_montgomery_form();
2902 }
2903 });
2904 return result;
2905}
2906
2907template <typename Curve>
2910 bool handle_edge_cases,
2911 bool dedup_hint) noexcept
2912{
2913 return AffineElement(pippenger_fast<Curve>(scalars, points, handle_edge_cases, dedup_hint));
2914}
2915
2916#include "./pippenger_batched.hpp"
2917
2918// Explicit instantiations.
2922 bool dedup_hint) noexcept;
2926 bool dedup_hint) noexcept;
2929 bool handle_edge_cases,
2930 bool dedup_hint) noexcept;
2934 bool handle_edge_cases,
2935 bool dedup_hint) noexcept;
2936template class MSM_fast<curve::BN254>;
2937template class MSM_fast<curve::Grumpkin>;
2938
2942 bool dedup_hint,
2944 std::span<std::byte> external_arena) noexcept;
2945
2949 bool dedup_hint,
2951 std::span<std::byte> external_arena) noexcept;
2952
2956
2960
2964
2968
2969namespace round_parallel_detail {
2970template curve::BN254::Element pippenger_round_parallel_jacobian_fast<curve::BN254>(
2973 size_t min_pts_per_thread_override) noexcept;
2974
2975template curve::Grumpkin::Element pippenger_round_parallel_jacobian_fast<curve::Grumpkin>(
2978 size_t min_pts_per_thread_override) noexcept;
2979} // namespace round_parallel_detail
2980
2981template size_t compute_arena_bytes_for_msm<curve::BN254>(size_t, bool, bool) noexcept;
2982
2983} // namespace bb::scalar_multiplication
#define BB_ASSERT_GTE(left, right,...)
Definition assert.hpp:128
#define BB_ASSERT_GT(left, right,...)
Definition assert.hpp:113
#define BB_ASSERT_EQ(actual, expected,...)
Definition assert.hpp:83
#define BB_ASSERT_LTE(left, right,...)
Definition assert.hpp:158
#define BB_BENCH_NAME(name)
Definition bb_bench.hpp:264
typename Group::element Element
Definition bn254.hpp:21
typename Group::affine_element AffineElement
Definition bn254.hpp:22
typename Group::element Element
Definition grumpkin.hpp:63
typename Group::affine_element AffineElement
Definition grumpkin.hpp:64
static AffineElement msm(std::span< const AffineElement > points, PolynomialSpan< const ScalarField > scalars, bool handle_edge_cases=false, bool dedup_hint=false) noexcept
Single MSM_fast convenience wrapper — returns the result as an AffineElement.
#define info(...)
Definition log.hpp:93
FF a
FF b
uint32_t get_constantine_packed_digit(const uint64_t *scalar_data, uint32_t lo_limb, uint32_t hi_limb, uint32_t lo_off, uint32_t lo_bits, uint32_t lo_mask, uint32_t hi_mask, bool slice_localised_to_one_u64, size_t window_bits) noexcept
Read (window_bits+1) bits from scalar_data (uint64 limbs) using precomputed slice params and apply Co...
ConstantineSlicePath classify_slice_path_u32(const ConstantineSliceParamsU32 &sp) noexcept
size_t compute_global_max_overflow_per_window(size_t n, size_t num_threads, size_t subchunk_entries_cap) noexcept
size_t compute_phase_one_prologue_bytes(size_t n, bool use_glv, bool inline_glv_double, size_t profile_threads) noexcept
void store_constantine_packed_digits_x4_bottom(uint32_t *dst, const uint32_t *scalar_data_0, const uint32_t *scalar_data_1, const uint32_t *scalar_data_2, const uint32_t *scalar_data_3, uint32_t hi_limb, uint32_t lo_bits, SimdU32x4 hi_mask_v, SimdU32x4 one_v, SimdU32x4 val_mask, uint32_t window_bits) noexcept
size_t solve_wpb(size_t per_window_bytes, size_t available_budget, size_t W_R) noexcept
void store_constantine_packed_digits_x4_boundary(uint32_t *dst, const uint32_t *scalar_data_0, const uint32_t *scalar_data_1, const uint32_t *scalar_data_2, const uint32_t *scalar_data_3, uint32_t lo_limb, uint32_t hi_limb, uint32_t lo_off, uint32_t lo_bits, SimdU32x4 lo_mask_v, SimdU32x4 hi_mask_v, SimdU32x4 one_v, SimdU32x4 val_mask, uint32_t window_bits) noexcept
size_t compute_bucket_partials_max(size_t B_eff, size_t num_threads) noexcept
uint32_t __attribute__((vector_size(16))) SimdU32x4
PhaseACaps compute_phase_a_caps(size_t n, size_t num_threads) noexcept
ConstantineSliceParams compute_constantine_slice_params(size_t bit_offset, size_t window_bits, size_t num_uint64_limbs) noexcept
Curve::Element pippenger_round_parallel_jacobian_fast(std::span< const typename Curve::ScalarField > scalars, std::span< const typename Curve::AffineElement > points, size_t min_pts_per_thread_override) noexcept
Small-N fast-path: per-thread Jacobian Pippenger over a partition of the input.
void store_constantine_packed_digits_x4_localised(uint32_t *dst, const uint32_t *scalar_data_0, const uint32_t *scalar_data_1, const uint32_t *scalar_data_2, const uint32_t *scalar_data_3, uint32_t lo_limb, uint32_t lo_off, SimdU32x4 lo_mask_v, SimdU32x4 one_v, SimdU32x4 val_mask, uint32_t window_bits) noexcept
size_t compute_dense_stride(size_t B_eff, size_t num_threads) noexcept
uint32_t choose_window_bits(size_t num_points, size_t num_bits, size_t n_input, size_t num_logical_threads) noexcept
VariableWindowSchedule build_var_window_schedule(size_t num_bits, size_t window_bits) noexcept
ConstantineSliceParamsU32 compute_constantine_slice_params_u32(size_t bit_offset, size_t window_bits, size_t num_u32_limbs) noexcept
template curve::BN254::Element pippenger_fast< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars, std::span< const curve::BN254::AffineElement > points, bool handle_edge_cases, bool dedup_hint) noexcept
size_t compute_arena_bytes_for_msm(size_t n_input, bool external_glv_provided, bool dedup_active) noexcept
Round-parallel Pippenger MSM_fast. Windows process sequentially (high-to-low) but each window is full...
template size_t compute_arena_bytes_for_msm< curve::BN254 >(size_t, bool, bool) noexcept
Curve::Element pippenger_unsafe_fast(PolynomialSpan< const typename Curve::ScalarField > scalars, std::span< const typename Curve::AffineElement > points, bool dedup_hint) noexcept
template curve::BN254::Element pippenger_round_parallel< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars, std::span< const curve::BN254::AffineElement > points, bool dedup_hint, std::span< const curve::BN254::AffineElement > external_glv_doubled, std::span< std::byte > external_arena) noexcept
template curve::BN254::Element trivial_msm_threaded< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars_span, std::span< const curve::BN254::AffineElement > all_points) noexcept
template curve::Grumpkin::Element pippenger_round_parallel< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars, std::span< const curve::Grumpkin::AffineElement > points, bool dedup_hint, std::span< const curve::Grumpkin::AffineElement > external_glv_doubled, std::span< std::byte > external_arena) noexcept
template curve::Grumpkin::Element trivial_msm< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars_span, std::span< const curve::Grumpkin::AffineElement > all_points) noexcept
template curve::Grumpkin::Element pippenger_fast< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars, std::span< const curve::Grumpkin::AffineElement > points, bool handle_edge_cases, bool dedup_hint) noexcept
Curve::Element pippenger_fast(PolynomialSpan< const typename Curve::ScalarField > scalars, std::span< const typename Curve::AffineElement > points, bool handle_edge_cases, bool dedup_hint) noexcept
template curve::BN254::Element trivial_msm< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars_span, std::span< const curve::BN254::AffineElement > all_points) noexcept
template curve::Grumpkin::Element pippenger_unsafe_fast< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars, std::span< const curve::Grumpkin::AffineElement > points, bool dedup_hint) noexcept
size_t window_bits_tuning_oversub_factor(size_t n_input)
N-dependent oversubscription factor used ONLY for choose_window_bits' target_load formula (not for ac...
template curve::BN254::Element pippenger_unsafe_fast< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars, std::span< const curve::BN254::AffineElement > points, bool dedup_hint) noexcept
Curve::Element pippenger_round_parallel(PolynomialSpan< const typename Curve::ScalarField > scalars_span, std::span< const typename Curve::AffineElement > all_points, bool dedup_hint, std::span< const typename Curve::AffineElement > external_glv_doubled, std::span< std::byte > external_arena) noexcept
State of the art pippenger_fast multiscalar multiplication algorithm.
template curve::Grumpkin::Element trivial_msm_threaded< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars_span, std::span< const curve::Grumpkin::AffineElement > all_points) noexcept
size_t get_num_cpus()
Definition thread.cpp:33
C slice(C const &container, size_t start)
Definition container.hpp:9
Inner sum(Cont< Inner, Args... > const &in)
Definition container.hpp:70
void parallel_for(size_t num_iterations, const std::function< void(size_t)> &func)
Definition thread.cpp:111
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
uint8_t len
std::span< uint32_t > affine_bucket_indices
uintptr_t base_addr
std::byte * data
std::span< BaseField > affine_bucket_inversion_scratch
std::span< AffineElement > points_to_add
std::span< uint32_t > pair_dest
std::span< uint8_t > is_present
std::span< AffineElement > overflow_pts
std::unique_ptr< std::byte[]> local_owner
std::span< std::pair< uint32_t, uint32_t > > affine_bucket_pairs
std::span< uint32_t > overflow_slots
size_t affine_bucket_stride
std::span< AffineElement > curr_pts
std::span< uint32_t > curr_buckets
std::span< BaseField > inversion_scratch
std::span< AffineElement > dense_buckets
std::span< AffineBucketChunkInfo > chunk_infos
Curve::Element Element
size_t thread_index
Definition thread.hpp:150
auto range(size_t size, size_t offset=0) const
Definition thread.hpp:152
Per-window precomputed slice parameters for the carry-less signed-Booth window recoding....
std::span< typename Curve::AffineElement > extra_points