|
| 1 | +//! Helpers for undoing partial side effects when their larger operation fails. |
| 2 | +
|
| 3 | +use core::{fmt, mem, ops}; |
| 4 | + |
| 5 | +/// An RAII guard to rollback and undo something on (early) drop. |
| 6 | +/// |
| 7 | +/// Dereferences to its inner `T` and its undo function is given the `T` on |
| 8 | +/// drop. |
| 9 | +/// |
| 10 | +/// When all of the changes that need to happen together have happened, you can |
| 11 | +/// call `Undo::commit` to disable the guard and commit the associated side |
| 12 | +/// effects. |
| 13 | +/// |
| 14 | +/// # Example |
| 15 | +/// |
| 16 | +/// ``` |
| 17 | +/// use std::cell::Cell; |
| 18 | +/// use wasmtime_internal_core::{error::Result, undo::Undo}; |
| 19 | +/// |
| 20 | +/// /// Some big ball of state that must always be coherent. |
| 21 | +/// pub struct Context { |
| 22 | +/// // ... |
| 23 | +/// } |
| 24 | +/// |
| 25 | +/// impl Context { |
| 26 | +/// /// Perform some incremental mutation to `self`, which might not leave |
| 27 | +/// /// it in a valid state unless its whole batch of work is completed. |
| 28 | +/// fn do_thing(&mut self, arg: u32) -> Result<()> { |
| 29 | +/// # let _ = arg; |
| 30 | +/// # todo!() |
| 31 | +/// // ... |
| 32 | +/// } |
| 33 | +/// |
| 34 | +/// /// Undo the side effects of `self.do_thing(arg)` for when we need to |
| 35 | +/// /// roll back mutations. |
| 36 | +/// fn undo_thing(&mut self, arg: u32) { |
| 37 | +/// # let _ = arg; |
| 38 | +/// // ... |
| 39 | +/// } |
| 40 | +/// |
| 41 | +/// /// Call `self.do_thing(arg)` for each `arg` in `args`. |
| 42 | +/// /// |
| 43 | +/// /// However, if any `self.do_thing(arg)` call fails, make sure that |
| 44 | +/// /// we roll back to the original state by calling `self.undo_thing(arg)` |
| 45 | +/// /// for all the `self.do_thing(arg)` calls that already succeeded. This |
| 46 | +/// /// way we never leave `self` in a state where things got half-done. |
| 47 | +/// pub fn do_all_or_nothing(&mut self, args: &[u32]) -> Result<()> { |
| 48 | +/// // Counter for our progress, so that we know how much to work undo upon |
| 49 | +/// // failure. |
| 50 | +/// let num_things_done = Cell::new(0); |
| 51 | +/// |
| 52 | +/// // Wrap the `Context` in an `Undo` that rolls back our side effects if |
| 53 | +/// // we early-exit this function via `?`-propagation or panic unwinding. |
| 54 | +/// let mut ctx = Undo::new(self, |ctx| { |
| 55 | +/// for arg in args.iter().take(num_things_done.get()) { |
| 56 | +/// ctx.undo_thing(*arg); |
| 57 | +/// } |
| 58 | +/// }); |
| 59 | +/// |
| 60 | +/// // Do each piece of work! |
| 61 | +/// for arg in args { |
| 62 | +/// // Note: if this call returns an error that is `?`-propagated or |
| 63 | +/// // triggers unwinding by panicking, then the work performed thus |
| 64 | +/// // far will be rolled back when `ctx` is dropped. |
| 65 | +/// ctx.do_thing(*arg)?; |
| 66 | +/// |
| 67 | +/// // Update how much work has been completed. |
| 68 | +/// num_things_done.set(num_things_done.get() + 1); |
| 69 | +/// } |
| 70 | +/// |
| 71 | +/// // We completed all of the work, so commit the `Undo` guard and |
| 72 | +/// // disable its cleanup function. |
| 73 | +/// Undo::commit(ctx); |
| 74 | +/// |
| 75 | +/// Ok(()) |
| 76 | +/// } |
| 77 | +/// } |
| 78 | +/// ``` |
| 79 | +#[must_use = "`Undo` implicitly runs its undo function on drop; use `Undo::commit(...)` \ |
| 80 | + to disable"] |
| 81 | +pub struct Undo<T, F> |
| 82 | +where |
| 83 | + F: FnOnce(T), |
| 84 | +{ |
| 85 | + inner: mem::ManuallyDrop<T>, |
| 86 | + undo: mem::ManuallyDrop<F>, |
| 87 | +} |
| 88 | + |
| 89 | +impl<T, F> Drop for Undo<T, F> |
| 90 | +where |
| 91 | + F: FnOnce(T), |
| 92 | +{ |
| 93 | + fn drop(&mut self) { |
| 94 | + // Safety: These `ManuallyDrop` fields will not be used again. |
| 95 | + let inner = unsafe { mem::ManuallyDrop::take(&mut self.inner) }; |
| 96 | + let undo = unsafe { mem::ManuallyDrop::take(&mut self.undo) }; |
| 97 | + undo(inner); |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +impl<T, F> fmt::Debug for Undo<T, F> |
| 102 | +where |
| 103 | + F: FnOnce(T), |
| 104 | + T: fmt::Debug, |
| 105 | +{ |
| 106 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 107 | + f.debug_struct("Undo") |
| 108 | + .field("inner", &self.inner) |
| 109 | + .field("undo", &"..") |
| 110 | + .finish() |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +impl<T, F> ops::Deref for Undo<T, F> |
| 115 | +where |
| 116 | + F: FnOnce(T), |
| 117 | +{ |
| 118 | + type Target = T; |
| 119 | + |
| 120 | + fn deref(&self) -> &Self::Target { |
| 121 | + &self.inner |
| 122 | + } |
| 123 | +} |
| 124 | + |
| 125 | +impl<T, F> ops::DerefMut for Undo<T, F> |
| 126 | +where |
| 127 | + F: FnOnce(T), |
| 128 | +{ |
| 129 | + fn deref_mut(&mut self) -> &mut Self::Target { |
| 130 | + &mut self.inner |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +impl<T, F> Undo<T, F> |
| 135 | +where |
| 136 | + F: FnOnce(T), |
| 137 | +{ |
| 138 | + /// Create a new `Undo` guard. |
| 139 | + /// |
| 140 | + /// This guard will wrap the given `inner` object and call `undo(inner)` |
| 141 | + /// when dropped, unless the guard is disabled via `Undo::commit`. |
| 142 | + pub fn new(inner: T, undo: F) -> Self { |
| 143 | + Self { |
| 144 | + inner: mem::ManuallyDrop::new(inner), |
| 145 | + undo: mem::ManuallyDrop::new(undo), |
| 146 | + } |
| 147 | + } |
| 148 | + |
| 149 | + /// Disable this `Undo` and return its inner value. |
| 150 | + /// |
| 151 | + /// This `Undo`'s cleanup function will never be called. |
| 152 | + pub fn commit(guard: Self) -> T { |
| 153 | + let mut guard = mem::ManuallyDrop::new(guard); |
| 154 | + |
| 155 | + // Safety: These `ManuallyDrop` fields will not be used again. |
| 156 | + unsafe { |
| 157 | + // Make sure to drop `undo`, even though we aren't calling it, to |
| 158 | + // avoid leaking closed-over `Arc`s, for example. |
| 159 | + mem::ManuallyDrop::drop(&mut guard.undo); |
| 160 | + |
| 161 | + mem::ManuallyDrop::take(&mut guard.inner) |
| 162 | + } |
| 163 | + } |
| 164 | +} |
| 165 | + |
| 166 | +#[cfg(all(test, feature = "std"))] |
| 167 | +mod tests { |
| 168 | + use super::*; |
| 169 | + use crate::error::{Result, ensure}; |
| 170 | + use core::{cell::Cell, cmp}; |
| 171 | + use std::{panic, string::ToString}; |
| 172 | + |
| 173 | + #[derive(Default)] |
| 174 | + struct Counter { |
| 175 | + value: u32, |
| 176 | + max_value_seen: u32, |
| 177 | + } |
| 178 | + |
| 179 | + impl Counter { |
| 180 | + fn inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> { |
| 181 | + f(self)?; |
| 182 | + self.value += 1; |
| 183 | + self.max_value_seen = cmp::max(self.max_value_seen, self.value); |
| 184 | + Ok(()) |
| 185 | + } |
| 186 | + |
| 187 | + fn dec(&mut self) { |
| 188 | + self.value -= 1; |
| 189 | + } |
| 190 | + |
| 191 | + fn inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> { |
| 192 | + let i = Cell::new(0); |
| 193 | + |
| 194 | + let mut counter = Undo::new(self, |counter| { |
| 195 | + for _ in 0..i.get() { |
| 196 | + counter.dec(); |
| 197 | + } |
| 198 | + }); |
| 199 | + |
| 200 | + for _ in 0..n { |
| 201 | + counter.inc(&mut f)?; |
| 202 | + i.set(i.get() + 1); |
| 203 | + } |
| 204 | + |
| 205 | + Undo::commit(counter); |
| 206 | + Ok(()) |
| 207 | + } |
| 208 | + } |
| 209 | + |
| 210 | + #[test] |
| 211 | + fn error_propagation() { |
| 212 | + let mut counter = Counter::default(); |
| 213 | + let result = counter.inc_n(10, |c| { |
| 214 | + ensure!(c.value < 5, "uh oh"); |
| 215 | + Ok(()) |
| 216 | + }); |
| 217 | + assert_eq!(result.unwrap_err().to_string(), "uh oh"); |
| 218 | + assert_eq!(counter.value, 0); |
| 219 | + assert_eq!(counter.max_value_seen, 5); |
| 220 | + } |
| 221 | + |
| 222 | + #[test] |
| 223 | + fn panic_unwind() { |
| 224 | + let mut counter = Counter::default(); |
| 225 | + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { |
| 226 | + counter.inc_n(10, |c| { |
| 227 | + assert!(c.value < 5); |
| 228 | + Ok(()) |
| 229 | + }) |
| 230 | + })); |
| 231 | + assert!(result.is_err()); |
| 232 | + assert_eq!(counter.value, 0); |
| 233 | + assert_eq!(counter.max_value_seen, 5); |
| 234 | + } |
| 235 | + |
| 236 | + #[test] |
| 237 | + fn commit() { |
| 238 | + let mut counter = Counter::default(); |
| 239 | + let result = counter.inc_n(10, |_| Ok(())); |
| 240 | + assert!(result.is_ok()); |
| 241 | + assert_eq!(counter.value, 10); |
| 242 | + assert_eq!(counter.max_value_seen, 10); |
| 243 | + } |
| 244 | +} |
0 commit comments