Skip to content

Commit

Permalink
Wrap hooks in refcounter instead of box, same as previously.
Browse files Browse the repository at this point in the history
This allows to obtain independent clone of closure.
Ensure that stack is always clean when running thread hook.
  • Loading branch information
khvzak committed Feb 8, 2025
1 parent 5bbd23e commit d1a587f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 32 deletions.
10 changes: 4 additions & 6 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ impl Lua {
let lua = self.lock();
unsafe {
(*lua.extra.get()).hook_triggers = triggers;
(*lua.extra.get()).hook_callback = Some(Box::new(callback));
(*lua.extra.get()).hook_callback = Some(XRc::new(callback));
lua.set_thread_hook(lua.state(), HookKind::Global)
}
}
Expand Down Expand Up @@ -564,7 +564,7 @@ impl Lua {
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
{
let lua = self.lock();
unsafe { lua.set_thread_hook(lua.state(), HookKind::Thread(triggers, Box::new(callback))) }
unsafe { lua.set_thread_hook(lua.state(), HookKind::Thread(triggers, XRc::new(callback))) }
}

/// Removes a global hook previously set by [`Lua::set_global_hook`].
Expand Down Expand Up @@ -644,8 +644,6 @@ impl Lua {
where
F: Fn(&Lua) -> Result<VmState> + MaybeSend + 'static,
{
use std::rc::Rc;

unsafe extern "C-unwind" fn interrupt_proc(state: *mut ffi::lua_State, gc: c_int) {
if gc >= 0 {
// We don't support GC interrupts since they cannot survive Lua exceptions
Expand All @@ -654,7 +652,7 @@ impl Lua {
let result = callback_error_ext(state, ptr::null_mut(), move |extra, _| {
let interrupt_cb = (*extra).interrupt_callback.clone();
let interrupt_cb = mlua_expect!(interrupt_cb, "no interrupt callback set in interrupt_proc");
if Rc::strong_count(&interrupt_cb) > 2 {
if XRc::strong_count(&interrupt_cb) > 2 {
return Ok(VmState::Continue); // Don't allow recursion
}
let _guard = StateGuard::new((*extra).raw_lua(), state);
Expand All @@ -671,7 +669,7 @@ impl Lua {
// Set interrupt callback
let lua = self.lock();
unsafe {
(*lua.extra.get()).interrupt_callback = Some(Rc::new(callback));
(*lua.extra.get()).interrupt_callback = Some(XRc::new(callback));
(*ffi::lua_callbacks(lua.main_state())).interrupt = Some(interrupt_proc);
}
}
Expand Down
39 changes: 18 additions & 21 deletions src/state/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,15 +410,12 @@ impl RawLua {

unsafe extern "C-unwind" fn global_hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
let status = callback_error_ext(state, ptr::null_mut(), move |extra, _| {
let rawlua = (*extra).raw_lua();
let _guard = StateGuard::new(rawlua, state);
let debug = Debug::new(rawlua, ar);
match (*extra).hook_callback.take() {
Some(hook_cb) => {
// Temporary obtain ownership of the hook callback
let result = hook_cb((*extra).lua(), debug);
(*extra).hook_callback = Some(hook_cb);
result
match (*extra).hook_callback.clone() {
Some(hook_callback) => {
let rawlua = (*extra).raw_lua();
let _guard = StateGuard::new(rawlua, state);
let debug = Debug::new(rawlua, ar);
hook_callback((*extra).lua(), debug)
}
None => {
ffi::lua_sethook(state, None, 0, 0);
Expand All @@ -430,11 +427,17 @@ impl RawLua {
}

unsafe extern "C-unwind" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
let top = ffi::lua_gettop(state);
let mut hook_callback_ptr = ptr::null();
ffi::luaL_checkstack(state, 3, ptr::null());
ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, HOOKS_KEY);
ffi::lua_pushthread(state);
if ffi::lua_rawget(state, -2) != ffi::LUA_TUSERDATA {
ffi::lua_pop(state, 2);
if ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, HOOKS_KEY) == ffi::LUA_TTABLE {
ffi::lua_pushthread(state);
if ffi::lua_rawget(state, -2) == ffi::LUA_TUSERDATA {
hook_callback_ptr = get_internal_userdata::<HookCallback>(state, -1, ptr::null());
}
}
ffi::lua_settop(state, top);
if hook_callback_ptr.is_null() {
ffi::lua_sethook(state, None, 0, 0);
return;
}
Expand All @@ -443,13 +446,8 @@ impl RawLua {
let rawlua = (*extra).raw_lua();
let _guard = StateGuard::new(rawlua, state);
let debug = Debug::new(rawlua, ar);
match get_internal_userdata::<HookCallback>(state, -1, ptr::null()).as_ref() {
Some(hook_cb) => hook_cb((*extra).lua(), debug),
None => {
ffi::lua_sethook(state, None, 0, 0);
Ok(VmState::Continue)
}
}
let hook_callback = (*hook_callback_ptr).clone();
hook_callback((*extra).lua(), debug)
});
process_status(state, (*ar).event, status)
}
Expand Down Expand Up @@ -482,7 +480,6 @@ impl RawLua {

ffi::lua_pushthread(thread_state);
ffi::lua_xmove(thread_state, state, 1); // key (thread)
let callback: HookCallback = Box::new(callback);
let _ = push_internal_userdata(state, callback, false); // value (hook callback)
ffi::lua_rawset(state, -3); // hooktable[thread] = hook callback
})?;
Expand Down
7 changes: 6 additions & 1 deletion src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,12 @@ impl Thread {
F: Fn(&crate::Lua, Debug) -> Result<crate::VmState> + crate::MaybeSend + 'static,
{
let lua = self.0.lua.lock();
unsafe { lua.set_thread_hook(self.state(), HookKind::Thread(triggers, Box::new(callback))) }
unsafe {
lua.set_thread_hook(
self.state(),
HookKind::Thread(triggers, crate::types::XRc::new(callback)),
)
}
}

/// Removes any hook function from this thread.
Expand Down
8 changes: 4 additions & 4 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@ pub(crate) enum HookKind {
}

#[cfg(all(feature = "send", not(feature = "luau")))]
pub(crate) type HookCallback = Box<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;
pub(crate) type HookCallback = XRc<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;

#[cfg(all(not(feature = "send"), not(feature = "luau")))]
pub(crate) type HookCallback = Box<dyn Fn(&Lua, Debug) -> Result<VmState>>;
pub(crate) type HookCallback = XRc<dyn Fn(&Lua, Debug) -> Result<VmState>>;

#[cfg(all(feature = "send", feature = "luau"))]
pub(crate) type InterruptCallback = std::rc::Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;
pub(crate) type InterruptCallback = XRc<dyn Fn(&Lua) -> Result<VmState> + Send>;

#[cfg(all(not(feature = "send"), feature = "luau"))]
pub(crate) type InterruptCallback = std::rc::Rc<dyn Fn(&Lua) -> Result<VmState>>;
pub(crate) type InterruptCallback = XRc<dyn Fn(&Lua) -> Result<VmState>>;

#[cfg(all(feature = "send", feature = "lua54"))]
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &str, bool) -> Result<()> + Send>;
Expand Down

0 comments on commit d1a587f

Please sign in to comment.