Skip to content

Commit 7da8d64

Browse files
committed
Add checkBalanceOf
1 parent 115220c commit 7da8d64

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

src/utils/SafeTransferLib.sol

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,27 @@ library SafeTransferLib {
405405
}
406406
}
407407

408+
/// @dev Performs a `token.balanceOf(account)` check.
409+
/// `implemented` denotes whether the `token` does not implement `balanceOf`.
410+
/// `amount` is zero if the `token` does not implement `balanceOf`.
411+
function checkBalanceOf(address token, address account)
412+
internal
413+
view
414+
returns (bool implemented, uint256 amount)
415+
{
416+
/// @solidity memory-safe-assembly
417+
assembly {
418+
mstore(0x14, account) // Store the `account` argument.
419+
mstore(0x00, 0x70a08231000000000000000000000000) // `balanceOf(address)`.
420+
implemented :=
421+
and( // The arguments of `and` are evaluated from right to left.
422+
gt(returndatasize(), 0x1f), // At least 32 bytes returned.
423+
staticcall(gas(), token, 0x10, 0x24, 0x20, 0x20)
424+
)
425+
amount := mul(mload(0x20), implemented)
426+
}
427+
}
428+
408429
/// @dev Returns the total supply of the `token`.
409430
/// Reverts if the token does not exist or does not implement `totalSupply()`.
410431
function totalSupply(address token) internal view returns (uint256 result) {

test/SafeTransferLib.t.sol

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,23 @@ contract SafeTransferLibTest is SoladyTest {
386386
assertEq(SafeTransferLib.balanceOf(address(erc20), _brutalized(address(this))), amount);
387387
}
388388

389+
function testCheckBalanceOfNonImplemented() public {
390+
(bool implemented,) = SafeTransferLib.checkBalanceOf(address(0), _brutalized(address(this)));
391+
assertFalse(implemented);
392+
}
393+
394+
function testCheckBalanceOf(address to, uint256 amount) public {
395+
uint256 originalBalance = erc20.balanceOf(address(this));
396+
while (originalBalance < amount) amount = _random();
397+
while (to == address(this)) to = _randomHashedAddress();
398+
399+
SafeTransferLib.safeTransfer(address(erc20), _brutalized(to), originalBalance - amount);
400+
(bool implemented, uint256 retrievedAmount) =
401+
SafeTransferLib.checkBalanceOf(address(erc20), _brutalized(address(this)));
402+
assertEq(retrievedAmount, amount);
403+
assertTrue(implemented);
404+
}
405+
389406
function testTransferAllWithStandardERC20() public {
390407
SafeTransferLib.safeTransferAll(address(erc20), address(1));
391408
}

0 commit comments

Comments
 (0)