Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
bernstein_yang_inverse_wasm.hpp
Go to the documentation of this file.
1// 9 × 29-bit-limb state. Included from bernstein_yang_inverse.hpp; uses the
2// u64 / i64 / DivstepMatrix names declared there.
3//
4// Why a different limb size from Native5x64: on wasm32 there is no native
5// 64×64→128 multiply, so i64 × u64 → __int128 lowers to a compiler-rt
6// __multi3 dispatch. Pack the 254-bit state into 9 limbs of 29 bits each
7// instead: every limb-level product is then i29 × i29 = i58, fitting in a
8// single WASM i64.mul. Choosing the per-iter BATCH as exactly 2 × LIMB_BITS
9// makes the "/ 2^BATCH" at the end of apply_divstep_matrix equivalent to dropping
10// the bottom two 29-bit limbs (no sub-limb shift on the intermediate).
11
12#pragma once
13
15#include <cstdint>
16
17namespace bb::bernstein_yang {
18
19class Wasm9x29 {
20 public:
21 // Divsteps per matrix application; smaller than Native5x64::BATCH so
22 // the resulting "/ 2^BATCH" is limb-aligned (= drop the bottom two
23 // 29-bit limbs) and no sub-limb shift is needed on the intermediate.
24 static constexpr int BATCH = 58;
25
26 // ⌈735 / 58⌉ = 13. Same convergence-bound logic as Native5x64; one
27 // iter more because BATCH is smaller.
28 static constexpr int NUM_ITERATIONS = 13;
29
30 // |d|, |e| can grow by ~2× + p per matrix application; after 4 iters
31 // they reach ~31p ≈ 2^259, which still fits in the 9 × 29-bit signed
32 // state (capacity ~2^260). Reducing once every 4 iters instead of
33 // every iter saves ~3× reduce_to_canonical calls per inversion.
34 static constexpr int REDUCE_INTERVAL = 4;
35
36 // Worst-case iteration cap inside reduce_to_canonical. After
37 // REDUCE_INTERVAL iters between reductions, |d|, |e| ≤ (2^(REDUCE_INTERVAL+1) - 1)·p,
38 // so reducing requires that many subtractions plus one break iter.
39 static constexpr int REDUCE_TO_CANONICAL_MAX_ITERS = 36;
40 static_assert((1U << (REDUCE_INTERVAL + 1)) <= REDUCE_TO_CANONICAL_MAX_ITERS,
41 "REDUCE_INTERVAL too large for reduce_to_canonical iteration bound");
42
43 Wasm9x29() noexcept
44 : l{}
45 {}
46 explicit Wasm9x29(const uint256_t& x) noexcept
47 {
48 const u64* d = x.data;
49 l[0] = (i64)(d[0] & LIMB_MASK);
50 l[1] = (i64)((d[0] >> 29) & LIMB_MASK);
51 l[2] = (i64)(((d[0] >> 58) & 0x3FULL) | ((d[1] & 0x7FFFFFULL) << 6));
52 l[3] = (i64)((d[1] >> 23) & LIMB_MASK);
53 l[4] = (i64)(((d[1] >> 52) & 0xFFFULL) | ((d[2] & 0x1FFFFULL) << 12));
54 l[5] = (i64)((d[2] >> 17) & LIMB_MASK);
55 l[6] = (i64)(((d[2] >> 46) & 0x3FFFFULL) | ((d[3] & 0x7FFULL) << 18));
56 l[7] = (i64)((d[3] >> 11) & LIMB_MASK);
57 l[8] = (i64)((d[3] >> 40) & 0xFFFFFFULL);
58 }
59 static Wasm9x29 one() noexcept
60 {
61 Wasm9x29 r;
62 r.l[0] = 1;
63 return r;
64 }
65
66 uint256_t to_uint256() const noexcept
67 {
68 return { (u64)l[0] | ((u64)l[1] << 29) | ((u64)l[2] << 58),
69 ((u64)l[2] >> 6) | ((u64)l[3] << 23) | ((u64)l[4] << 52),
70 ((u64)l[4] >> 12) | ((u64)l[5] << 17) | ((u64)l[6] << 46),
71 ((u64)l[6] >> 18) | ((u64)l[7] << 11) | ((u64)l[8] << 40) };
72 }
73 u64 low_64() const noexcept { return (u64)l[0] | ((u64)l[1] << 29) | (((u64)l[2] & 0x3F) << 58); }
74 bool is_zero() const noexcept
75 {
76 i64 a = 0;
77 for (int i = 0; i < N; ++i) {
78 a |= l[i];
79 }
80 return a == 0;
81 }
82 bool is_negative() const noexcept { return l[N - 1] < 0; }
83 void neg() noexcept
84 {
85 for (int i = 0; i < N; ++i) {
86 l[i] = -l[i];
87 }
88 normalise();
89 }
90 void reduce_to_canonical(const Wasm9x29& p) noexcept;
91
92 // See Native5x64 for the batched divstep matrix, matrix application,
93 // and p_inv_mod_2k_from_montgomery_r_inv contracts; the bodies differ only
94 // in limb representation.
95 static DivstepMatrix compute_divstep_matrix(i64& delta, u64 f_lo, u64 g_lo) noexcept;
96 static void apply_divstep_matrix(const DivstepMatrix& m,
97 Wasm9x29& f,
98 Wasm9x29& g,
99 Wasm9x29& d,
100 Wasm9x29& e,
101 const Wasm9x29& p,
102 u64 p_inv_mod_2k) noexcept;
103 static constexpr u64 p_inv_mod_2k_from_montgomery_r_inv(u64 r_inv) noexcept
104 {
105 // r_inv = -p^{-1} mod 2^64, so 0 - r_inv = p^{-1} mod 2^64.
106 return (0ULL - r_inv) & ((1ULL << BATCH) - 1);
107 }
108
109 private:
110 static constexpr int N = 9;
111 static constexpr int LIMB_BITS = 29;
112 static constexpr u64 LIMB_MASK = (1ULL << LIMB_BITS) - 1;
113 i64 l[N]; // top limb carries sign; lower limbs in [0, 2^29) post-normalise
114
115 void normalise() noexcept
116 {
117 i64 c = 0;
118 for (int i = 0; i < N - 1; ++i) {
119 i64 v = l[i] + c;
120 l[i] = v & (i64)LIMB_MASK;
121 c = v >> LIMB_BITS;
122 }
123 l[N - 1] += c;
124 }
125 void add_inplace(const Wasm9x29& b) noexcept
126 {
127 for (int i = 0; i < N; ++i) {
128 l[i] += b.l[i];
129 }
130 normalise();
131 }
132 void sub_inplace(const Wasm9x29& b) noexcept
133 {
134 for (int i = 0; i < N; ++i) {
135 l[i] -= b.l[i];
136 }
137 normalise();
138 }
139};
140
141// Iter cap chosen by the REDUCE_TO_CANONICAL_MAX_ITERS / REDUCE_INTERVAL
142// static_assert above; see those constants for the magnitude argument.
143inline void Wasm9x29::reduce_to_canonical(const Wasm9x29& p) noexcept
144{
145 normalise();
146 for (int it = 0; it < REDUCE_TO_CANONICAL_MAX_ITERS; ++it) {
147 if (is_negative()) {
148 add_inplace(p);
149 continue;
150 }
151 int cmp = 0;
152 for (int i = N - 1; i >= 0; --i) {
153 if (l[i] != p.l[i]) {
154 cmp = l[i] > p.l[i] ? 1 : -1;
155 break;
156 }
157 }
158 if (cmp < 0) {
159 break;
160 }
161 sub_inplace(p);
162 }
163}
164
165inline DivstepMatrix Wasm9x29::compute_divstep_matrix(i64& delta, u64 f_lo, u64 g_lo) noexcept
166{
167 i64 u = 1, v = 0, q = 0, r = 1;
168 for (int i = 0; i < BATCH; ++i) {
169 if (g_lo & 1) {
170 if (delta > 0) {
171 u64 nf = g_lo, ng = (g_lo - f_lo) >> 1;
172 i64 nu = q << 1, nv = r << 1, nq = q - u, nr = r - v;
173 f_lo = nf;
174 g_lo = ng;
175 u = nu;
176 v = nv;
177 q = nq;
178 r = nr;
179 delta = 1 - delta;
180 } else {
181 g_lo = (g_lo + f_lo) >> 1;
182 q = q + u;
183 r = r + v;
184 u <<= 1;
185 v <<= 1;
186 delta = delta + 1;
187 }
188 } else {
189 g_lo >>= 1;
190 u <<= 1;
191 v <<= 1;
192 delta = delta + 1;
193 }
194 }
195 return { u, v, q, r };
196}
197
198// Streamed schoolbook: for each limb position i compute
199// nf_i = u_lo·f_i + v_lo·g_i + u_hi·f_{i-1} + v_hi·g_{i-1} + carry_in
200// (similarly ng, nd, ne), then carry_out = nf_i >> LIMB_BITS, masked low
201// 29 bits land at output position i - 2 (= exact >> BATCH). The de row
202// derives k_d, k_e from the low two limbs up front and folds k·p into the
203// per-limb formula from position 2 onward. No 11-limb intermediate is
204// materialised — the JIT keeps the four running carries in registers.
206 Wasm9x29& f,
207 Wasm9x29& g,
208 Wasm9x29& d,
209 Wasm9x29& e,
210 const Wasm9x29& p,
211 u64 p_inv_mod_2k) noexcept
212{
213 constexpr u64 MASK_BATCH = (1ULL << BATCH) - 1;
214 const i64 u_lo = m.u & (i64)LIMB_MASK, u_hi = m.u >> LIMB_BITS;
215 const i64 v_lo = m.v & (i64)LIMB_MASK, v_hi = m.v >> LIMB_BITS;
216 const i64 q_lo = m.q & (i64)LIMB_MASK, q_hi = m.q >> LIMB_BITS;
217 const i64 r_lo = m.r & (i64)LIMB_MASK, r_hi = m.r >> LIMB_BITS;
218
219 {
220 i64 cf = 0, cg = 0, fp = 0, gp = 0;
221 for (int i = 0; i < N; ++i) {
222 const i64 fi = f.l[i], gi = g.l[i];
223 const i64 nf = u_lo * fi + v_lo * gi + u_hi * fp + v_hi * gp + cf;
224 const i64 ng = q_lo * fi + r_lo * gi + q_hi * fp + r_hi * gp + cg;
225 cf = nf >> LIMB_BITS;
226 cg = ng >> LIMB_BITS;
227 if (i >= 2) {
228 f.l[i - 2] = nf & (i64)LIMB_MASK;
229 g.l[i - 2] = ng & (i64)LIMB_MASK;
230 }
231 fp = fi;
232 gp = gi;
233 }
234 const i64 nf9 = u_hi * fp + v_hi * gp + cf;
235 const i64 ng9 = q_hi * fp + r_hi * gp + cg;
236 f.l[N - 2] = nf9 & (i64)LIMB_MASK;
237 g.l[N - 2] = ng9 & (i64)LIMB_MASK;
238 f.l[N - 1] = nf9 >> LIMB_BITS;
239 g.l[N - 1] = ng9 >> LIMB_BITS;
240 }
241
242 // k_d, k_e (mod 2^BATCH) clear the low BATCH bits of nd, ne; fold k·p
243 // into the streaming pass from position 2 onward.
244 {
245 const i64 d0 = d.l[0], e0 = e.l[0], d1 = d.l[1], e1 = e.l[1];
246 const i64 nd0 = u_lo * d0 + v_lo * e0;
247 const i64 ne0 = q_lo * d0 + r_lo * e0;
248 const i64 nd1 = u_lo * d1 + v_lo * e1 + u_hi * d0 + v_hi * e0;
249 const i64 ne1 = q_lo * d1 + r_lo * e1 + q_hi * d0 + r_hi * e0;
250 const u64 t_d = ((u64)nd0 & LIMB_MASK) | (((u64)(nd1 + (nd0 >> LIMB_BITS)) & LIMB_MASK) << LIMB_BITS);
251 const u64 t_e = ((u64)ne0 & LIMB_MASK) | (((u64)(ne1 + (ne0 >> LIMB_BITS)) & LIMB_MASK) << LIMB_BITS);
252 const u64 k_d = ((0ULL - t_d) * p_inv_mod_2k) & MASK_BATCH;
253 const u64 k_e = ((0ULL - t_e) * p_inv_mod_2k) & MASK_BATCH;
254 const i64 kd_lo = (i64)(k_d & LIMB_MASK), kd_hi = (i64)(k_d >> LIMB_BITS);
255 const i64 ke_lo = (i64)(k_e & LIMB_MASK), ke_hi = (i64)(k_e >> LIMB_BITS);
256 i64 cd = (nd1 + kd_lo * p.l[1] + kd_hi * p.l[0] + ((nd0 + kd_lo * p.l[0]) >> LIMB_BITS)) >> LIMB_BITS;
257 i64 ce = (ne1 + ke_lo * p.l[1] + ke_hi * p.l[0] + ((ne0 + ke_lo * p.l[0]) >> LIMB_BITS)) >> LIMB_BITS;
258
259 i64 dp = d1, ep = e1;
260 for (int i = 2; i < N; ++i) {
261 const i64 di = d.l[i], ei = e.l[i];
262 const i64 nd = u_lo * di + v_lo * ei + u_hi * dp + v_hi * ep + kd_lo * p.l[i] + kd_hi * p.l[i - 1] + cd;
263 const i64 ne = q_lo * di + r_lo * ei + q_hi * dp + r_hi * ep + ke_lo * p.l[i] + ke_hi * p.l[i - 1] + ce;
264 cd = nd >> LIMB_BITS;
265 ce = ne >> LIMB_BITS;
266 d.l[i - 2] = nd & (i64)LIMB_MASK;
267 e.l[i - 2] = ne & (i64)LIMB_MASK;
268 dp = di;
269 ep = ei;
270 }
271 const i64 nd9 = u_hi * dp + v_hi * ep + kd_hi * p.l[N - 1] + cd;
272 const i64 ne9 = q_hi * dp + r_hi * ep + ke_hi * p.l[N - 1] + ce;
273 d.l[N - 2] = nd9 & (i64)LIMB_MASK;
274 e.l[N - 2] = ne9 & (i64)LIMB_MASK;
275 d.l[N - 1] = nd9 >> LIMB_BITS;
276 e.l[N - 1] = ne9 >> LIMB_BITS;
277 }
278}
279
280} // namespace bb::bernstein_yang
void sub_inplace(const Wasm9x29 &b) noexcept
static void apply_divstep_matrix(const DivstepMatrix &m, Wasm9x29 &f, Wasm9x29 &g, Wasm9x29 &d, Wasm9x29 &e, const Wasm9x29 &p, u64 p_inv_mod_2k) noexcept
void reduce_to_canonical(const Wasm9x29 &p) noexcept
Wasm9x29(const uint256_t &x) noexcept
static constexpr u64 p_inv_mod_2k_from_montgomery_r_inv(u64 r_inv) noexcept
static DivstepMatrix compute_divstep_matrix(i64 &delta, u64 f_lo, u64 g_lo) noexcept
void add_inplace(const Wasm9x29 &b) noexcept
FF a
FF b