1275 const size_t n_input = scalars_span.size();
1277 return Curve::Group::point_at_infinity;
1286 const size_t num_threads_dispatch =
std::max<size_t>(1, std::min(n_input, max_threads));
1287 const size_t pts_per_thread = (n_input + num_threads_dispatch - 1) / num_threads_dispatch;
1289 return trivial_msm_threaded<Curve>(scalars_span, all_points);
1293 BB_ASSERT_GTE(all_points.size(), scalars_span.start_index + n_input);
1294 std::span<const AffineElement> input_points(&all_points[scalars_span.start_index], n_input);
1296 constexpr size_t FULL_NUM_BITS = ScalarField::modulus.get_msb() + 1;
1299 ScalarField* scalar_ptr =
const_cast<ScalarField*
>(&scalars_span[scalars_span.start_index]);
1310 const bool external_glv_provided = !external_glv_doubled.empty();
1319 const size_t n = use_glv ? (2 * n_input) : n_input;
1320 const size_t NUM_BITS = use_glv ?
size_t{ 128 } : FULL_NUM_BITS;
1323 "working scalar indices must fit in the 29-bit schedule payload");
1325 std::span<const AffineElement> points;
1326 const bool inline_glv_double = use_glv && !external_glv_provided;
1333 const bool dedup_active = dedup_hint;
1347 const size_t arena_total_bytes = compute_arena_bytes_for_msm<Curve>(n_input, external_glv_provided, dedup_active);
1348 round_parallel_detail::MsmArena arena(arena_total_bytes, external_arena);
1362 using round_parallel_detail::MSB_ZERO_SENTINEL;
1364 auto msb_per_scalar = arena.template alloc<uint8_t>(n);
1365 auto per_thread_msb_hist = arena.template alloc<std::array<uint32_t, 256>>(profile_threads);
1376 glv_scalars_storage = arena.template alloc<ScalarField>(n);
1377 if (inline_glv_double) {
1378 glv_points_storage = arena.template alloc<AffineElement>(n);
1390 const BaseField beta = inline_glv_double ? BaseField::cube_root_of_unity() : BaseField{};
1392 auto& th_hist = per_thread_msb_hist[chunk.
thread_index];
1393 for (
size_t i : chunk.
range(n_input)) {
1394 const ScalarField canonical = input_scalars[i].from_montgomery_form_reduced();
1395 const auto split = ScalarField::split_into_endomorphism_scalars(canonical);
1396 const auto& k1 = split.first;
1397 const auto& k2 = split.second;
1398 glv_scalars_storage[2 * i].data[0] = k1[0];
1399 glv_scalars_storage[2 * i].data[1] = k1[1];
1400 glv_scalars_storage[2 * i].data[2] = 0;
1401 glv_scalars_storage[2 * i].data[3] = 0;
1402 glv_scalars_storage[(2 * i) + 1].
data[0] = k2[0];
1403 glv_scalars_storage[(2 * i) + 1].
data[1] = k2[1];
1404 glv_scalars_storage[(2 * i) + 1].
data[2] = 0;
1405 glv_scalars_storage[(2 * i) + 1].
data[3] = 0;
1406 if (inline_glv_double) {
1407 glv_points_storage[2 * i] = input_points[i];
1408 glv_points_storage[(2 * i) + 1].x = input_points[i].x * beta;
1409 glv_points_storage[(2 * i) + 1].y = -input_points[i].y;
1411 round_parallel_detail::record_msb(
1412 round_parallel_detail::msb_of_2limb(k1[0], k1[1]), msb_per_scalar[2 * i], th_hist);
1413 round_parallel_detail::record_msb(
1414 round_parallel_detail::msb_of_2limb(k2[0], k2[1]), msb_per_scalar[(2 * i) + 1], th_hist);
1418 inline_glv_double ? std::span<const AffineElement>(glv_points_storage.data(), n) : external_glv_doubled;
1419 scalars = glv_scalars_storage;
1423 auto& th_hist = per_thread_msb_hist[chunk.
thread_index];
1424 for (
size_t i : chunk.
range(n_input)) {
1425 input_scalars[i].self_from_montgomery_form_reduced();
1426 round_parallel_detail::record_msb(
1427 round_parallel_detail::msb_of_4limb(input_scalars[i].
data), msb_per_scalar[i], th_hist);
1430 scalars = input_scalars;
1431 points = input_points;
1435 for (
size_t t = 0; t < profile_threads; ++t) {
1436 for (
size_t b = 0;
b < 256; ++
b) {
1437 msb_hist[
b] += per_thread_msb_hist[t][
b];
1440 const size_t n_active_early = n -
static_cast<size_t>(msb_hist[0]);
1449 const size_t threads_for_dispatch =
std::max<size_t>(1, std::min(n_active_early, max_threads_dispatch));
1450 const size_t pts_per_thread = (n_active_early + threads_for_dispatch - 1) / threads_for_dispatch;
1453 for (
size_t i : chunk.
range(n)) {
1454 scalars[i].self_to_montgomery_form();
1457 std::span<const ScalarField> scalars_const(scalars.data(), n);
1459 return trivial_msm_threaded<Curve>(ps, points);
1470 size_t effective_num_bits = 0;
1471 for (
size_t bin = 256; bin > 1;) {
1473 if (msb_hist[bin] != 0) {
1474 effective_num_bits = bin;
1478 if (effective_num_bits == 0 || effective_num_bits > NUM_BITS) {
1479 effective_num_bits = NUM_BITS;
1481 const size_t window_bits =
1483 const size_t num_buckets = (
size_t{ 1 } << (window_bits - 1)) + 1;
1501 "window schedule exceeds compile-time max window count");
1513 const size_t max_threads_for_min_batch =
std::max<size_t>(1, n / MIN_BATCH_CAPACITY);
1514 const size_t num_threads = std::min(desired_threads, max_threads_for_min_batch);
1529 size_t B_eff = num_buckets;
1530 for (
size_t w = 0; w < sched.num_windows; ++w) {
1531 B_eff =
std::max(B_eff,
static_cast<size_t>(sched.num_buckets[w]));
1534 const size_t worker_total_for_budget = num_threads;
1536 const size_t bucket_partials_per_window_max =
1538 const size_t per_window_bytes_lo = round_parallel_detail::compute_per_window_bytes<Curve>(
1539 num_threads, B_eff, n, dense_stride_est, worker_total_for_budget);
1541 const size_t global_max_overflow_per_window_for_budget =
1544 const size_t phase_one_prologue_bytes =
1548 const size_t phase_a_cluster_members_cap = phase_a_caps.members_cap;
1549 const size_t phase_a_cluster_offsets_cap = phase_a_caps.offsets_cap;
1555 SUBCHUNK_ENTRIES_CAP,
1556 global_max_overflow_per_window_for_budget,
1558 phase_a_cluster_members_cap,
1559 phase_a_cluster_offsets_cap,
1564 const size_t fixed_overhead = (worker_union_bytes_for_budget * worker_total_for_budget) +
1566 + (
size_t{ 8 } * (num_threads + 1))
1567 + phase_one_prologue_bytes;
1570 const size_t available_budget =
1571 (BATCH_MEM_BUDGET > fixed_overhead) ? (BATCH_MEM_BUDGET - fixed_overhead) :
size_t{ 0 };
1572 const size_t windows_per_batch =
1583 const size_t global_max_chunk_len = (n + num_threads - 1) / num_threads;
1584 const size_t global_max_overflow_per_window =
1585 (global_max_chunk_len + SUBCHUNK_ENTRIES_CAP - 1) / SUBCHUNK_ENTRIES_CAP;
1586 const size_t chunk_capacity =
std::max(SUBCHUNK_ENTRIES_CAP, 2 * global_max_overflow_per_window);
1594 const size_t worker_total = num_threads;
1598 phase_a_scratch.resize(worker_total);
1641 const size_t bytes_P_prefix = arena.cursor;
1646 auto align_up = [](
size_t off,
size_t align) ->
size_t {
return (off + align - 1) & ~(align - 1); };
1647 auto layout_add = [&](
size_t& off,
size_t bytes,
size_t align) { off = align_up(off, align) + bytes; };
1653 global_max_overflow_per_window,
1655 phase_a_cluster_members_cap,
1656 phase_a_cluster_offsets_cap,
1666 size_t bytes_P_extra_layout = 0;
1667 layout_add(bytes_P_extra_layout,
sizeof(
Element) * VAR_WINDOW_WINDOW_SUMS_CAP,
alignof(
Element));
1669 layout_add(bytes_P_extra_layout,
sizeof(uint32_t) * n,
alignof(uint32_t));
1670 layout_add(bytes_P_extra_layout,
1672 alignof(AffineElement));
1680 const size_t arena_base_misalign =
static_cast<size_t>(arena.base_addr & (WORKER_SLAB_ALIGN - 1));
1681 const size_t bytes_P_min = align_up(bytes_P_prefix,
alignof(
Element)) + bytes_P_extra_layout;
1682 const size_t bytes_P = align_up(bytes_P_min + arena_base_misalign, WORKER_SLAB_ALIGN) - arena_base_misalign;
1685 const size_t bytes_W = per_worker_bytes * worker_total;
1691 const size_t bytes_S_total = arena.capacity - bytes_P - bytes_W;
1696 size_t zone_P_cursor = bytes_P_prefix;
1697 size_t zone_S_cursor = 0;
1698 auto zone_P_alloc = [&]<
typename T>(
size_t count) ->
std::span<T> {
1699 return arena.template bump_alloc<T>(count, zone_P_cursor, bytes_P, 0);
1701 auto zone_S_alloc = [&]<
typename T>(
size_t count) ->
std::span<T> {
1702 return arena.template bump_alloc<T>(count, zone_S_cursor, bytes_S_total, bytes_P + bytes_W);
1712 for (
size_t t = 0; t < worker_total; ++t) {
1714 const size_t slab_base = t * per_worker_bytes;
1715 auto& s = thread_scratch[t];
1718 size_t ts_fixed_cur = 0;
1719 auto ts_fixed_alloc = [&]<
typename T>(
size_t count) ->
std::span<T> {
1720 return arena.template bump_alloc<T>(count, ts_fixed_cur, per_worker_union_bytes, bytes_P + slab_base);
1722 s.curr_pts = ts_fixed_alloc.template operator()<AffineElement>(chunk_capacity);
1723 s.curr_buckets = ts_fixed_alloc.template operator()<uint32_t>(chunk_capacity);
1724 s.points_to_add = ts_fixed_alloc.template operator()<AffineElement>(2 * BATCH_CAPACITY);
1725 s.inversion_scratch = ts_fixed_alloc.template operator()<BaseField>(BATCH_CAPACITY);
1726 s.pair_dest = ts_fixed_alloc.template operator()<uint32_t>(BATCH_CAPACITY);
1727 s.overflow_slots = ts_fixed_alloc.template operator()<uint32_t>(global_max_overflow_per_window);
1728 s.overflow_pts = ts_fixed_alloc.template operator()<AffineElement>(global_max_overflow_per_window);
1735 auto pa_alloc = [&]<
typename T>(
size_t count) ->
std::span<T> {
1736 return arena.template bump_alloc<T>(count, pa_cur, per_worker_union_bytes, bytes_P + slab_base);
1738 auto& ps = phase_a_scratch[t];
1740 ps.cluster_members = pa_alloc.template operator()<uint32_t>(phase_a_cluster_members_cap);
1741 ps.cluster_offsets = pa_alloc.template operator()<uint32_t>(phase_a_cluster_offsets_cap);
1742 ps.dirty_slots = pa_alloc.template operator()<uint16_t>(PWAL::PHASE_A_DIRTY_SLOTS_CAP);
1743 ps.bucket_rep = pa_alloc.template operator()<uint32_t>(PWAL::PHASE_A_BUCKET_REP_CAP);
1745 ps.chunk_pts = pa_alloc.template operator()<AffineElement>(PWAL::PHASE_A_CHUNK_CAP);
1746 ps.chunk_ids = pa_alloc.template operator()<uint32_t>(PWAL::PHASE_A_CHUNK_CAP);
1752 size_t ts_tail_cur = per_worker_union_bytes;
1753 auto ts_tail_alloc = [&]<
typename T>(
size_t count) ->
std::span<T> {
1754 return arena.template bump_alloc<T>(count, ts_tail_cur, per_worker_bytes, bytes_P + slab_base);
1756 const size_t dense_total = windows_per_batch * dense_stride_est;
1757 const size_t dense_pair_max = dense_total / 2;
1758 s.dense_buckets = ts_tail_alloc.template operator()<AffineElement>(dense_total);
1759 s.is_present = ts_tail_alloc.template operator()<uint8_t>(dense_total);
1761 s.affine_bucket_indices = ts_tail_alloc.template operator()<uint32_t>(dense_pair_max);
1762 s.affine_bucket_inversion_scratch = ts_tail_alloc.template operator()<BaseField>(dense_pair_max);
1766 s.affine_bucket_stride = dense_stride_est;
1770 const size_t schedule_total = windows_per_batch * n;
1771 auto schedule = zone_S_alloc.template operator()<uint32_t>(schedule_total);
1798 static_assert(
alignof(
Element) <= 32,
"HIST slot O layout assumes alignof(Element) <= 32");
1800 "HIST slot O layout assumes alignof(ChunkOutput) <= 32");
1802 auto align_up_local = [](
size_t off,
size_t a) ->
size_t {
return (off +
a - 1) & ~(
a - 1); };
1805 const size_t hist_h_bytes_total = (
size_t{ 4 } * windows_per_batch * num_threads * B_eff);
1809 size_t o_layout_cur = 0;
1811 const size_t off_chunk_outputs = o_layout_cur;
1813 o_layout_cur = align_up_local(o_layout_cur,
alignof(
typename Curve::Element));
1814 const size_t off_window_partial_sums = o_layout_cur;
1815 o_layout_cur +=
sizeof(
typename Curve::Element) * num_threads * windows_per_batch;
1816 const size_t hist_o_bytes_total = o_layout_cur;
1818 const size_t hist_slot_bytes_total =
std::max(hist_h_bytes_total, hist_o_bytes_total);
1824 const size_t hist_slot_cells = (hist_slot_bytes_total +
sizeof(AffineElement) - 1) /
sizeof(AffineElement);
1825 auto hist_slot_cells_span = zone_S_alloc.template operator()<AffineElement>(hist_slot_cells);
1827 std::byte*
const hist_slot_bytes =
reinterpret_cast<std::byte*
>(hist_slot_cells_span.data());
1841 auto digit_cursors =
1842 std::span<uint32_t>{
reinterpret_cast<uint32_t*
>(hist_slot_bytes), windows_per_batch * num_threads * B_eff };
1854 windows_per_batch * num_threads
1858 reinterpret_cast<typename
Curve::Element*
>(hist_slot_bytes + off_window_partial_sums),
1859 num_threads * windows_per_batch
1877 static_assert(
alignof(AffineElement) == 64,
"DENSE slot D layout assumes alignof(AffineElement) == 64");
1878 const size_t bp_total = windows_per_batch * bucket_partials_per_window_max;
1879 size_t d_layout_cur = 0;
1880 const size_t off_dense = d_layout_cur;
1881 d_layout_cur +=
sizeof(AffineElement) * bp_total;
1882 const size_t off_present = d_layout_cur;
1883 d_layout_cur +=
sizeof(uint8_t) * bp_total;
1884 const size_t dense_slot_bytes_total = d_layout_cur;
1885 const size_t dense_slot_cells = (dense_slot_bytes_total +
sizeof(AffineElement) - 1) /
sizeof(AffineElement);
1888 auto dense_slot_cells_span = zone_S_alloc.template operator()<AffineElement>(dense_slot_cells);
1890 std::byte*
const dense_slot_bytes =
reinterpret_cast<std::byte*
>(dense_slot_cells_span.data());
1893 auto bucket_partials_dense =
1896 auto bucket_partials_present =
1897 std::span<uint8_t>{
reinterpret_cast<uint8_t*
>(dense_slot_bytes + off_present), bp_total };
1900 auto bucket_start_all = zone_S_alloc.template operator()<
size_t>(windows_per_batch * (B_eff + 1));
1901 auto chunk_start_all = zone_S_alloc.template operator()<
size_t>(windows_per_batch * (num_threads + 1));
1909 auto chunk_bucket_lo_all = zone_S_alloc.template operator()<
size_t>(windows_per_batch * (num_threads + 1));
1910 auto chunk_bucket_hi_all = zone_S_alloc.template operator()<
size_t>(windows_per_batch * num_threads);
1915 auto bucket_partials_offsets = zone_S_alloc.template operator()<
size_t>((num_threads * windows_per_batch) + 1);
1922 auto rebalanced_bucket_lo_partition = zone_S_alloc.template operator()<
size_t>(num_threads + 1);
1923 auto orig_thread_lo = zone_S_alloc.template operator()<
size_t>(windows_per_batch * num_threads);
1924 auto orig_thread_hi = zone_S_alloc.template operator()<
size_t>(windows_per_batch * num_threads);
1927 auto window_sums = zone_P_alloc.template operator()<
typename Curve::Element>(VAR_WINDOW_WINDOW_SUMS_CAP);
1928 std::fill_n(window_sums.begin(), VAR_WINDOW_WINDOW_SUMS_CAP, Curve::Group::point_at_infinity);
1936 dedup_state.
redirect_lookup = zone_P_alloc.template operator()<uint32_t>(n);
1942 for (
size_t i : chunk.
range(n)) {
1950 constexpr uint32_t BUCKET_MASK = (uint32_t{ 1 } << 31) - 1;
1958 bool phase_a_done =
false;
1960 auto run_batch = [&](
size_t batch_start,
size_t windows_in_batch,
size_t B_R)
noexcept {
1963 const size_t bucket_stride = B_eff;
1967 constexpr size_t SCALAR_UINT64_LIMBS =
sizeof(ScalarField) /
sizeof(uint64_t);
1975 constexpr size_t SCALAR_U32_LIMBS =
sizeof(ScalarField) /
sizeof(uint32_t);
1976 for (
size_t w = 0; w < windows_in_batch; ++w) {
1977 const size_t global_w = batch_start + w;
1978 const size_t window_bits_w = sched.window_bits_per_window[global_w];
1979 per_window_bits[w] =
static_cast<uint8_t
>(window_bits_w);
1981 sched.bit_base[global_w], window_bits_w, SCALAR_UINT64_LIMBS);
1983 sched.bit_base[global_w], window_bits_w, SCALAR_U32_LIMBS);
1985 const uint32_t lo_mask = slice_params_u32[w].lo_mask;
1986 const uint32_t hi_mask = slice_params_u32[w].hi_mask;
1987 const uint32_t val_mask = (uint32_t{ 1 } <<
static_cast<uint32_t
>(window_bits_w)) - 1;
1993 constexpr size_t SIMD_BATCH = 64;
1994 static_assert(SIMD_BATCH % 4 == 0,
"SIMD_BATCH must be divisible by 4");
1995 constexpr size_t LIMBS_PER_SCALAR =
sizeof(ScalarField) /
sizeof(uint32_t);
1996 const auto* scalars_u32 =
reinterpret_cast<const uint32_t*
>(scalars.data());
1998 auto fill_packed_digit_buffer = [&](
size_t w,
size_t i, uint32_t* packed_buf)
noexcept {
1999 const auto& sp32 = slice_params_u32[w];
2000 const uint32_t window_bits_w =
static_cast<uint32_t
>(per_window_bits[w]);
2002 for (
size_t k = 0; k < SIMD_BATCH; k += 4) {
2005 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2006 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2007 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2008 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2013 val_mask_vectors[w],
2017 for (
size_t k = 0; k < SIMD_BATCH; k += 4) {
2020 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2021 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2022 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2023 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2028 val_mask_vectors[w],
2032 for (
size_t k = 0; k < SIMD_BATCH; k += 4) {
2035 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2036 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2037 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2038 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2046 val_mask_vectors[w],
2055 const bool phase_a_done_at_batch_start = phase_a_done;
2056 const bool dedup_known_for_batch =
2057 dedup_active && phase_a_done_at_batch_start && dedup_state.
n_dedup_extras != 0;
2062 auto stage1_digit_extract = [&]<
bool DedupKnown>(
size_t tid)
noexcept {
2063 [[maybe_unused]]
const uint32_t*
const rl_data = dedup_state.
redirect_lookup.data();
2064 for (
size_t w = 0; w < windows_in_batch; ++w) {
2065 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2066 std::memset(my_counts, 0, B_R *
sizeof(uint32_t));
2068 const size_t start = tid * n / num_threads;
2069 const size_t end = (tid + 1) * n / num_threads;
2075 auto compute_include_mask = [&](
size_t block_start)
noexcept -> uint64_t {
2076 uint64_t include_mask = 0;
2077 for (
size_t k = 0; k < SIMD_BATCH; ++k) {
2078 const size_t scalar_idx = block_start + k;
2079 const uint8_t m = msb_per_scalar[scalar_idx];
2080 bool include = (m != MSB_ZERO_SENTINEL);
2081 if constexpr (DedupKnown) {
2083 const uint32_t patch = rl_data[scalar_idx];
2088 include_mask |=
static_cast<uint64_t
>(include) << k;
2090 return include_mask;
2094 while (i + SIMD_BATCH <= end) {
2095 const uint64_t include_mask = compute_include_mask(i);
2096 if (include_mask == 0) {
2100 const bool all_included = include_mask == ~uint64_t{ 0 };
2101 for (
size_t w = 0; w < windows_in_batch; ++w) {
2102 fill_packed_digit_buffer(w, i, packed_buf.data());
2103 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2105 for (
size_t k = 0; k < SIMD_BATCH; ++k) {
2106 ++my_counts[packed_buf[k] & BUCKET_MASK];
2109 uint64_t scatter_mask = include_mask;
2110 for (
size_t k = 0; k < SIMD_BATCH; ++k) {
2111 if ((scatter_mask & uint64_t{ 1 }) != 0) {
2112 ++my_counts[packed_buf[k] & BUCKET_MASK];
2123 for (; i < end; ++i) {
2124 const uint8_t m = msb_per_scalar[i];
2125 if (m == MSB_ZERO_SENTINEL) {
2128 if constexpr (DedupKnown) {
2129 const uint32_t patch = rl_data[i];
2135 for (
size_t w = 0; w < windows_in_batch; ++w) {
2136 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2138 const uint32_t window_bits_w =
static_cast<uint32_t
>(per_window_bits[w]);
2139 const uint32_t packed =
2149 ++my_counts[packed & BUCKET_MASK];
2153 if (dedup_known_for_batch) {
2154 bb::parallel_for(num_threads, [&](
size_t tid) { stage1_digit_extract.template operator()<
true>(tid); });
2156 bb::parallel_for(num_threads, [&](
size_t tid) { stage1_digit_extract.template operator()<
false>(tid); });
2170 const size_t d_start = tid * B_R / num_threads;
2171 const size_t d_end = (tid + 1) * B_R / num_threads;
2172 for (
size_t w = 0; w < windows_in_batch; ++w) {
2173 size_t*
const bucket_start_w = bucket_start_all.data() + (w * (bucket_stride + 1));
2174 for (
size_t d = d_start; d < d_end; ++d) {
2178 uint32_t running = 0;
2179 for (
size_t t = 0; t < num_threads; ++t) {
2180 const size_t k = (((w * num_threads) + t) * bucket_stride) + d;
2181 const uint32_t cnt = digit_cursors[k];
2182 digit_cursors[k] = running;
2185 bucket_start_w[d + 1] = running;
2195 auto build_bucket_offsets_for_window = [&](
size_t w)
noexcept {
2196 size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2197 bucket_start[0] = 0;
2198 bucket_start[1] = 0;
2199 for (
size_t d = 1; d < B_R; ++d) {
2200 bucket_start[d + 1] += bucket_start[d];
2203 const size_t offset_threads = std::min(num_threads, windows_in_batch);
2204 if (offset_threads <= 1) {
2205 for (
size_t w = 0; w < windows_in_batch; ++w) {
2206 build_bucket_offsets_for_window(w);
2210 for (
size_t w = tid; w < windows_in_batch; w += offset_threads) {
2211 build_bucket_offsets_for_window(w);
2234 auto stage4_emit = [&]<
bool DedupKnown>(
size_t tid)
noexcept {
2235 [[maybe_unused]]
const uint32_t*
const rl_data = dedup_state.
redirect_lookup.data();
2236 const size_t start = tid * n / num_threads;
2237 const size_t end = (tid + 1) * n / num_threads;
2241 for (
size_t w = 0; w < windows_in_batch; ++w) {
2242 cursors[w] = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2243 bucket_starts[w] = bucket_start_all.data() + (w * (bucket_stride + 1));
2244 schedules[w] = schedule.data() + (w * n);
2248 constexpr size_t STAGE4_SCALAR_TILE = 2048;
2252 for (
size_t tile_start = start; tile_start < end; tile_start += STAGE4_SCALAR_TILE) {
2253 const size_t tile_end = std::min(end, tile_start + STAGE4_SCALAR_TILE);
2254 const size_t tile_len = tile_end - tile_start;
2255 for (
size_t j = 0; j < tile_len; ++j) {
2256 const size_t scalar_idx = tile_start + j;
2257 const uint8_t m = msb_per_scalar[scalar_idx];
2258 bool include = (m != MSB_ZERO_SENTINEL);
2259 if constexpr (DedupKnown) {
2260 uint32_t out_base =
static_cast<uint32_t
>(scalar_idx);
2262 const uint32_t patch = rl_data[scalar_idx];
2268 out_base_tile[j] = out_base;
2270 active_tile[j] =
static_cast<uint8_t
>(include);
2273 for (
size_t w = 0; w < windows_in_batch; ++w) {
2274 uint32_t* my_cursor = cursors[w];
2275 const size_t* bucket_start = bucket_starts[w];
2276 uint32_t* sched_w = schedules[w];
2277 size_t i = tile_start;
2278 while (i + SIMD_BATCH <= tile_end) {
2279 const size_t rel = i - tile_start;
2280 uint64_t include_mask = 0;
2281 for (
size_t k = 0; k < SIMD_BATCH; ++k) {
2282 include_mask |=
static_cast<uint64_t
>(active_tile[rel + k]) << k;
2284 if (include_mask == 0) {
2288 fill_packed_digit_buffer(w, i, packed_buf.data());
2289 uint64_t scatter_mask = include_mask;
2290 for (
size_t k = 0; k < SIMD_BATCH; ++k) {
2291 if ((scatter_mask & uint64_t{ 1 }) != 0) {
2292 const uint32_t packed = packed_buf[k];
2293 const uint32_t bucket_idx = packed & BUCKET_MASK;
2294 if (bucket_idx != 0) {
2295 const uint32_t idx =
2296 static_cast<uint32_t
>(bucket_start[bucket_idx]) + my_cursor[bucket_idx]++;
2298 if constexpr (DedupKnown) {
2299 out |= out_base_tile[rel + k];
2301 out |=
static_cast<uint32_t
>(i + k);
2310 for (; i < tile_end; ++i) {
2311 const size_t rel = i - tile_start;
2312 if (active_tile[rel] == 0) {
2325 static_cast<uint32_t
>(per_window_bits[w]));
2326 const uint32_t bucket_idx = packed & BUCKET_MASK;
2327 if (bucket_idx != 0) {
2328 const uint32_t idx =
2329 static_cast<uint32_t
>(bucket_start[bucket_idx]) + my_cursor[bucket_idx]++;
2331 if constexpr (DedupKnown) {
2332 out |= out_base_tile[rel];
2334 out |=
static_cast<uint32_t
>(i);
2343 if (dedup_known_for_batch) {
2344 bb::parallel_for(num_threads, [&](
size_t tid) { stage4_emit.template operator()<
true>(tid); });
2346 bb::parallel_for(num_threads, [&](
size_t tid) { stage4_emit.template operator()<
false>(tid); });
2361 if (dedup_active && windows_in_batch > 0 && !phase_a_done) {
2363 uint32_t* sched_w0 = schedule.data();
2373 const uint32_t cids_per_thread =
2385 const size_t*
const w0_bucket_start = bucket_start_all.data();
2386 std::atomic<size_t> dedup_cluster_count{ 0 };
2388 const size_t b_lo = 1 + ((tid * (B_R - 1)) / num_threads);
2389 const size_t b_hi = 1 + (((tid + 1) * (B_R - 1)) / num_threads);
2390 const uint32_t cid_lo =
static_cast<uint32_t
>(tid) * cids_per_thread;
2391 const uint32_t cid_max = cid_lo + cids_per_thread;
2392 const size_t local_clusters = round_parallel_detail::dedup_phase_a_worker_hash<Curve>(
2397 std::span<const ScalarField>(scalars.data(), n),
2401 msb_per_scalar.data(),
2405 phase_a_scratch[tid]);
2406 if (local_clusters != 0) {
2412 phase_a_done =
true;
2422 auto partition_chunks_for_window = [&](
size_t w)
noexcept {
2423 const size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2424 const size_t*
const bucket_start_end = bucket_start + B_R + 1;
2425 size_t* chunk_start = chunk_start_all.data() + (w * (num_threads + 1));
2426 size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2427 size_t* chunk_bucket_hi = chunk_bucket_hi_all.data() + (w * num_threads);
2428 const size_t m = bucket_start[B_R];
2429 const size_t* search_begin = bucket_start + 1;
2431 chunk_start[0] = lo;
2432 for (
size_t t = 0; t < num_threads; ++t) {
2433 const size_t hi = ((t + 1) == num_threads) ? m : (((t + 1) * m) / num_threads);
2434 chunk_start[t + 1] = hi;
2436 const size_t*
const lo_it =
std::upper_bound(search_begin, bucket_start_end, lo);
2437 const size_t lo_bucket =
static_cast<size_t>(lo_it - bucket_start - 1);
2438 const size_t*
const hi_it =
std::upper_bound(lo_it, bucket_start_end, hi - 1);
2439 const size_t hi_bucket =
static_cast<size_t>(hi_it - bucket_start - 1);
2440 chunk_bucket_lo[t] = lo_bucket;
2441 chunk_bucket_hi[t] = hi_bucket;
2442 search_begin = hi_it;
2444 chunk_bucket_lo[t] = B_R;
2445 chunk_bucket_hi[t] = 0;
2449 chunk_bucket_lo[num_threads] = B_R;
2452 bool chunk_partition_done =
false;
2453 if (dedup_active && windows_in_batch > 0 && phase_a_done && !phase_a_done_at_batch_start) {
2456 const size_t bs_stride = bucket_stride + 1;
2457 const size_t br = B_R;
2458 const size_t cap_R = n;
2459 bb::parallel_for(num_threads, [&, rl_data, bs_stride, br, cap_R](
size_t tid)
noexcept {
2460 for (
size_t w = tid; w < windows_in_batch; w += num_threads) {
2461 uint32_t* sched_w = schedule.data() + (w * cap_R);
2462 size_t* bucket_start_w = bucket_start_all.data() + (w * bs_stride);
2463 round_parallel_detail::dedup_patch_schedule_window<Curve>(sched_w, bucket_start_w, br, rl_data);
2464 partition_chunks_for_window(w);
2467 chunk_partition_done =
true;
2476 if (!chunk_partition_done) {
2477 for (
size_t w = 0; w < windows_in_batch; ++w) {
2478 partition_chunks_for_window(w);
2494 auto next_pow2 = [](
size_t x) ->
size_t {
2505 size_t max_chunk_len = 0;
2506 for (
size_t t = 0; t < num_threads; ++t) {
2507 for (
size_t w = 0; w < windows_in_batch; ++w) {
2508 const size_t* chunk_start = chunk_start_all.data() + (w * (num_threads + 1));
2509 const size_t entries_in_chunk = chunk_start[t + 1] - chunk_start[t];
2510 if (entries_in_chunk == 0) {
2513 max_chunk_len =
std::max(max_chunk_len, entries_in_chunk);
2525 size_t global_stride = 0;
2530 const size_t active_digits = (B_R > 0) ? (B_R - 1) : 0;
2531 for (
size_t t = 0; t <= num_threads; ++t) {
2532 rebalanced_bucket_lo_partition[t] = 1 + (t * active_digits) / num_threads;
2534 rebalanced_bucket_lo_partition[num_threads] = B_R;
2535 size_t max_buckets_per_task = 0;
2536 for (
size_t t = 0; t + 1 <= num_threads; ++t) {
2537 const size_t hi_d = (t + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[t + 1] - 1);
2538 const size_t lo_d = rebalanced_bucket_lo_partition[t];
2540 max_buckets_per_task =
std::max(max_buckets_per_task, hi_d - lo_d + 1);
2543 global_stride = next_pow2(max_buckets_per_task);
2548 for (
size_t w = 0; w < windows_in_batch; ++w) {
2549 const size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2550 const size_t* chunk_bucket_hi = chunk_bucket_hi_all.data() + (w * num_threads);
2551 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2552 for (
size_t tprime = 0; tprime < num_threads; ++tprime) {
2553 const size_t lo_d = rebalanced_bucket_lo_partition[tprime];
2555 (tprime + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[tprime + 1] - 1);
2556 size_t lo_orig = num_threads;
2558 for (
size_t t = 0; t < num_threads; ++t) {
2559 const size_t entries = chunk_start_w[t + 1] - chunk_start_w[t];
2563 const size_t cl = chunk_bucket_lo[t];
2564 const size_t ch = chunk_bucket_hi[t];
2565 if (ch < lo_d || cl > hi_d) {
2568 if (lo_orig == num_threads) {
2573 orig_thread_lo[(w * num_threads) + tprime] = lo_orig;
2574 orig_thread_hi[(w * num_threads) + tprime] = hi_orig;
2582 size_t bucket_partials_cursor = 0;
2583 for (
size_t t = 0; t < num_threads; ++t) {
2584 for (
size_t w = 0; w < windows_in_batch; ++w) {
2585 bucket_partials_offsets[(t * windows_in_batch) + w] = bucket_partials_cursor;
2586 const size_t* chunk_bucket_lo_w = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2587 const size_t* chunk_bucket_hi_w = chunk_bucket_hi_all.data() + (w * num_threads);
2588 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2589 const size_t entries = chunk_start_w[t + 1] - chunk_start_w[t];
2591 bucket_partials_cursor += chunk_bucket_hi_w[t] - chunk_bucket_lo_w[t] + 1;
2595 bucket_partials_offsets[num_threads * windows_in_batch] = bucket_partials_cursor;
2596 const size_t bucket_partials_total = bucket_partials_cursor;
2597 BB_ASSERT_LTE(bucket_partials_total, bucket_partials_dense.size());
2598 std::memset(bucket_partials_present.data(), 0, bucket_partials_total);
2603 for (
size_t t = 0; t < worker_total; ++t) {
2604 thread_scratch[t].affine_bucket_stride = global_stride;
2613 auto bucket_partials_per_thread_lambda = [&](
size_t tid) {
2614 auto& s = thread_scratch[tid];
2615 for (
size_t w = 0; w < windows_in_batch; ++w) {
2616 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2617 const size_t cs_lo = chunk_start_w[tid];
2618 const size_t cs_hi = chunk_start_w[tid + 1];
2619 if (cs_lo == cs_hi) {
2622 const uint32_t* sched_w = schedule.data() + (w * n);
2623 const size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2624 AffineElement* dst_dense =
2625 bucket_partials_dense.data() + bucket_partials_offsets[(tid * windows_in_batch) + w];
2626 uint8_t* dst_present =
2627 bucket_partials_present.data() + bucket_partials_offsets[(tid * windows_in_batch) + w];
2628 const size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2629 const uint32_t my_lo =
static_cast<uint32_t
>(chunk_bucket_lo[tid]);
2630 const size_t my_hi = chunk_bucket_hi_all[(w * num_threads) + tid];
2631 size_t bucket_cursor = my_lo;
2633 for (
size_t pos = cs_lo; pos < cs_hi;) {
2634 const size_t end = std::min(pos + SUBCHUNK_ENTRIES_CAP, cs_hi);
2635 reduce_chunk<Curve>(s,
2643 std::span<const AffineElement>(dedup_state.
extra_points));
2644 const size_t len = s.result_len;
2645 for (
size_t k = 0; k <
len; ++k) {
2646 const uint32_t d = s.curr_buckets[k];
2647 const size_t slot = d - my_lo;
2648 if (dst_present[
slot]) {
2649 s.overflow_slots[s.overflow_len] =
static_cast<uint32_t
>(
slot);
2650 s.overflow_pts[s.overflow_len] = s.curr_pts[k];
2653 dst_dense[
slot] = s.curr_pts[k];
2654 dst_present[
slot] = 1;
2659 merge_overflow<Curve>(s, dst_dense);
2671 auto bucket_reduce_cross_thread_lambda = [&](
size_t tprime) {
2672 auto& s = thread_scratch[tprime];
2673 Element* my_partials = window_partial_sums.data() + (tprime * windows_per_batch);
2674 for (
size_t w = 0; w < windows_in_batch; ++w) {
2675 my_partials[w] = Curve::Group::point_at_infinity;
2678 const size_t stride = s.affine_bucket_stride;
2679 std::memset(s.is_present.data(), 0, windows_in_batch * stride);
2681 const size_t lo_d = rebalanced_bucket_lo_partition[tprime];
2683 (tprime + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[tprime + 1] - 1);
2684 const uint32_t lo_d_u =
static_cast<uint32_t
>(lo_d);
2685 const uint32_t hi_d_u =
static_cast<uint32_t
>(hi_d);
2687 bool any_nonempty =
false;
2688 for (
size_t w = 0; w < windows_in_batch; ++w) {
2689 auto&
info = s.chunk_infos[w];
2690 auto& out = chunk_outputs[(w * num_threads) + tprime];
2695 info.buckets_padded = 0;
2699 const size_t orig_lo = orig_thread_lo[(w * num_threads) + tprime];
2700 const size_t orig_hi = orig_thread_hi[(w * num_threads) + tprime];
2701 if (orig_lo == num_threads) {
2705 info.buckets_padded = 0;
2709 const size_t base = w * stride;
2710 bool has_data =
false;
2717 const size_t* chunk_bucket_lo_w = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2718 const size_t* chunk_bucket_hi_w = chunk_bucket_hi_all.data() + (w * num_threads);
2719 for (
size_t t = orig_lo; t <= orig_hi; ++t) {
2720 const size_t cl = chunk_bucket_lo_w[t];
2721 const size_t ch = chunk_bucket_hi_w[t];
2724 if (d_lo_clip > d_hi_clip) {
2727 const AffineElement* src_dense =
2728 bucket_partials_dense.data() + bucket_partials_offsets[(t * windows_in_batch) + w];
2729 const uint8_t* src_present =
2730 bucket_partials_present.data() + bucket_partials_offsets[(t * windows_in_batch) + w];
2731 for (
size_t d = d_lo_clip; d <= d_hi_clip; ++d) {
2732 const size_t src_slot = d - cl;
2733 if (src_present[src_slot] == 0) {
2736 const size_t dst_slot = base + (d - lo_d);
2737 if (s.is_present[dst_slot] == 0) {
2738 s.dense_buckets[dst_slot] = src_dense[src_slot];
2739 s.is_present[dst_slot] = 1;
2746 acc +=
Element(src_dense[src_slot]);
2747 s.dense_buckets[dst_slot] = AffineElement(acc);
2756 info.buckets_padded = 0;
2760 any_nonempty =
true;
2761 const size_t M = hi_d - lo_d + 1;
2762 const uint32_t buckets_padded =
2763 (M == 1) ? 1 : (uint32_t{ 1 } << (32 - __builtin_clz(
static_cast<uint32_t
>(M - 1))));
2767 info.buckets_padded = buckets_padded;
2773 if (!any_nonempty) {
2777 round_parallel_detail::recursive_affine_bucket_reduce_strided<Curve>(
2778 s, s.chunk_infos.data(), windows_in_batch, chunk_outputs.data() + tprime, num_threads);
2780 for (
size_t w = 0; w < windows_in_batch; ++w) {
2781 auto& out = chunk_outputs[(w * num_threads) + tprime];
2782 if (out.empty == 0) {
2783 my_partials[w] = round_parallel_detail::chunk_contribution<Curve>(out);
2796 const size_t reduce_threads = std::min(num_threads, windows_in_batch);
2798 const size_t lo = rid * windows_in_batch / reduce_threads;
2799 const size_t hi = (rid + 1) * windows_in_batch / reduce_threads;
2800 for (
size_t w = lo; w < hi; ++w) {
2801 Element sum = Curve::Group::point_at_infinity;
2802 for (
size_t tid = 0; tid < num_threads; ++tid) {
2803 sum += window_partial_sums[(tid * windows_per_batch) + w];
2805 window_sums[batch_start + w] =
sum;
2813 const size_t B_R = (
size_t{ 1 } << (window_bits - 1)) + 1;
2814 for (
size_t batch_start = 0; batch_start < sched.num_windows; batch_start += windows_per_batch) {
2815 const size_t windows_in_batch = std::min(windows_per_batch, sched.num_windows - batch_start);
2816 run_batch(batch_start, windows_in_batch, B_R);
2822 Element result = (sched.num_windows == 0) ? Curve::Group::point_at_infinity : window_sums[sched.num_windows - 1];
2823 for (
size_t w_rev = sched.num_windows - 1; w_rev > 0; --w_rev) {
2824 const size_t window_bits_w = sched.window_bits_per_window[w_rev - 1];
2825 for (
size_t d = 0; d < window_bits_w; ++d) {
2828 result += window_sums[w_rev - 1];
2835 for (
size_t i : chunk.
range(n_input)) {
2836 input_scalars[i].self_to_montgomery_form();