Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
pippenger_batched.hpp
Go to the documentation of this file.
1#pragma once
2
3// Implementation fragment included from scalar_multiplication_fast.cpp inside
4// bb::scalar_multiplication, after pippenger_round_parallel is defined.
5
6// Multi-MSM_fast driver for `MSM_fast<>::batch_multi_scalar_mul`. The hot path
7// (`CommitmentKey::batch_commit` from `commit_to_wires`) batches K MSMs sharing the same
8// SRS subspan. We do NOT interleave K MSMs inside a single parallel_for body — that
9// K-multiplies the per-thread working set and forces windows_in_batch=1; the single-MSM_fast
10// hot path is tuned to fit ~4 MiB in L2 and we want to preserve that. The loop is just
11// for m in 0..K: run single-MSM_fast dispatch for MSM_fast m.
12// The only cross-MSM_fast amortisation is the GLV-doubled point set: when every member of a
13// shared-SRS-prefix group wants GLV, we double the prefix once into a shared buffer and
14// each per-MSM_fast call aliases its prefix instead of doubling its own.
16
17// One per shared-SRS-prefix group. Membership is keyed on identical
18// `point_arrays[m].data()` pointers — that is the actual sharing relation
19// `commit_to_wires` exposes. Static-lifetime so the doubled buffer survives
20// across calls (typical workloads commit the same SRS prefix repeatedly).
21template <typename Curve> struct BatchMsmGlvGroup {
22 const typename Curve::AffineElement* base_ptr = nullptr; // SRS prefix pointer
23 size_t group_max_n = 0; // max n_input across MSMs in this group
24 std::span<typename Curve::AffineElement> doubled; // length 2 * group_max_n; aliases a prefix of
25 // the master-group buffer (computed once for
26 // the largest GLV-using group). Layout
27 // `[P_0, φP_0, P_1, φP_1, …]` — the first 2*n
28 // entries are the per-MSM_fast view for n ≤ Nmax.
29 std::vector<size_t> member_msms; // indices into `scalar_arrays` of MSMs in this group
30};
31
32} // namespace round_parallel_detail
33
34namespace {
35// NOLINTNEXTLINE(readability-function-size, readability-function-cognitive-complexity,
36// google-readability-function-size)
37template <typename Curve>
38void pippenger_round_parallel_batched(std::span<std::span<typename Curve::ScalarField>> scalar_arrays,
41 std::span<const uint8_t> dedup_hints = {}) noexcept
42{
43 using AffineElement = typename Curve::AffineElement;
44 using ScalarField = typename Curve::ScalarField;
45 using BaseField = typename Curve::BaseField;
46
47 BB_BENCH_NAME("MSM_fast::pippenger_round_parallel_batched");
48
49 const size_t K = scalar_arrays.size();
50 BB_ASSERT_EQ(point_arrays.size(), K);
51 out_results.assign(K, Curve::Group::point_at_infinity);
52
53 auto hint_for = [&](size_t m) noexcept -> bool { return m < dedup_hints.size() && dedup_hints[m] != 0; };
54
55 if (K == 0) {
56 return;
57 }
58 if (K == 1) {
59 const size_t n = std::min(scalar_arrays[0].size(), point_arrays[0].size());
60 if (n == 0) {
61 return;
62 }
63 PolynomialSpan<const ScalarField> sp(0, std::span<const ScalarField>(scalar_arrays[0].data(), n));
64 out_results[0] = pippenger_round_parallel<Curve>(sp, point_arrays[0], hint_for(0));
65 return;
66 }
67
68 std::vector<size_t> n_input(K);
69 for (size_t m = 0; m < K; ++m) {
70 n_input[m] = std::min(scalar_arrays[m].size(), point_arrays[m].size());
71 }
72
73 // Group MSMs by shared SRS pointer; one shared GLV-doubled buffer per group, sized to
74 // group_max_n. group_uses_glv is a per-group bool but the per-MSM_fast internal dispatch keeps
75 // each MSM_fast's own GLV decision in case shared doubling is skipped.
77 std::vector<GlvGroup> glv_groups;
78
79 auto find_or_create_group = [&](const AffineElement* base_ptr, size_t n) -> size_t {
80 for (size_t g = 0; g < glv_groups.size(); ++g) {
81 if (glv_groups[g].base_ptr == base_ptr) {
82 glv_groups[g].group_max_n = std::max(glv_groups[g].group_max_n, n);
83 return g;
84 }
85 }
86 GlvGroup g{};
87 g.base_ptr = base_ptr;
88 g.group_max_n = n;
89 glv_groups.push_back(std::move(g));
90 return glv_groups.size() - 1;
91 };
92
93 std::vector<size_t> msm_to_group(K, std::numeric_limits<size_t>::max());
94 for (size_t m = 0; m < K; ++m) {
95 if (n_input[m] == 0) {
96 continue;
97 }
98 const size_t g = find_or_create_group(point_arrays[m].data(), n_input[m]);
99 glv_groups[g].member_msms.push_back(m);
100 msm_to_group[m] = g;
101 }
102
103 std::vector<bool> group_uses_glv(glv_groups.size(), false);
104 for (size_t g = 0; g < glv_groups.size(); ++g) {
105 // GLV decision is per-group on group_max_n. Within a group, every MSM_fast has
106 // n[m] <= group_max_n; if group_max_n is in the small-N regime, every MSM_fast
107 // is too, so they all want GLV. If group_max_n is in the large-N regime,
108 // no MSM_fast in the group wants GLV (they'd be slower with it).
109 group_uses_glv[g] = glv_groups[g].group_max_n <= round_parallel_detail::GLV_SMALL_N_THRESHOLD;
110 }
111
112 // Build ONE shared GLV-doubled buffer covering the union of every GLV-using group's
113 // SRS range, then alias each group's `doubled` into a slice of that buffer.
114 //
115 // Every production / test caller of batch_multi_scalar_mul is `commitment_key.batch_commit`,
116 // which constructs each MSM_fast's point span as `get_monomial_points().subspan(start_index)`
117 // — sub-spans of a single contiguous `std::vector<AffineElement>` SRS. So in every
118 // batch every group's `base_ptr` lives in the same allocation and offsets are
119 // necessarily integer multiples of `sizeof(AffineElement)`. The asserts below
120 // catch a future caller that violates that contract.
121 std::unique_ptr<AffineElement[]> master_doubled_owner; // NOLINT(cppcoreguidelines-avoid-c-arrays)
122 {
123 BB_BENCH_NAME("MSM_fast::pippenger_round_parallel_batched/glv_double_points");
124
125 const AffineElement* min_base = nullptr;
126 for (size_t g = 0; g < glv_groups.size(); ++g) {
127 glv_groups[g].doubled = {};
128 if (!group_uses_glv[g]) {
129 continue;
130 }
131 if (min_base == nullptr || std::less<const AffineElement*>{}(glv_groups[g].base_ptr, min_base)) {
132 min_base = glv_groups[g].base_ptr;
133 }
134 }
135
136 if (min_base != nullptr) {
137 const auto min_addr = reinterpret_cast<uintptr_t>(min_base);
138 size_t max_extent_units = 0;
139 for (size_t g = 0; g < glv_groups.size(); ++g) {
140 if (!group_uses_glv[g]) {
141 continue;
142 }
143 const auto base_addr = reinterpret_cast<uintptr_t>(glv_groups[g].base_ptr);
144 const uintptr_t offset_bytes = base_addr - min_addr;
145 BB_ASSERT_EQ(offset_bytes % sizeof(AffineElement),
146 size_t{ 0 },
147 "GLV group base_ptr not aligned to AffineElement boundary "
148 "(point spans must be subranges of a contiguous AffineElement array)");
149 const size_t offset_units = offset_bytes / sizeof(AffineElement);
150 const size_t end_units = offset_units + glv_groups[g].group_max_n;
151 max_extent_units = std::max(max_extent_units, end_units);
152 }
153
155 2 * max_extent_units); // NOLINT(cppcoreguidelines-avoid-c-arrays)
156 AffineElement* const master_buf = master_doubled_owner.get();
157 const BaseField beta = BaseField::cube_root_of_unity();
158 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
159 for (size_t i : chunk.range(max_extent_units)) {
160 master_buf[2 * i] = min_base[i];
161 master_buf[(2 * i) + 1].x = min_base[i].x * beta;
162 master_buf[(2 * i) + 1].y = -min_base[i].y;
163 }
164 });
165
166 for (size_t g = 0; g < glv_groups.size(); ++g) {
167 if (!group_uses_glv[g]) {
168 continue;
169 }
170 const auto base_addr = reinterpret_cast<uintptr_t>(glv_groups[g].base_ptr);
171 const size_t offset_units = (base_addr - min_addr) / sizeof(AffineElement);
172 glv_groups[g].doubled =
173 std::span<AffineElement>(master_buf + (2 * offset_units), 2 * glv_groups[g].group_max_n);
174 }
175 }
176 }
177
178 // Shared dynamically-sized arena for all per-MSM_fast internal calls. Sized to the max
179 // requirement across the batch so each MSM_fast finds enough space. Single allocation
180 // across the batch (vs one per MSM_fast if we passed {} down). Freed at return.
181 // dedup_active varies per MSM_fast (gated by per-MSM_fast hint), so the budget query must
182 // mirror the predicate used inside pippenger_round_parallel.
183 size_t shared_arena_bytes = 0;
184 for (size_t m = 0; m < K; ++m) {
185 if (n_input[m] == 0) {
186 continue;
187 }
188 const size_t g = msm_to_group[m];
189 const bool ext_glv =
190 g != std::numeric_limits<size_t>::max() && group_uses_glv[g] && !glv_groups[g].doubled.empty();
191 // The internal short-circuits to trivial_msm_threaded for tiny MSMs, so the hint
192 // alone is the right arena-sizing predicate (over-sizing for a path that bails
193 // is harmless — under-sizing would crash).
194 const bool dedup_active_m = hint_for(m);
195 const size_t bytes = compute_arena_bytes_for_msm<Curve>(n_input[m], ext_glv, dedup_active_m);
196 shared_arena_bytes = std::max(shared_arena_bytes, bytes);
197 }
198 std::unique_ptr<std::byte[]> shared_arena_owner; // NOLINT(cppcoreguidelines-avoid-c-arrays)
199 std::span<std::byte> shared_arena;
200 if (shared_arena_bytes > 0) {
201 shared_arena_owner =
202 std::make_unique_for_overwrite<std::byte[]>(shared_arena_bytes); // NOLINT(cppcoreguidelines-avoid-c-arrays)
203 shared_arena = std::span<std::byte>(shared_arena_owner.get(), shared_arena_bytes);
204 }
205
206 // Per-MSM_fast dispatch. Each call runs the full single-MSM_fast pipeline (its own from-Mont and
207 // to-Mont, schedule, Stage 1-6b). The only batched amortisation we share is the doubled
208 // SRS prefix above; the rest of the hot path runs at single-MSM_fast cost.
209 for (size_t m = 0; m < K; ++m) {
210 const size_t n = n_input[m];
211 if (n == 0) {
212 continue;
213 }
214 PolynomialSpan<const ScalarField> sp(0, std::span<const ScalarField>(scalar_arrays[m].data(), n));
215
216 const size_t g = msm_to_group[m];
217 std::span<const AffineElement> external_glv;
218 if (g != std::numeric_limits<size_t>::max() && group_uses_glv[g]) {
219 // `group.doubled` is interleaved `[P_0, φP_0, …]` of length 2*Nmax. The
220 // first 2*n entries are exactly the per-MSM_fast `[P_0, φP_0, …, P_{n-1}, φP_{n-1}]`
221 // view, regardless of whether n == Nmax (uniform batch) or n < Nmax (ragged).
222 external_glv = std::span<const AffineElement>(glv_groups[g].doubled.data(), 2 * n);
223 }
224
225 out_results[m] = pippenger_round_parallel<Curve>(sp, point_arrays[m], hint_for(m), external_glv, shared_arena);
226 }
227}
228} // namespace
229
230template <typename Curve>
234 bool handle_edge_cases,
235 std::span<const uint8_t> dedup_hints) noexcept
236{
237 BB_BENCH_NAME("MSM_fast::batch_multi_scalar_mul");
238 const size_t k = scalars.size();
239
240 // Adapt the new (single shared points span + per-MSM_fast PolynomialSpan scalars) API to
241 // the internal dispatcher, which still takes one point sub-span per MSM_fast. Each MSM_fast's
242 // sub-span is `points[start_index .. start_index + size)`; the dispatcher's existing
243 // GLV-doubled-buffer grouping then deduplicates across MSMs that fall in the same
244 // underlying allocation.
246 std::vector<std::span<ScalarField>> scalar_subspans;
247 point_subspans.reserve(k);
248 scalar_subspans.reserve(k);
249 for (size_t i = 0; i < k; ++i) {
250 const size_t start_i = scalars[i].start_index;
251 BB_ASSERT_LTE(start_i, points.size(), "scalars[m].start_index exceeds shared points span");
252 point_subspans.push_back(points.subspan(start_i, points.size() - start_i));
253 scalar_subspans.push_back(scalars[i].span);
254 }
255
256 auto hint_for = [&](size_t m) noexcept -> bool { return m < dedup_hints.size() && dedup_hints[m] != 0; };
257
258 if (handle_edge_cases) {
259 std::vector<AffineElement> results(k);
260 for (size_t i = 0; i < k; ++i) {
261 const size_t n = std::min(point_subspans[i].size(), scalar_subspans[i].size());
263 std::span<const ScalarField>(scalar_subspans[i].data(), n));
264 results[i] =
265 AffineElement(pippenger_fast<Curve>(scalar_span, point_subspans[i], handle_edge_cases, hint_for(i)));
266 }
267 return results;
268 }
269
271 pippenger_round_parallel_batched<Curve>(scalar_subspans, point_subspans, per_msm_jac, dedup_hints);
272
273 std::vector<AffineElement> results(k);
274 for (size_t i = 0; i < k; ++i) {
275 results[i] = AffineElement(per_msm_jac[i]);
276 }
277 return results;
278}
#define BB_ASSERT_EQ(actual, expected,...)
Definition assert.hpp:83
#define BB_ASSERT_LTE(left, right,...)
Definition assert.hpp:158
#define BB_BENCH_NAME(name)
Definition bb_bench.hpp:264
bb::fq BaseField
Definition bn254.hpp:19
typename Group::affine_element AffineElement
Definition bn254.hpp:22
bb::fr ScalarField
Definition bn254.hpp:18
size_t get_num_cpus()
Definition thread.cpp:33
void parallel_for(size_t num_iterations, const std::function< void(size_t)> &func)
Definition thread.cpp:111
constexpr void g(state_array &state, size_t a, size_t b, size_t c, size_t d, uint32_t x, uint32_t y)
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
uintptr_t base_addr
std::byte * data
std::span< typename Curve::AffineElement > doubled