Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
pippenger_dedup.hpp
Go to the documentation of this file.
1// Input-scalar dedup pre-pass for the round-parallel Pippenger MSM (Phase A).
2//
3// Detects clusters of input scalars whose canonical (non-Montgomery) value is identical
4// and spans more than one signed-Booth window, then combines each cluster's base points
5// into a single (rep, combined_point) pair via a chunked batched-affine tree-reduce.
6// Stage 4 then sees a redirect_lookup that rewrites the cluster's schedule entries:
7// the rep gets DEDUP_REDIRECT_BIT|extra_idx (fetched from extra_points[]), the rest get
8// DEDUP_SKIP_BIT and contribute nothing. This carves out two bits of the 32-bit schedule
9// encoding (bit 30 = redirect, bit 29 = skip), which is why this header also owns the
10// full schedule-bit encoding constants (the sign bit, the dedup bits, and the index
11// mask are all co-defined).
12//
13// The encoding constants and the `dedup_*` workers are pulled into one file so:
14// * scalar_multiplication.cpp's Stage 4 / Stage 6a schedule readers see the bit
15// constants via this header;
16// * the dedup machinery (Phase A workers, hash table, cluster tree-reduce, redirect
17// finalize) lives as a self-contained module rather than being scattered through
18// the MSM driver. Pure code motion — every function is inline / templated, so the
19// compiler sees identical code at identical call sites and codegen is unchanged.
20
21#pragma once
22
24
29
30#include <cstddef>
31#include <cstdint>
32#include <span>
33#include <utility>
34
36
37// 32-bit schedule-entry encoding. Stage 4 stores only the point sign and scalar index;
38// bucket magnitude is recovered from Stage 3's bucket_start ranges in Stage 5/6 because
39// the schedule is bucket-contiguous.
40// bit 31: sign bit from the packed signed digit
41// bit 30: dedup redirect — fetch from extra_points[payload]
42// bit 29: dedup skip — non-rep duplicate, carries no contribution
43// bits 0..28: scalar_idx, or extra_points index when redirect is set
44inline constexpr uint32_t SCHEDULE_SIGN_BIT = uint32_t{ 1 } << 31;
45inline constexpr uint32_t DEDUP_REDIRECT_BIT = uint32_t{ 1 } << 30;
46inline constexpr uint32_t DEDUP_SKIP_BIT = uint32_t{ 1 } << 29;
47inline constexpr uint32_t SCHEDULE_INDEX_MASK = DEDUP_SKIP_BIT - 1;
48static_assert((SCHEDULE_SIGN_BIT & DEDUP_REDIRECT_BIT) == 0);
49static_assert((SCHEDULE_SIGN_BIT & DEDUP_SKIP_BIT) == 0);
50static_assert((DEDUP_REDIRECT_BIT & DEDUP_SKIP_BIT) == 0);
53inline constexpr uint32_t DEDUP_INVALID_EXTRA = ~uint32_t{ 0 };
54
55[[nodiscard]] inline uint64_t dedup_scalar_fingerprint(const uint64_t* scalar_data) noexcept
56{
57 return scalar_data[0];
58}
59
60[[nodiscard]] inline size_t dedup_fingerprint_slot(uint64_t fingerprint, size_t mask) noexcept
61{
62 uint64_t h = fingerprint * 0x9E3779B97F4A7C15ULL;
63 h ^= h >> 32;
64 return static_cast<size_t>(h) & mask;
65}
66
67// ===================================================================================
68// Input-scalar dedup pre-pass.
69// ===================================================================================
70//
71// For each cluster of input scalars whose canonical value is identical and spans more
72// than one bucket window of width c (msb >= c), combine the cluster's base points into
73// a single (rep, combined_point) pair so Pippenger only iterates the cluster once
74// instead of `cluster_size` times.
75//
76// Detection: sort an index permutation by `scalars[i].data[0]` (a one-limb predicate;
77// equal-value scalars are guaranteed to collide on data[0] so they cluster contiguously
78// in the sorted output, with at most a few false-collision PAIRS expected per MSM at
79// chonk's scale). Walk runs of equal data[0]; verify each pair with a full memcmp.
80//
81// Combine: build a flat (cluster_pts, cluster_ids) array with same-cluster entries
82// contiguous, then run an in-place tree-reduce that pairs adjacent same-cluster-id
83// entries via `batch_affine_add_interleaved` (one inversion per BATCH_CAPACITY pairs)
84// until each cluster has a single surviving entry. Avoids the per-cluster Element += /
85// AffineElement(cast) round-trip that does one inversion per cluster.
86//
87// Output: a redirect_lookup[n] mapping scalar_idx → final dedup schedule payload
88// (DEDUP_REDIRECT_BIT | extra_idx, DEDUP_SKIP_BIT | scalar_idx, or INVALID = no patch).
89// Stage 4b ORs that payload with the preserved sign bit. The underlying canonical scalar
90// value is left untouched (`scalars` aliases the caller's polynomial and is restored to
91// Mont form on exit; mutating it would corrupt downstream consumers).
92
93template <typename Curve> struct DedupResult {
94 std::span<uint32_t> redirect_lookup; // size n; INVALID or encoded dedup payload.
95 // Allocated from the pippenger arena
96 // (no zero-init); filled with INVALID
97 // by a parallel_for before Phase A.
98 std::span<typename Curve::AffineElement> extra_points; // size DEDUP_MAX_CLUSTERS; arena-allocated.
99 // Phase A writes per-cluster aggregates
100 // into thread-disjoint cid ranges.
101 size_t n_dedup_extras = 0; // # extra_points populated by Phase A
102};
103
104// In-place batched-affine tree-reduce over (pts[0..len), cluster_ids[0..len)) with
105// same-cluster entries contiguous. After return, pts[0..result_len) holds one combined
106// point per cluster (paired in cluster-id order); ids[0..result_len) tracks the
107// surviving cluster_id at each slot. Caller-provided scratch (`scratch_pts`,
108// `pair_dest`, `inversion_scratch`) sized to BATCH_CAPACITY pairs.
109template <typename Curve>
111 uint32_t* ids,
112 size_t initial_len,
113 typename Curve::AffineElement* scratch_pts,
114 uint32_t* pair_dest,
115 typename Curve::BaseField* inversion_scratch) noexcept
116{
117 using AffineElement = typename Curve::AffineElement;
118 using BaseField = typename Curve::BaseField;
119
120 const auto drain = [&](size_t pair_count) noexcept {
121 if (pair_count == 0) {
122 return;
123 }
124 bb::group_elements::batch_affine_add_interleaved<AffineElement, BaseField>(
125 scratch_pts, 2 * pair_count, inversion_scratch);
126 for (size_t k = 0; k < pair_count; ++k) {
127 pts[pair_dest[k]] = scratch_pts[pair_count + k];
128 }
129 };
130
131 size_t curr_len = initial_len;
132 while (true) {
133 size_t i = 0;
134 size_t next_len = 0;
135 size_t pair_count = 0;
136 bool made_pair = false;
137
138 while (i < curr_len) {
139 if (i + 1 < curr_len && ids[i] == ids[i + 1]) {
140 scratch_pts[2 * pair_count] = pts[i];
141 scratch_pts[(2 * pair_count) + 1] = pts[i + 1];
142 ids[next_len] = ids[i];
143 pair_dest[pair_count] = static_cast<uint32_t>(next_len);
144 ++next_len;
145 ++pair_count;
146 i += 2;
147 made_pair = true;
148 if (pair_count >= BATCH_CAPACITY) {
149 drain(pair_count);
150 pair_count = 0;
151 }
152 } else {
153 pts[next_len] = pts[i];
154 ids[next_len] = ids[i];
155 ++next_len;
156 ++i;
157 }
158 }
159 drain(pair_count);
160
161 if (!made_pair) {
162 break;
163 }
164 curr_len = next_len;
165 }
166 return curr_len;
167}
168
169// Hard caps that bound the worst-case dedup-on memory at ≤ 4 MB above dedup-off, for
170// all possible inputs.
171//
172// redirect_lookup[n] = 4 n bytes (≤ 2 MB at n = 2^19)
173// extra_points[MAX_CLUSTERS] = 64 × MAX_CLUSTERS bytes
174// per-thread Phase A scratch: bounded by per-thread chunk size + chunk_pts buffer
175//
176// All phases ≤ 4 MB regardless of input shape. The caps degrade gracefully: when hit
177// we leave un-deduped scalars on the standard pippenger path (still correct, just
178// less savings).
179// `DEDUP_MAX_CLUSTERS`, `DEDUP_MAX_MEMBERS`, and `DEDUP_MAX_CHUNK_MEMBERS` are defined
180// in `pippenger_arena_layout.hpp` so the test harness can size the matching slabs.
181static_assert(DEDUP_MAX_CLUSTERS <= size_t{ SCHEDULE_INDEX_MASK } + 1,
182 "dedup extra-point ids must fit in the schedule payload");
183
184// Per-worker Phase A scratch backed by the pippenger arena. Replaces the prior
185// `thread_local std::vector<...>` slabs so process-resident memory after the MSM
186// drops back to zero and the per-worker working set is deterministic.
187//
188// All caps below are *loose upper bounds* — when a runtime population exceeds them,
189// the cluster-scan / tree-reduce inner loops already fall through to "leave un-deduped"
190// behaviour via `clusters_opened >= cid_max - cid_lo` and the `room` calculation in the
191// tree-reduce chunk-fill loop.
192template <typename Curve> struct PhaseAScratch {
193 // Worst case = every cluster on this worker has exactly one member. Per-worker
194 // cluster budget is `DEDUP_MAX_CLUSTERS / num_threads`; `DEDUP_MAX_MEMBERS / num_threads`
195 // is a looser, structurally simpler bound. +1 covers the final partition rounding slop.
196 std::span<uint32_t> cluster_members;
197 // One entry per opened cluster + the initial 0 sentinel pushed at function entry.
198 // Cap = (DEDUP_MAX_CLUSTERS / num_threads) + 2 (covers the +1 sentinel and rounding).
199 std::span<uint32_t> cluster_offsets;
200 // One uint16_t per hash-table slot dirtied since the last bucket; HT_SIZE = 4096 is
201 // the structural cap — every slot can at most be dirtied once per bucket.
202 std::span<uint16_t> dirty_slots;
203 // Per-bucket cluster representative scalar_idx. Current code reserves 32; widening
204 // to 256 covers chonk-wire worst cases (mega-buckets) without resizing.
205 std::span<uint32_t> bucket_rep;
206 // Per-bucket staged (bucket_cid, idx) pairs awaiting cluster emission. Current code
207 // reserves 64; widening to 1024 covers the chonk-wire mega-bucket worst case.
209 // Tree-reduce per-iteration working sets. Both capped at DEDUP_MAX_CHUNK_MEMBERS=2048;
210 // see the constant's definition above.
211 std::span<typename Curve::AffineElement> chunk_pts;
212 std::span<uint32_t> chunk_ids;
213};
214
215// Per-bucket hash-based dedup. Each thread owns a contiguous range of buckets in
216// window 0's schedule. For each bucket, we build a tiny open-addressing hash
217// table over the long-scalar entries (msb >= c_threshold) — short entries are
218// skipped because their dedup savings (W_nz ≈ 1) are zero. Slot selection uses
219// a cheap one-limb fingerprint; full 4-limb memcmp still gates every match.
220// Hash collisions resolve via linear probing; same-value collisions become cluster matches.
221// Replaces the old "std::sort each bucket then run consecutive-pair walk"
222// approach: hash is O(K) per bucket vs O(K log K), avoids the 32-byte memcmp
223// comparator entirely (one-limb hash on insert, full compare only on fingerprint
224// hits), and keeps thread balance uniform because skipping shorts removes the
225// mega-bucket bottleneck.
226//
227// Output: per-thread cluster_members + cluster_offsets feeding a chunked
228// batched-affine tree-reduce, plus encoded redirect_lookup writes
229// (rep -> DEDUP_REDIRECT_BIT | cid, non_rep -> DEDUP_SKIP_BIT | idx).
230// The thread's cid space is the disjoint per-thread sub-range [cid_lo, cid_max).
231template <typename Curve>
232size_t dedup_phase_a_worker_hash(const uint32_t* schedule_w0,
233 const size_t* w0_bucket_start,
234 size_t b_lo,
235 size_t b_hi,
238 std::span<typename Curve::AffineElement> extra_points,
239 std::span<uint32_t> redirect_lookup,
240 const uint8_t* msb_per_scalar,
241 size_t c_threshold,
242 uint32_t cid_lo,
243 uint32_t cid_max,
244 PhaseAScratch<Curve>& scratch) noexcept
245{
246 using AffineElement = typename Curve::AffineElement;
247 using BaseField = typename Curve::BaseField;
248 constexpr uint32_t HT_EMPTY = ~uint32_t{ 0 };
249
250 // Per-thread hash table — sized for the largest expected bucket. Long-
251 // scalar density per bucket is highly NON-uniform on chonk wires: the few
252 // buckets corresponding to digit_0 ∈ {1,2,3,…} hold 700+ long entries with
253 // 500+ distinct values. A 256-slot table fills up and the open-addressing
254 // probe goes infinite. A 4096-slot table keeps load <25% even on the worst
255 // bucket. 4096 × 4 = 16 KB per thread.
256 //
257 // We use LAZY CLEARING via a dirty-slot list rather than std::fill_n per
258 // bucket: a 16 KB fill × ~2 K buckets × 8 threads = 256 MB of write traffic
259 // per Phase A, which dominates the cluster-scan wall (≈ 280 ms / 450 ms on
260 // the WASM trace). With lazy clear the per-bucket reset cost scales with
261 // the number of slots ACTUALLY written (typically 25-700), not 4096.
262 constexpr size_t HT_SIZE = 4096;
263 constexpr size_t HT_MASK = HT_SIZE - 1;
264 static_assert((HT_SIZE & (HT_SIZE - 1)) == 0, "HT_SIZE must be a power of 2");
266
267 // The hash table maps scalar_value → either (a) the singleton scalar_idx
268 // observed first, or (b) a sentinel pointing into cluster_members for an
269 // already-opened cluster. We disambiguate via a separate parallel slot
270 // status array (bit-set if slot holds a cluster pointer). To keep the data
271 // structure simple, we instead use TWO sentinel bits in the high end of
272 // the uint32_t scalar_idx:
273 // high bit clear → slot holds a singleton scalar_idx (just one observation)
274 // high bit set → slot holds (cluster_id | HT_CLUSTER_BIT)
275 // scalar_idx values fit the schedule payload (29 bits), so the top 3 bits are free.
276 constexpr uint32_t HT_CLUSTER_BIT = uint32_t{ 1 } << 31;
277
278 // Per-worker arena-backed scratch spans. Caller allocates `scratch` once at the start
279 // of the MSM (see `pippenger_round_parallel_internal`); we treat them as bounded
280 // capacity buffers with a logical-size cursor. No allocator churn, no thread_local
281 // process state to clean up after the MSM returns.
282 uint32_t* const cluster_members_data = scratch.cluster_members.data();
283 const size_t cluster_members_cap = scratch.cluster_members.size();
284 size_t cluster_members_size = 0;
285 uint32_t* const cluster_offsets_data = scratch.cluster_offsets.data();
286 const size_t cluster_offsets_cap = scratch.cluster_offsets.size();
287 size_t cluster_offsets_size = 0;
288 uint16_t* const dirty_slots_data = scratch.dirty_slots.data();
289 const size_t dirty_slots_cap = scratch.dirty_slots.size();
290 size_t dirty_slots_size = 0;
291 {
292 BB_BENCH_NAME("MSM::PhaseA/alloc_buffers");
293 // Cluster offsets always pushes a 0 sentinel first.
294 BB_ASSERT_GTE(cluster_offsets_cap, size_t{ 1 });
295 cluster_offsets_data[cluster_offsets_size++] = 0;
296 }
297
298 // Initial fill — uninitialised stack memory could match HT_EMPTY values
299 // by coincidence. After this, clearing is incremental via dirty_slots.
300 std::fill_n(ht.data(), HT_SIZE, HT_EMPTY);
301 // Slot-local one-limb fingerprints for occupied hash-table entries. They are
302 // valid iff `ht[slot] != HT_EMPTY`; lazy clearing only needs to reset `ht`.
303 std::array<uint64_t, HT_SIZE> ht_fingerprint;
304
305 uint32_t clusters_opened = 0;
306 {
307 BB_BENCH_NAME("MSM::PhaseA/cluster_scan");
308 // Per-bucket scratch — both backed by arena spans (caller-allocated).
309 // - `bucket_rep[bucket_cid]` = scalar_idx of the rep for that in-bucket cluster.
310 // - `staged[..]` = (bucket_cid, idx) pairs awaiting cluster emission.
311 uint32_t* const bucket_rep_data = scratch.bucket_rep.data();
312 const size_t bucket_rep_cap = scratch.bucket_rep.size();
313 size_t bucket_rep_size = 0;
314 std::pair<uint32_t, uint32_t>* const staged_data = scratch.staged.data();
315 const size_t staged_cap = scratch.staged.size();
316 size_t staged_size = 0;
317
318 for (size_t b = b_lo; b < b_hi; ++b) {
319 const size_t lo = w0_bucket_start[b];
320 const size_t hi = w0_bucket_start[b + 1];
321 if (hi - lo < 2) {
322 continue;
323 }
324
325 // Lazy clear: reset only slots dirtied by the previous bucket.
326 for (size_t k = 0; k < dirty_slots_size; ++k) {
327 ht[dirty_slots_data[k]] = HT_EMPTY;
328 }
329 dirty_slots_size = 0;
330 bucket_rep_size = 0;
331 staged_size = 0;
332
333 for (size_t i = lo; i < hi; ++i) {
334 const uint32_t idx = schedule_w0[i] & SCHEDULE_INDEX_MASK;
335 if (static_cast<size_t>(msb_per_scalar[idx]) < c_threshold) {
336 continue;
337 }
338 const uint64_t* d = scalars[idx].data;
339 const uint64_t fingerprint = dedup_scalar_fingerprint(d);
340 size_t slot = dedup_fingerprint_slot(fingerprint, HT_MASK);
341
342 // Probe-count safety net. With HT_SIZE = 4096 and per-bucket distinct-
343 // long-value counts up to ~700 on chonk wires, table load is ≤ 17 %
344 // and the average probe length is ≈ 1.1 — but if any future workload
345 // produces a bucket dense enough to fill the table, fall back to
346 // "treat as singleton, don't dedup" rather than infinite-loop.
347 size_t probe_count = 0;
348 while (true) {
349 if (++probe_count > HT_SIZE) {
350 break;
351 }
352 const uint32_t entry = ht[slot];
353 if (entry == HT_EMPTY) {
354 ht[slot] = idx;
355 ht_fingerprint[slot] = fingerprint;
356 // If the dirty-slot list overflows its cap we must NOT skip the
357 // record — every subsequent bucket would then leak slots forward.
358 // Cap is HT_SIZE so this is structurally unreachable.
359 if (BB_LIKELY(dirty_slots_size < dirty_slots_cap)) {
360 dirty_slots_data[dirty_slots_size++] = static_cast<uint16_t>(slot);
361 }
362 break;
363 }
364 if ((entry & HT_CLUSTER_BIT) != 0) {
365 const uint32_t bucket_cid = entry & ~HT_CLUSTER_BIT;
366 const uint32_t rep = bucket_rep_data[bucket_cid];
367 if (ht_fingerprint[slot] == fingerprint &&
368 std::memcmp(d, scalars[rep].data, sizeof(scalars[rep].data)) == 0) {
369 // Out of staged-pair capacity: leave this duplicate un-deduped
370 // (it will go through the standard pippenger path).
371 if (BB_UNLIKELY(staged_size >= staged_cap)) {
372 break;
373 }
374 staged_data[staged_size++] = { bucket_cid, idx };
375 break;
376 }
377 slot = (slot + 1) & HT_MASK;
378 continue;
379 }
380 // Singleton at slot: compare values.
381 if (ht_fingerprint[slot] == fingerprint &&
382 std::memcmp(d, scalars[entry].data, sizeof(scalars[entry].data)) == 0) {
383 if (clusters_opened >= (cid_max - cid_lo)) {
384 break; // cap reached, leave un-deduped
385 }
386 // Out of bucket_rep / staged capacity: leave un-deduped.
387 if (BB_UNLIKELY(bucket_rep_size >= bucket_rep_cap || staged_size >= staged_cap)) {
388 break;
389 }
390 const uint32_t bucket_cid = static_cast<uint32_t>(bucket_rep_size);
391 bucket_rep_data[bucket_rep_size++] = entry;
392 staged_data[staged_size++] = { bucket_cid, idx };
393 ht[slot] = HT_CLUSTER_BIT | bucket_cid;
394 ++clusters_opened;
395 break;
396 }
397 slot = (slot + 1) & HT_MASK;
398 }
399 }
400
401 if (bucket_rep_size == 0) {
402 continue;
403 }
404
405 // Sort staged non-reps by bucket_cid so each cluster's members are
406 // contiguous; then emit (rep, non-reps...) per cluster.
407 std::stable_sort(staged_data,
408 staged_data + staged_size,
410 const std::pair<uint32_t, uint32_t>& b) noexcept { return a.first < b.first; });
411 size_t staged_cursor = 0;
412 for (size_t bc = 0; bc < bucket_rep_size; ++bc) {
413 // Compute this cluster's member count up front (rep + staged non-reps with
414 // matching bucket_cid) so we never split a cluster across the slab cap.
415 // When the next cluster would overflow cluster_members_cap, break cleanly:
416 // un-flattened cluster reps/members never get a redirect_lookup entry, so
417 // Stage 4/6a process them as normal scalars with their original signed
418 // digits. The MSM sum is unchanged; we just deliver less dedup work.
419 size_t this_cluster_members = 1; // rep
420 for (size_t sc = staged_cursor; sc < staged_size && staged_data[sc].first == bc; ++sc) {
421 ++this_cluster_members;
422 }
423 if (cluster_members_size + this_cluster_members > cluster_members_cap) {
424 break;
425 }
426 cluster_members_data[cluster_members_size++] = bucket_rep_data[bc];
427 while (staged_cursor < staged_size && staged_data[staged_cursor].first == bc) {
428 cluster_members_data[cluster_members_size++] = staged_data[staged_cursor].second;
429 ++staged_cursor;
430 }
431 // cluster_offsets cap is provably non-overflow given clusters_opened ≤
432 // cids_per_thread and cluster_offsets_cap = cids_per_thread + 2; the
433 // initial 0 sentinel plus at most cids_per_thread end-offsets fits.
434 cluster_offsets_data[cluster_offsets_size++] = static_cast<uint32_t>(cluster_members_size);
435 }
436 }
437 } // MSM::PhaseA/cluster_scan
438
439 // Only flattened clusters are published. `clusters_opened` counts every promoted
440 // hash-table singleton, including clusters later skipped because cluster_members_cap
441 // would be exceeded. Skipped clusters intentionally fall through the normal Pippenger
442 // path because they never get redirect_lookup entries.
443 const size_t num_clusters = cluster_offsets_size - 1;
444 if (num_clusters == 0) {
445 return 0;
446 }
447
448 // For tree_reduce we need a single contiguous member list; cluster_members_data is
449 // already such a list, with [cluster_offsets[k], cluster_offsets[k+1]) per cluster.
450 // cluster_offsets_size = num_clusters + 1 (initial 0 sentinel + one push per cluster).
451 BB_ASSERT_EQ(cluster_offsets_size, num_clusters + 1, "cluster_offsets layout mismatch");
452
453 {
454 BB_BENCH_NAME("MSM::PhaseA/tree_reduce");
455 typename Curve::AffineElement* const chunk_pts_data = scratch.chunk_pts.data();
456 uint32_t* const chunk_ids_data = scratch.chunk_ids.data();
457 const size_t chunk_cap = scratch.chunk_pts.size();
459 BB_ASSERT_GTE(scratch.chunk_ids.size(), DEDUP_MAX_CHUNK_MEMBERS);
460 size_t chunk_size = 0;
461 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
463 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
465 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
467
468 size_t cid_cursor = 0;
469 size_t member_offset_in_cluster = 0;
470 AffineElement carry{};
471 bool has_carry = false;
472
473 while (cid_cursor < num_clusters || has_carry) {
474 chunk_size = 0;
475 if (has_carry) {
476 chunk_pts_data[chunk_size] = carry;
477 chunk_ids_data[chunk_size] = static_cast<uint32_t>(cid_cursor);
478 ++chunk_size;
479 has_carry = false;
480 }
481 while (cid_cursor < num_clusters && chunk_size < DEDUP_MAX_CHUNK_MEMBERS) {
482 const size_t cluster_lo = cluster_offsets_data[cid_cursor] + member_offset_in_cluster;
483 const size_t cluster_hi = cluster_offsets_data[cid_cursor + 1];
484 const size_t available = cluster_hi - cluster_lo;
485 const size_t room = DEDUP_MAX_CHUNK_MEMBERS - chunk_size;
486 if (available <= room) {
487 for (size_t k = 0; k < available; ++k) {
488 chunk_pts_data[chunk_size] = points[cluster_members_data[cluster_lo + k]];
489 chunk_ids_data[chunk_size] = static_cast<uint32_t>(cid_cursor);
490 ++chunk_size;
491 }
492 ++cid_cursor;
493 member_offset_in_cluster = 0;
494 } else {
495 for (size_t k = 0; k < room; ++k) {
496 chunk_pts_data[chunk_size] = points[cluster_members_data[cluster_lo + k]];
497 chunk_ids_data[chunk_size] = static_cast<uint32_t>(cid_cursor);
498 ++chunk_size;
499 }
500 member_offset_in_cluster += room;
501 break;
502 }
503 }
504 const size_t result_len = dedup_tree_reduce_in_place<Curve>(chunk_pts_data,
505 chunk_ids_data,
506 chunk_size,
507 scratch_pts.data(),
508 pair_dest.data(),
509 inversion_scratch.data());
510 const bool last_is_partial = (cid_cursor < num_clusters) && (member_offset_in_cluster > 0);
511 const size_t whole_count = last_is_partial ? result_len - 1 : result_len;
512 for (size_t k = 0; k < whole_count; ++k) {
513 const uint32_t local_cid = chunk_ids_data[k];
514 extra_points[cid_lo + local_cid] = chunk_pts_data[k];
515 }
516 if (last_is_partial) {
517 carry = chunk_pts_data[result_len - 1];
518 has_carry = true;
519 }
520 }
521 } // MSM::PhaseA/tree_reduce
522
523 {
524 BB_BENCH_NAME("MSM::PhaseA/publish_redirects");
525 for (size_t k = 0; k < num_clusters; ++k) {
526 const size_t mlo = cluster_offsets_data[k];
527 const size_t mhi = cluster_offsets_data[k + 1];
528 const uint32_t rep_idx = cluster_members_data[mlo];
529 const uint32_t global_cid = cid_lo + static_cast<uint32_t>(k);
530 redirect_lookup[rep_idx] = DEDUP_REDIRECT_BIT | global_cid;
531 for (size_t m = mlo + 1; m < mhi; ++m) {
532 const uint32_t non_rep_idx = cluster_members_data[m];
533 redirect_lookup[non_rep_idx] = DEDUP_SKIP_BIT | non_rep_idx;
534 }
535 }
536 }
537
538 return num_clusters;
539}
540
541// Post-Phase-A schedule patcher. Walks a window's already-emitted bucket runs,
542// rewrites entries whose scalar_idx has an encoded dedup payload, and compacts
543// non-rep DEDUP_SKIP entries out of the schedule.
544// The hot Stage 4 emit loop is now dedup-unaware (plain `sched_w[idx] = sign | scalar_idx`);
545// all dedup tagging happens here.
546//
547// This is a free function — NOT a lambda capturing dedup_state by reference — so
548// `redirect_lookup` is passed as a raw pointer argument and the inner loop has no
549// closure-indirection chain. The only random load per iter is the single
550// `redirect_lookup[scalar_idx]` lookup, which lands in L2 for typical MSM sizes.
551// `bucket_start` is rewritten in place, so each old bucket end is saved before
552// its prefix slot is overwritten with the compacted end.
553template <typename Curve>
554[[gnu::flatten]] inline void dedup_patch_schedule_window(uint32_t* __restrict sched_w,
555 size_t* __restrict bucket_start,
556 size_t num_buckets,
557 const uint32_t* __restrict redirect_lookup) noexcept
558{
559 static_cast<void>(static_cast<Curve*>(nullptr)); // template tag for symbol disambiguation
560 size_t write = 0;
561 size_t old_bucket_start = bucket_start[0];
562 bucket_start[0] = 0;
563 for (size_t bucket = 0; bucket < num_buckets; ++bucket) {
564 const size_t old_bucket_end = bucket_start[bucket + 1];
565 for (size_t read = old_bucket_start; read < old_bucket_end; ++read) {
566 const uint32_t e = sched_w[read];
567 const uint32_t idx = e & SCHEDULE_INDEX_MASK;
568 const uint32_t patch = redirect_lookup[idx];
569 uint32_t out = e;
570 if (BB_UNLIKELY(patch != DEDUP_INVALID_EXTRA)) {
571 if ((patch & DEDUP_SKIP_BIT) != 0) {
572 continue;
573 }
574 out = (e & SCHEDULE_SIGN_BIT) | patch;
575 }
576 if (write != read || out != e) {
577 sched_w[write] = out;
578 }
579 ++write;
580 }
581 old_bucket_start = old_bucket_end;
582 bucket_start[bucket + 1] = write;
583 }
584}
585
586} // namespace bb::scalar_multiplication::round_parallel_detail
#define BB_ASSERT_GTE(left, right,...)
Definition assert.hpp:128
#define BB_ASSERT_EQ(actual, expected,...)
Definition assert.hpp:83
#define BB_BENCH_NAME(name)
Definition bb_bench.hpp:264
typename Group::affine_element AffineElement
Definition grumpkin.hpp:64
#define BB_UNLIKELY(x)
#define BB_LIKELY(x)
FF a
FF b
size_t dedup_tree_reduce_in_place(typename Curve::AffineElement *pts, uint32_t *ids, size_t initial_len, typename Curve::AffineElement *scratch_pts, uint32_t *pair_dest, typename Curve::BaseField *inversion_scratch) noexcept
uint64_t dedup_scalar_fingerprint(const uint64_t *scalar_data) noexcept
void dedup_patch_schedule_window(uint32_t *__restrict sched_w, size_t *__restrict bucket_start, size_t num_buckets, const uint32_t *__restrict redirect_lookup) noexcept
size_t dedup_fingerprint_slot(uint64_t fingerprint, size_t mask) noexcept
size_t dedup_phase_a_worker_hash(const uint32_t *schedule_w0, const size_t *w0_bucket_start, size_t b_lo, size_t b_hi, std::span< const typename Curve::ScalarField > scalars, std::span< const typename Curve::AffineElement > points, std::span< typename Curve::AffineElement > extra_points, std::span< uint32_t > redirect_lookup, const uint8_t *msb_per_scalar, size_t c_threshold, uint32_t cid_lo, uint32_t cid_max, PhaseAScratch< Curve > &scratch) noexcept
void read(B &it, field2< base_field, Params > &value)
void write(B &buf, field2< base_field, Params > const &value)
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
std::byte * data
std::span< uint32_t > pair_dest
std::span< BaseField > inversion_scratch
std::span< typename Curve::AffineElement > extra_points