diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 1ac13781934f..6e538cdde366 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -22,3 +22,4 @@ pub mod alloc; pub mod error; pub mod math; pub mod slab; +pub mod undo; diff --git a/crates/core/src/undo.rs b/crates/core/src/undo.rs new file mode 100644 index 000000000000..c9bd5c3b19f5 --- /dev/null +++ b/crates/core/src/undo.rs @@ -0,0 +1,244 @@ +//! Helpers for undoing partial side effects when their larger operation fails. + +use core::{fmt, mem, ops}; + +/// An RAII guard to rollback and undo something on (early) drop. +/// +/// Dereferences to its inner `T` and its undo function is given the `T` on +/// drop. +/// +/// When all of the changes that need to happen together have happened, you can +/// call `Undo::commit` to disable the guard and commit the associated side +/// effects. +/// +/// # Example +/// +/// ``` +/// use std::cell::Cell; +/// use wasmtime_internal_core::{error::Result, undo::Undo}; +/// +/// /// Some big ball of state that must always be coherent. +/// pub struct Context { +/// // ... +/// } +/// +/// impl Context { +/// /// Perform some incremental mutation to `self`, which might not leave +/// /// it in a valid state unless its whole batch of work is completed. +/// fn do_thing(&mut self, arg: u32) -> Result<()> { +/// # let _ = arg; +/// # todo!() +/// // ... +/// } +/// +/// /// Undo the side effects of `self.do_thing(arg)` for when we need to +/// /// roll back mutations. +/// fn undo_thing(&mut self, arg: u32) { +/// # let _ = arg; +/// // ... +/// } +/// +/// /// Call `self.do_thing(arg)` for each `arg` in `args`. +/// /// +/// /// However, if any `self.do_thing(arg)` call fails, make sure that +/// /// we roll back to the original state by calling `self.undo_thing(arg)` +/// /// for all the `self.do_thing(arg)` calls that already succeeded. This +/// /// way we never leave `self` in a state where things got half-done. +/// pub fn do_all_or_nothing(&mut self, args: &[u32]) -> Result<()> { +/// // Counter for our progress, so that we know how much to work undo upon +/// // failure. +/// let num_things_done = Cell::new(0); +/// +/// // Wrap the `Context` in an `Undo` that rolls back our side effects if +/// // we early-exit this function via `?`-propagation or panic unwinding. +/// let mut ctx = Undo::new(self, |ctx| { +/// for arg in args.iter().take(num_things_done.get()) { +/// ctx.undo_thing(*arg); +/// } +/// }); +/// +/// // Do each piece of work! +/// for arg in args { +/// // Note: if this call returns an error that is `?`-propagated or +/// // triggers unwinding by panicking, then the work performed thus +/// // far will be rolled back when `ctx` is dropped. +/// ctx.do_thing(*arg)?; +/// +/// // Update how much work has been completed. +/// num_things_done.set(num_things_done.get() + 1); +/// } +/// +/// // We completed all of the work, so commit the `Undo` guard and +/// // disable its cleanup function. +/// Undo::commit(ctx); +/// +/// Ok(()) +/// } +/// } +/// ``` +#[must_use = "`Undo` implicitly runs its undo function on drop; use `Undo::commit(...)` \ + to disable"] +pub struct Undo +where + F: FnOnce(T), +{ + inner: mem::ManuallyDrop, + undo: mem::ManuallyDrop, +} + +impl Drop for Undo +where + F: FnOnce(T), +{ + fn drop(&mut self) { + // Safety: These `ManuallyDrop` fields will not be used again. + let inner = unsafe { mem::ManuallyDrop::take(&mut self.inner) }; + let undo = unsafe { mem::ManuallyDrop::take(&mut self.undo) }; + undo(inner); + } +} + +impl fmt::Debug for Undo +where + F: FnOnce(T), + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Undo") + .field("inner", &self.inner) + .field("undo", &"..") + .finish() + } +} + +impl ops::Deref for Undo +where + F: FnOnce(T), +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl ops::DerefMut for Undo +where + F: FnOnce(T), +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl Undo +where + F: FnOnce(T), +{ + /// Create a new `Undo` guard. + /// + /// This guard will wrap the given `inner` object and call `undo(inner)` + /// when dropped, unless the guard is disabled via `Undo::commit`. + pub fn new(inner: T, undo: F) -> Self { + Self { + inner: mem::ManuallyDrop::new(inner), + undo: mem::ManuallyDrop::new(undo), + } + } + + /// Disable this `Undo` and return its inner value. + /// + /// This `Undo`'s cleanup function will never be called. + pub fn commit(guard: Self) -> T { + let mut guard = mem::ManuallyDrop::new(guard); + + // Safety: These `ManuallyDrop` fields will not be used again. + unsafe { + // Make sure to drop `undo`, even though we aren't calling it, to + // avoid leaking closed-over `Arc`s, for example. + mem::ManuallyDrop::drop(&mut guard.undo); + + mem::ManuallyDrop::take(&mut guard.inner) + } + } +} + +#[cfg(all(test, feature = "std"))] +mod tests { + use super::*; + use crate::error::{Result, ensure}; + use core::{cell::Cell, cmp}; + use std::{panic, string::ToString}; + + #[derive(Default)] + struct Counter { + value: u32, + max_value_seen: u32, + } + + impl Counter { + fn inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> { + f(self)?; + self.value += 1; + self.max_value_seen = cmp::max(self.max_value_seen, self.value); + Ok(()) + } + + fn dec(&mut self) { + self.value -= 1; + } + + fn inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> { + let i = Cell::new(0); + + let mut counter = Undo::new(self, |counter| { + for _ in 0..i.get() { + counter.dec(); + } + }); + + for _ in 0..n { + counter.inc(&mut f)?; + i.set(i.get() + 1); + } + + Undo::commit(counter); + Ok(()) + } + } + + #[test] + fn error_propagation() { + let mut counter = Counter::default(); + let result = counter.inc_n(10, |c| { + ensure!(c.value < 5, "uh oh"); + Ok(()) + }); + assert_eq!(result.unwrap_err().to_string(), "uh oh"); + assert_eq!(counter.value, 0); + assert_eq!(counter.max_value_seen, 5); + } + + #[test] + fn panic_unwind() { + let mut counter = Counter::default(); + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + counter.inc_n(10, |c| { + assert!(c.value < 5); + Ok(()) + }) + })); + assert!(result.is_err()); + assert_eq!(counter.value, 0); + assert_eq!(counter.max_value_seen, 5); + } + + #[test] + fn commit() { + let mut counter = Counter::default(); + let result = counter.inc_n(10, |_| Ok(())); + assert!(result.is_ok()); + assert_eq!(counter.value, 10); + assert_eq!(counter.max_value_seen, 10); + } +} diff --git a/crates/environ/src/lib.rs b/crates/environ/src/lib.rs index b37dd9dd9196..29806dd0b3ea 100644 --- a/crates/environ/src/lib.rs +++ b/crates/environ/src/lib.rs @@ -88,6 +88,8 @@ pub use wasmtime_core::error; #[cfg(feature = "anyhow")] pub use self::error::ToWasmtimeResult; +pub use wasmtime_core::{alloc::PanicOnOom, undo::Undo}; + // Only for use with `bindgen!`-generated code. #[doc(hidden)] #[cfg(feature = "anyhow")]