-
-
Notifications
You must be signed in to change notification settings - Fork 139
perf: branchless square root #264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Hey @marmitar, thank you very much for this PR. I will review it during the weekend! |
6bf6caf to
fc31231
Compare
|
@PaulRBerg I have another even more optimized implementation using De Bruijn sequences. It's a bit more involved, but the gas costs goes to 346. I could propose that instead, if you prefer. De Bruijn implementation/// @notice Calculates the square root of x using the Babylonian method.
///
/// @dev See https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method.
///
/// Notes:
/// - If x is not a perfect square, the result is rounded down.
/// - Credits to OpenZeppelin for the explanations in comments below.
///
/// @param x The uint256 number for which to calculate the square root.
/// @return result The result as a uint256.
/// @custom:smtchecker abstract-function-nondet
function sqrt(uint256 x) pure returns (uint256 result) {
// For our first guess, we find the most significant *byte* of x and use its value and position
// to approximate the square root of x.
//
// For this, we want to find $k \in [0,255]$ and $n \in {0,8,...,248}$ such that $x \approx k 2^n$.
// We can find $n$ by doing five steps of the `msb()` algorithm ($n = 8 floor(msb(x) / 8)$), and
// then we also have $k = floor(x / 2^n)$.
//
// Once we have those values, the square root can be approximated by $sqrt(x) \approx sqrt(k 2^n) =
// sqrt(k) 2^{n/2}$. For $sqrt(k)$, we use a lookup table that fits in a 32-byte word, which means
// that we'll need to use the top 5 bits of $k$ for indexing, instead of the full 8 bits, so
// $i = k >> 3$. Because of this, each position in the table must have the average square root for
// all bytes that it covers:
//
// $$
// table[i] = round(1/8 sum_{t=0}^7 sqrt(8i+t))
// $$
//
// The table is encoded big-endian so `byte(i, table)` returns entry `i`. This process will produce
// a good initial guess for $sqrt(x)$, with at least one correct bit.
assembly ("memory-safe") {
let n := shl(7, lt(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, x))
n := or(n, shl(6, lt(0xFFFFFFFFFFFFFFFF, shr(n, x))))
n := or(n, shl(5, lt(0xFFFFFFFF, shr(n, x))))
n := or(n, shl(4, lt(0xFFFF, shr(n, x))))
n := or(n, shl(3, lt(0xFF, shr(n, x))))
let table := 0x02030405060707080809090A0A0A0B0B0B0C0C0C0D0D0D0E0E0E0F0F0F0F1010
let i := shr(3, shr(n, x))
result := shl(shr(1, n), byte(i, table))
}
// At this point, `result` is an estimation with at least one bit of precision. We know the true value has at
// most 128 bits, since it is the square root of a uint256. Newton's method converges quadratically (precision
// doubles at every iteration). We thus need at most 7 iterations to turn our partial result with one bit of
// precision into the expected uint128 result.
assembly ("memory-safe") {
// note: division by zero in EVM returns zero
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
result := shr(1, add(result, div(x, result)))
// If x is not a perfect square, round the result toward zero.
result := sub(result, gt(result, div(x, result)))
}
} |
|
Thanks @marmitar. I suggest putting the De Bruihn implementation in another branch so we can compare the complexity. I would also suggest keeping the original implementation in the tests directory so that we can compare it to the new implementation from this PR (for regression testing). |
|
Oh, I just noticed that there are no fuzz tests for For For Also, that's what integer square root implementations do, see Python's |
|
Yeah, fuzz those for those functions would be helpful. It would be good idea to use established mathematical properties to test out those functions, but I would still keep a copy of the original implementations simply because a lot of PRBMath users are using them now, and in this way, we can stress test both the old version and the new version. We could do something like: Where That said, I can understand if you don't have the time to write the tests like this. Feel free to implement the simple tests for the new version, and I can handle the rest later. |
fc31231 to
57b01fa
Compare
|
Ok, I'll do that. In the meantime, I was tracking this gas variance on |
|
Thanks @marmitar, will review this over the weekend. Yeah, gas golfing is difficult with the latest versions of Solidity, especially when |
57b01fa to
b5a3acc
Compare
|
I opened #265 with property-based tests, rebased this on top of that one, and added the regression tests here. |
b5a3acc to
dc36e28
Compare
|
tyvm @marmitar. Apologies for the delay - I got overrun this weekend with life admin. I will review this week! |
|
Oh, no need to hurry. Do it when you have time. |
dc36e28 to
79a6727
Compare
Estimated gas reduction from 798.1 to 407 gas.
79a6727 to
a0dff59
Compare
|
Rebased it to main and updated the reference implementation of |
Branchless Square Root
Optimized
sqrtimplementation in Yul that avoids implicit and explicit branches, which are relatively expensive on the EVM. It eliminates Solidity's division checks, reuses and improvesmsb, and expresses conditionals arithmetically. The result is 49% lower gas cost on average and a constant 407 gas per call.Estimated Gas Cost
uint256Code Transformations
Branchless log2 (-40 gas on average)
The prior code branched once per region in
uint256, costing ~14 gas (PUSH/JUMPI/JUMPDEST) per branch plus 24-30 gas when taken. About half of those branches fire on average, so switching tomsb()already saves gas, even thoughmsb()does one extra iteration (msb0), whichsqrtdiscards.Slight regression at this step: inputs that previously short-circuited (including values around 1e18) go from ~700 to ~754 gas because early-outs are gone. Acceptable, because this change enables the
msb()optimization below, which cuts 60 gas and is required to reach the final 407 gas.Reorder instructions on
msb(-75 gas)msb()previously updtedxas it progressed. On a stack machine (EVM) this forces extraSWAP/DUPtraffic. We can computex >> resultevery iteration using the same number ofshr(). Saves 75 gas.Branchless "perfect square" condition (-42 gas on average)
The final iterate is always either$\lfloor\sqrt{x}\rfloor$ or $\lfloor\sqrt{x}\rfloor + 1$ . So we replace
result = min(result, x / result)with a boolean subtract:result -= (result > x / result) ? 1 : 0. Yul implementation drops gas from 679-687 to 641.Unchecked division (-196 gas)
Solidity still inserts a division-by-zero check, even in
uncheckedblocks (Checked or Unchecked Arithmetic). The only way to avoid it is to use assembly directly, which yields the biggest win here: -196 gas.Skip condition for zero (-38 gas on average)
With division in Yul, we drop the$x = 0$ regresses by +333. This is intentional: inputs are expected to be non-trivial most of the time ($\geq 90\%$ ), so the average gas cost improves.
x == 0branch and rely on the EVM semanticsdiv(a, 0) = 0(EVM Codes - DIV). This saves 38 gas for every nonzero input;Optimized inlined
msb(-21 gas, not implemented)Inlining
msb()intosqrtand skippingmsb0saves 21 gas, but the readability hit isn't worth it IMO, so I left it out. I can bring it back if you like it.