Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ pub mod alloc;
pub mod error;
pub mod math;
pub mod slab;
pub mod undo;
237 changes: 237 additions & 0 deletions crates/core/src/undo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
//! Helpers for undoing partial side effects when their larger operation fails.

use core::{fmt, mem, ops};

/// An RAII guard to rollback and undo something on (early) drop.
///
/// Dereferences to its inner `T` and its undo function is given the `T` on
/// drop.
///
/// When all of the changes that need to happen together have happened, you can
/// call `Undo::commit` to disable the guard and commit the associated side
/// effects.
///
/// # Example
///
/// ```
/// use std::cell::Cell;
/// use wasmtime_internal_core::{error::Result, undo::Undo};
///
/// /// Some big ball of state that must always be coherent.
/// pub struct Context {
/// // ...
/// }
///
/// impl Context {
/// /// Perform some incremental mutation to `self`, which might not leave
/// /// it in a valid state unless its whole batch of work is completed.
/// fn do_thing(&mut self, arg: u32) -> Result<()> {
/// # let _ = arg;
/// # todo!()
/// // ...
/// }
///
/// /// Undo the side effects of `self.do_thing(arg)` for when we need to
/// /// roll back mutations.
/// fn undo_thing(&mut self, arg: u32) {
/// # let _ = arg;
/// // ...
/// }
///
/// /// Call `self.do_thing(arg)` for each `arg` in `args`.
/// ///
/// /// However, if any `self.do_thing(arg)` call fails, make sure that
/// /// we roll back to the original state by calling `self.undo_thing(arg)`
/// /// for all the `self.do_thing(arg)` calls that already succeeded. This
/// /// way we never leave `self` in a state where things got half-done.
/// pub fn do_all_or_nothing(&mut self, args: &[u32]) -> Result<()> {
/// // Counter for our progress, so that we know how much to work undo upon
/// // failure.
/// let num_things_done = Cell::new(0);
///
/// // Wrap the `Context` in an `Undo` that rolls back our side effects if
/// // we early-exit this function via `?`-propagation or panic unwinding.
/// let mut ctx = Undo::new(self, |ctx| {
/// for arg in args.iter().take(num_things_done.get()) {
/// ctx.undo_thing(*arg);
/// }
/// });
///
/// // Do each piece of work!
/// for arg in args {
/// // Note: if this call returns an error that is `?`-propagated or
/// // triggers unwinding by panicking, then the work performed thus
/// // far will be rolled back when `ctx` is dropped.
/// ctx.do_thing(*arg)?;
///
/// // Update how much work has been completed.
/// num_things_done.set(num_things_done.get() + 1);
/// }
///
/// // We completed all of the work, so commit the `Undo` guard and
/// // disable its cleanup function.
/// Undo::commit(ctx);
///
/// Ok(())
/// }
/// }
/// ```
#[must_use = "`Undo` implicitly runs its undo function on drop; use `Undo::commit(...)` \
to disable"]
pub struct Undo<T, F>
where
F: FnOnce(T),
{
inner: Option<T>,
undo: Option<F>,
}

impl<T, F> Drop for Undo<T, F>
where
F: FnOnce(T),
{
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
let undo = self.undo.take().unwrap();
undo(inner);
}
}
}

impl<T, F> fmt::Debug for Undo<T, F>
where
F: FnOnce(T),
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Undo")
.field("inner", &self.inner)
.field("undo", &"..")
.finish()
}
}

impl<T, F> ops::Deref for Undo<T, F>
where
F: FnOnce(T),
{
type Target = T;

fn deref(&self) -> &Self::Target {
self.inner.as_ref().unwrap()
}
}

impl<T, F> ops::DerefMut for Undo<T, F>
where
F: FnOnce(T),
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.as_mut().unwrap()
}
}

impl<T, F> Undo<T, F>
where
F: FnOnce(T),
{
/// Create a new `Undo` guard.
///
/// This guard will wrap the given `inner` object and call `undo(inner)`
/// when dropped, unless the guard is disabled via `Undo::commit`.
pub fn new(inner: T, undo: F) -> Self {
Self {
inner: Some(inner),
undo: Some(undo),
}
}

/// Disable this `Undo` and return its inner value.
///
/// This `Undo`'s cleanup function will never be called.
pub fn commit(mut guard: Self) -> T {
let inner = guard.inner.take().unwrap();
mem::forget(guard);
inner
}
}

#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use crate::error::{Result, ensure};
use core::{cell::Cell, cmp};
use std::{panic, string::ToString};

#[derive(Default)]
struct Counter {
value: u32,
max_value_seen: u32,
}

impl Counter {
fn inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
f(self)?;
self.value += 1;
self.max_value_seen = cmp::max(self.max_value_seen, self.value);
Ok(())
}

fn dec(&mut self) {
self.value -= 1;
}

fn inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
let i = Cell::new(0);

let mut counter = Undo::new(self, |counter| {
for _ in 0..i.get() {
counter.dec();
}
});

for _ in 0..n {
counter.inc(&mut f)?;
i.set(i.get() + 1);
}

Undo::commit(counter);
Ok(())
}
}

#[test]
fn error_propagation() {
let mut counter = Counter::default();
let result = counter.inc_n(10, |c| {
ensure!(c.value < 5, "uh oh");
Ok(())
});
assert_eq!(result.unwrap_err().to_string(), "uh oh");
assert_eq!(counter.value, 0);
assert_eq!(counter.max_value_seen, 5);
}

#[test]
fn panic_unwind() {
let mut counter = Counter::default();
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
counter.inc_n(10, |c| {
assert!(c.value < 5);
Ok(())
})
}));
assert!(result.is_err());
assert_eq!(counter.value, 0);
assert_eq!(counter.max_value_seen, 5);
}

#[test]
fn commit() {
let mut counter = Counter::default();
let result = counter.inc_n(10, |_| Ok(()));
assert!(result.is_ok());
assert_eq!(counter.value, 10);
assert_eq!(counter.max_value_seen, 10);
}
}
2 changes: 2 additions & 0 deletions crates/environ/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ pub use wasmtime_core::error;
#[cfg(feature = "anyhow")]
pub use self::error::ToWasmtimeResult;

pub use wasmtime_core::{alloc::PanicOnOom, undo::Undo};

// Only for use with `bindgen!`-generated code.
#[doc(hidden)]
#[cfg(feature = "anyhow")]
Expand Down