|
2 | 2 |
|
3 | 3 | pragma solidity ^0.8.0; |
4 | 4 |
|
5 | | -import { TransientLib, tuint256 } from "./Transient.sol"; |
6 | | - |
7 | | -struct ReentrancyGuard { |
8 | | - tuint256 _raw; |
9 | | -} |
10 | | - |
11 | | -library ReentrancyGuardLib { |
12 | | - using TransientLib for tuint256; |
| 5 | +import { TransientLock, TransientLockLib } from "./TransientLock.sol"; |
| 6 | + |
| 7 | +/// @dev Base contract with reentrancy guard functionality using transient storage. |
| 8 | +/// |
| 9 | +/// Use private _lock defined in this contract: |
| 10 | +/// ```solidity |
| 11 | +/// function swap(...) external nonReentrant { |
| 12 | +/// function doMagic(...) external onlyNonReentrantCall { |
| 13 | +/// ``` |
| 14 | +/// |
| 15 | +/// Or use your own locks for more flexibility: |
| 16 | +/// ```solidity |
| 17 | +/// Lock private _myLock; |
| 18 | +/// function swap(...) external nonReentrantLock(_myLock) { |
| 19 | +/// function doMagic(...) external onlyNonReentrantCallLock(_myLock) { |
| 20 | +/// ``` |
| 21 | +/// |
| 22 | +abstract contract ReentrancyGuard { |
| 23 | + using TransientLockLib for TransientLock; |
| 24 | + |
| 25 | + error MissingNonReentrantModifier(); |
| 26 | + |
| 27 | + TransientLock private _lock; |
| 28 | + |
| 29 | + modifier nonReentrant { |
| 30 | + _lock.lock(); |
| 31 | + _; |
| 32 | + _lock.unlock(); |
| 33 | + } |
13 | 34 |
|
14 | | - error ReentrantCallDetected(); |
15 | | - error EnterLeaveDisbalance(); |
| 35 | + modifier onlyNonReentrantCall { |
| 36 | + if (!_inNonReentrantCall()) revert MissingNonReentrantModifier(); |
| 37 | + _; |
| 38 | + } |
16 | 39 |
|
17 | | - function enter(ReentrancyGuard storage self) internal { |
18 | | - if (self._raw.inc() != 1) revert ReentrantCallDetected(); |
| 40 | + modifier nonReentrantLock(TransientLock storage lock) { |
| 41 | + lock.lock(); |
| 42 | + _; |
| 43 | + lock.unlock(); |
19 | 44 | } |
20 | 45 |
|
21 | | - function enterNoIncrement(ReentrancyGuard storage self) internal view { |
22 | | - if (self._raw.tload() != 0) revert ReentrantCallDetected(); |
| 46 | + modifier onlyNonReentrantCallLock(TransientLock storage lock) { |
| 47 | + if (!lock.isLocked()) revert MissingNonReentrantModifier(); |
| 48 | + _; |
23 | 49 | } |
24 | 50 |
|
25 | | - function leave(ReentrancyGuard storage self) internal { |
26 | | - self._raw.dec(EnterLeaveDisbalance.selector); |
| 51 | + function _inNonReentrantCall() internal view returns (bool) { |
| 52 | + return _lock.isLocked(); |
27 | 53 | } |
28 | 54 | } |
0 commit comments