From c1f8abba9e2c40c04ca5a862a5210d1dbebd4793 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Sun, 9 Feb 2025 15:08:52 +0000 Subject: [PATCH] Do not allow recursive warnings (Lua 5.4) --- src/state.rs | 13 +++++++------ src/types.rs | 4 ++-- tests/tests.rs | 7 +++++++ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/state.rs b/src/state.rs index a39f22cc..1054fbfc 100644 --- a/src/state.rs +++ b/src/state.rs @@ -703,18 +703,19 @@ impl Lua { unsafe extern "C-unwind" fn warn_proc(ud: *mut c_void, msg: *const c_char, tocont: c_int) { let extra = ud as *mut ExtraData; callback_error_ext((*extra).raw_lua().state(), extra, |extra, _| { - let cb = mlua_expect!( - (*extra).warn_callback.as_ref(), - "no warning callback set in warn_proc" - ); + let warn_callback = (*extra).warn_callback.clone(); + let warn_callback = mlua_expect!(warn_callback, "no warning callback set in warn_proc"); + if XRc::strong_count(&warn_callback) > 2 { + return Ok(()); + } let msg = StdString::from_utf8_lossy(CStr::from_ptr(msg).to_bytes()); - cb((*extra).lua(), &msg, tocont != 0) + warn_callback((*extra).lua(), &msg, tocont != 0) }); } let lua = self.lock(); unsafe { - (*lua.extra.get()).warn_callback = Some(Box::new(callback)); + (*lua.extra.get()).warn_callback = Some(XRc::new(callback)); ffi::lua_setwarnf(lua.state(), Some(warn_proc), lua.extra.get() as *mut c_void); } } diff --git a/src/types.rs b/src/types.rs index 35257847..537e1feb 100644 --- a/src/types.rs +++ b/src/types.rs @@ -91,10 +91,10 @@ pub(crate) type InterruptCallback = XRc Result + Send>; pub(crate) type InterruptCallback = XRc Result>; #[cfg(all(feature = "send", feature = "lua54"))] -pub(crate) type WarnCallback = Box Result<()> + Send>; +pub(crate) type WarnCallback = XRc Result<()> + Send>; #[cfg(all(not(feature = "send"), feature = "lua54"))] -pub(crate) type WarnCallback = Box Result<()>>; +pub(crate) type WarnCallback = XRc Result<()>>; /// A trait that adds `Send` requirement if `send` feature is enabled. #[cfg(feature = "send")] diff --git a/tests/tests.rs b/tests/tests.rs index af236e0c..01f706d7 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1289,6 +1289,13 @@ fn test_warnings() -> Result<()> { if matches!(*cause, Error::RuntimeError(ref err) if err == "warning error") )); + // Recursive warning + lua.set_warning_function(|lua, _, _| { + lua.warning("inner", false); + Ok(()) + }); + lua.warning("hello", false); + Ok(()) }