Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
small_msm_matrix.bench.cpp
Go to the documentation of this file.
1
29#include <algorithm>
30#include <chrono>
31#include <cstddef>
32#include <cstdio>
33#include <span>
34#include <string>
35#include <vector>
36
41
42namespace {
43
44// Median wall-clock ns across `iters` invocations of `run`.
45template <typename Run> double median_ns(Run&& run, size_t iters)
46{
47 std::vector<double> samples(iters);
48 for (size_t i = 0; i < iters; ++i) {
49 const auto t0 = std::chrono::steady_clock::now();
50 run();
51 const auto t1 = std::chrono::steady_clock::now();
52 samples[i] = static_cast<double>(std::chrono::duration_cast<std::chrono::nanoseconds>(t1 - t0).count());
53 }
54 std::sort(samples.begin(), samples.end());
55 return samples[samples.size() / 2];
56}
57
58// WASM-tuned iteration counts. Quadrupled from the previous tuning to damp
59// per-cell variance — each cell now budgets ~200 ms–1 s wall time.
60size_t pick_iters(size_t n)
61{
62 if (n <= 4) {
63 return 800;
64 }
65 if (n <= 16) {
66 return 400;
67 }
68 if (n <= 64) {
69 return 200;
70 }
71 if (n <= 256) {
72 return 100;
73 }
74 if (n <= 1024) {
75 return 32;
76 }
77 if (n <= 4096) {
78 return 16;
79 }
80 return 8;
81}
82
83void print_matrix_header(const std::vector<size_t>& ns)
84{
85 std::printf("%-24s", "N");
86 for (size_t n : ns) {
87 std::printf(" %12zu", n);
88 }
89 std::printf("\n");
90}
91
92// `mask[i] == false` skips column i (prints "-" instead of a number).
93void print_matrix_row(const char* label, const std::vector<double>& ns_per_run, const std::vector<bool>& mask)
94{
95 std::printf("%-24s", label);
96 for (size_t i = 0; i < ns_per_run.size(); ++i) {
97 if (mask[i]) {
98 std::printf(" %12.0f", ns_per_run[i]);
99 } else {
100 std::printf(" %12s", "-");
101 }
102 }
103 std::printf("\n");
104}
105
106// Phase 1: precise crossover sweep — at every N in {32, 34, ..., 64}, compare
107// single-threaded `straus_msm` against single-threaded `jac_fast`. Returns the
108// smallest N where jac_fast wins (or 0 if jac never wins in-range).
109size_t run_crossover_sweep(std::span<const G1> all_points, std::span<const Fr> scalars)
110{
111 std::printf("\n=== MIN_JACOBIAN_SIZE crossover sweep (single-threaded straus_msm vs jac_fast, ns) ===\n\n");
112 std::printf("%-8s %12s %12s %10s\n", "N", "straus", "jac_st", "delta_%");
113
114 size_t crossover = 0;
115 constexpr size_t REPEATS = 3;
116 for (size_t n = 32; n <= 64; n += 2) {
117 std::span<const G1> points = all_points.subspan(0, n);
118 std::span<const Fr> scalars_view(scalars.data(), n);
119 const size_t iters = pick_iters(n);
120
121 std::vector<double> straus_samples(REPEATS);
122 std::vector<double> jac_samples(REPEATS);
123 for (size_t r = 0; r < REPEATS; ++r) {
124 straus_samples[r] = median_ns(
125 [&] {
126 volatile auto v = Element::straus_msm(points, scalars_view);
127 (void)v;
128 },
129 iters);
130 jac_samples[r] = median_ns(
131 [&] {
132 volatile auto v =
133 bb::scalar_multiplication::round_parallel_detail::pippenger_round_parallel_jacobian_fast<Curve>(
134 scalars_view, points, /*min_pts_per_thread_override=*/SIZE_MAX);
135 (void)v;
136 },
137 iters);
138 }
139 std::sort(straus_samples.begin(), straus_samples.end());
140 std::sort(jac_samples.begin(), jac_samples.end());
141 const double straus = straus_samples[REPEATS / 2];
142 const double jac = jac_samples[REPEATS / 2];
143 const double delta_pct = 100.0 * (jac - straus) / straus;
144 std::printf("%-8zu %12.0f %12.0f %+10.2f\n", n, straus, jac, delta_pct);
145 if (crossover == 0 && jac < straus) {
146 crossover = n;
147 }
148 }
149 if (crossover != 0) {
150 std::printf("\nFirst N where jac_fast_st_always beats straus_msm: %zu\n", crossover);
151 } else {
152 std::printf("\nstraus_msm wins across the entire 32..64 sweep.\n");
153 }
154 return crossover;
155}
156
157void run_matrix()
158{
159 constexpr size_t MAX_N = 1U << 14;
160
161 // Initialise SRS once and reuse the same point span across all cells.
163 auto srs = bb::srs::get_crs_factory<Curve>()->get_crs(MAX_N);
164 std::span<const G1> all_points = srs->get_monomial_points().subspan(0, MAX_N);
165
167 std::vector<Fr> scalars(MAX_N);
168 for (auto& s : scalars) {
170 }
171
172 // Phase 1: precise crossover sweep — disabled for the N=1..128 sub-range run.
173 (void)&run_crossover_sweep;
174
175 // Phase 2: full matrix.
176 // Column set — sweep small-MSM regime where the four methods can disagree.
177 // Includes powers of 2 plus a few intermediate values around the suspected
178 // jacobian crossover, extended out to 16384 since small_mul_threaded was
179 // still beating jac_fast_mt at 8192.
180 const std::vector<size_t> ns = {
181 1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48, 64, 96,
182 128, 192, 256, 384, 512, 768, 1024, 2048, 4096, 8192, 12288, 16384,
183 };
184
185 // Per-column masks. straus_msm is dropped at N >= 256 (its naive double-and-add
186 // cost dominates the schedule and saturates the iteration budget). The two
187 // pippenger_round_parallel variants kick in at N >= 64.
188 std::vector<bool> straus_mask(ns.size());
189 std::vector<bool> internal_mask(ns.size());
190 for (size_t i = 0; i < ns.size(); ++i) {
191 straus_mask[i] = (ns[i] < 256);
192 internal_mask[i] = (ns[i] >= 64);
193 }
194 std::vector<bool> all_mask(ns.size(), true);
195
196 std::vector<double> row_jac_mt(ns.size());
197 std::vector<double> row_jac_st(ns.size());
198 std::vector<double> row_threaded(ns.size());
199 std::vector<double> row_straus(ns.size());
200 std::vector<double> row_internal(ns.size());
201
202 for (size_t col = 0; col < ns.size(); ++col) {
203 const size_t n = ns[col];
204 std::span<const G1> points = all_points.subspan(0, n);
205 std::span<const Fr> scalars_view(scalars.data(), n);
206 std::span<Fr> mut_scalars_view(scalars.data(), n);
207 bb::PolynomialSpan<const Fr> poly_scalars(0, scalars_view);
208 bb::PolynomialSpan<Fr> mut_poly_scalars(0, mut_scalars_view);
209 const size_t iters = pick_iters(n);
210
211 row_jac_mt[col] = median_ns(
212 [&] {
213 volatile auto r =
214 bb::scalar_multiplication::round_parallel_detail::pippenger_round_parallel_jacobian_fast<Curve>(
215 scalars_view, points, /*min_pts_per_thread_override=*/1);
216 (void)r;
217 },
218 iters);
219
220 row_jac_st[col] = median_ns(
221 [&] {
222 volatile auto r =
223 bb::scalar_multiplication::round_parallel_detail::pippenger_round_parallel_jacobian_fast<Curve>(
224 scalars_view, points, /*min_pts_per_thread_override=*/SIZE_MAX);
225 (void)r;
226 },
227 iters);
228
229 row_threaded[col] = median_ns(
230 [&] {
231 volatile auto r = bb::scalar_multiplication::trivial_msm_threaded<Curve>(poly_scalars, points);
232 (void)r;
233 },
234 iters);
235
236 if (straus_mask[col]) {
237 row_straus[col] = median_ns(
238 [&] {
239 volatile auto r = Element::straus_msm(points, scalars_view);
240 (void)r;
241 },
242 iters);
243 }
244
245 if (internal_mask[col]) {
246 row_internal[col] = median_ns(
247 [&] {
248 volatile auto r =
249 bb::scalar_multiplication::pippenger_round_parallel<Curve>(mut_poly_scalars, points);
250 (void)r;
251 },
252 iters);
253 }
254 }
255
256 std::printf("\n=== small-MSM crossover matrix (median wall-clock ns per run, BN254) ===\n\n");
257 print_matrix_header(ns);
258 print_matrix_row("jac_fast_mt_always", row_jac_mt, all_mask);
259 print_matrix_row("jac_fast_st_always", row_jac_st, all_mask);
260 print_matrix_row("small_mul_threaded", row_threaded, all_mask);
261 print_matrix_row("straus_msm", row_straus, straus_mask);
262 print_matrix_row("pippenger_internal", row_internal, internal_mask);
263
264 // Best method per N — masked candidates are excluded from the comparison.
265 std::printf("\nBest method per N:\n");
266 for (size_t i = 0; i < ns.size(); ++i) {
267 struct Cand {
268 const char* name;
269 double v;
270 bool active;
271 };
272 std::array<Cand, 5> c{ { { "jac_mt", row_jac_mt[i], true },
273 { "jac_st", row_jac_st[i], true },
274 { "threaded", row_threaded[i], true },
275 { "straus", row_straus[i], straus_mask[i] },
276 { "internal", row_internal[i], internal_mask[i] } } };
277 const Cand* best = nullptr;
278 for (const Cand& cand : c) {
279 if (cand.active && (best == nullptr || cand.v < best->v)) {
280 best = &cand;
281 }
282 }
283 std::printf(" N=%-6zu best=%-12s (%.0f ns)\n", ns[i], best->name, best->v);
284 }
285}
286
287} // namespace
288
289int main()
290{
291 run_matrix();
292 return 0;
293}
typename Group::element Element
Definition bn254.hpp:21
typename Group::affine_element AffineElement
Definition bn254.hpp:22
bb::fr ScalarField
Definition bn254.hpp:18
numeric::RNG & engine
RNG & get_debug_randomness(bool reset, std::uint_fast64_t seed)
Definition engine.cpp:245
std::filesystem::path bb_crs_path()
void init_file_crs_factory(const std::filesystem::path &path)
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
Curve::AffineElement G1
Curve::Element Element
static field random_element(numeric::RNG *engine=nullptr) noexcept