Skip to content

Commit

Permalink
Rework Lua hooks:
Browse files Browse the repository at this point in the history
- Support global hooks inherited by new threads
- Support thread hooks, where each thread can have its own hook

This should also allow to enable hooks for async calls.
Related to #489 #347
  • Loading branch information
khvzak committed Feb 7, 2025
1 parent a89800b commit 74b4601
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 91 deletions.
57 changes: 39 additions & 18 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::util::{
use crate::value::{Nil, Value};

#[cfg(not(feature = "luau"))]
use crate::hook::HookTriggers;
use crate::{hook::HookTriggers, types::HookKind};

#[cfg(any(feature = "luau", doc))]
use crate::{buffer::Buffer, chunk::Compiler};
Expand Down Expand Up @@ -501,6 +501,26 @@ impl Lua {
}
}

/// Sets or replaces a global hook function that will periodically be called as Lua code
/// executes.
///
/// All new threads created (by mlua) after this call will use the global hook function.
///
/// For more information see [`Lua::set_hook`].
#[cfg(not(feature = "luau"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn set_global_hook<F>(&self, triggers: HookTriggers, callback: F) -> Result<()>
where
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
{
let lua = self.lock();
unsafe {
(*lua.extra.get()).hook_triggers = triggers;
(*lua.extra.get()).hook_callback = Some(Box::new(callback));
lua.set_thread_hook(lua.state(), HookKind::Global)
}
}

/// Sets a hook function that will periodically be called as Lua code executes.
///
/// When exactly the hook function is called depends on the contents of the `triggers`
Expand All @@ -511,12 +531,10 @@ impl Lua {
/// limited form of execution limits by setting [`HookTriggers.every_nth_instruction`] and
/// erroring once an instruction limit has been reached.
///
/// This method sets a hook function for the current thread of this Lua instance.
/// This method sets a hook function for the *current* thread of this Lua instance.
/// If you want to set a hook function for another thread (coroutine), use
/// [`Thread::set_hook`] instead.
///
/// Please note you cannot have more than one hook function set at a time for this Lua instance.
///
/// # Example
///
/// Shows each line number of code being executed by the Lua interpreter.
Expand All @@ -541,33 +559,36 @@ impl Lua {
/// [`HookTriggers.every_nth_instruction`]: crate::HookTriggers::every_nth_instruction
#[cfg(not(feature = "luau"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F) -> Result<()>
where
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
{
let lua = self.lock();
unsafe { lua.set_thread_hook(lua.state(), triggers, callback) };
unsafe { lua.set_thread_hook(lua.state(), HookKind::Thread(triggers, Box::new(callback))) }
}

/// Removes any hook previously set by [`Lua::set_hook`] or [`Thread::set_hook`].
/// Removes a global hook previously set by [`Lua::set_global_hook`].
///
/// This function has no effect if a hook was not previously set.
#[cfg(not(feature = "luau"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn remove_hook(&self) {
pub fn remove_global_hook(&self) {
let lua = self.lock();
unsafe {
let state = lua.state();
ffi::lua_sethook(state, None, 0, 0);
match lua.main_state {
Some(main_state) if state != main_state.as_ptr() => {
// If main_state is different from state, remove hook from it too
ffi::lua_sethook(main_state.as_ptr(), None, 0, 0);
}
_ => {}
};
(*lua.extra.get()).hook_callback = None;
(*lua.extra.get()).hook_thread = ptr::null_mut();
(*lua.extra.get()).hook_triggers = HookTriggers::default();
}
}

/// Removes any hook from the current thread.
///
/// This function has no effect if a hook was not previously set.
#[cfg(not(feature = "luau"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn remove_hook(&self) {
let lua = self.lock();
unsafe {
ffi::lua_sethook(lua.state(), None, 0, 0);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/state/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub(crate) struct ExtraData {
#[cfg(not(feature = "luau"))]
pub(super) hook_callback: Option<crate::types::HookCallback>,
#[cfg(not(feature = "luau"))]
pub(super) hook_thread: *mut ffi::lua_State,
pub(super) hook_triggers: crate::hook::HookTriggers,
#[cfg(feature = "lua54")]
pub(super) warn_callback: Option<crate::types::WarnCallback>,
#[cfg(feature = "luau")]
Expand Down Expand Up @@ -171,7 +171,7 @@ impl ExtraData {
#[cfg(not(feature = "luau"))]
hook_callback: None,
#[cfg(not(feature = "luau"))]
hook_thread: ptr::null_mut(),
hook_triggers: Default::default(),
#[cfg(feature = "lua54")]
warn_callback: None,
#[cfg(feature = "luau")]
Expand Down
135 changes: 101 additions & 34 deletions src/state/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ use super::extra::ExtraData;
use super::{Lua, LuaOptions, WeakLua};

#[cfg(not(feature = "luau"))]
use crate::hook::{Debug, HookTriggers};
use crate::{
hook::Debug,
types::{HookCallback, HookKind, VmState},
};

#[cfg(feature = "async")]
use {
Expand Down Expand Up @@ -186,6 +189,8 @@ impl RawLua {
init_internal_metatable::<XRc<UnsafeCell<ExtraData>>>(state, None)?;
init_internal_metatable::<Callback>(state, None)?;
init_internal_metatable::<CallbackUpvalue>(state, None)?;
#[cfg(not(feature = "luau"))]
init_internal_metatable::<HookCallback>(state, None)?;
#[cfg(feature = "async")]
{
init_internal_metatable::<AsyncCallback>(state, None)?;
Expand Down Expand Up @@ -373,42 +378,22 @@ impl RawLua {
status
}

/// Sets a 'hook' function for a thread (coroutine).
/// Sets a hook for a thread (coroutine).
#[cfg(not(feature = "luau"))]
pub(crate) unsafe fn set_thread_hook<F>(
pub(crate) unsafe fn set_thread_hook(
&self,
state: *mut ffi::lua_State,
triggers: HookTriggers,
callback: F,
) where
F: Fn(&Lua, Debug) -> Result<crate::VmState> + MaybeSend + 'static,
{
use crate::types::VmState;
use std::rc::Rc;
thread_state: *mut ffi::lua_State,
hook: HookKind,
) -> Result<()> {
// Key to store hooks in the registry
const HOOKS_KEY: *const c_char = cstr!("__mlua_hooks");

unsafe extern "C-unwind" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
let extra = ExtraData::get(state);
if (*extra).hook_thread != state {
// Hook was destined for a different thread, ignore
ffi::lua_sethook(state, None, 0, 0);
return;
}
let result = callback_error_ext(state, extra, move |extra, _| {
let hook_cb = (*extra).hook_callback.clone();
let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc");
if Rc::strong_count(&hook_cb) > 2 {
return Ok(VmState::Continue); // Don't allow recursion
}
let rawlua = (*extra).raw_lua();
let _guard = StateGuard::new(rawlua, state);
let debug = Debug::new(rawlua, ar);
hook_cb((*extra).lua(), debug)
});
match result {
unsafe fn process_status(state: *mut ffi::lua_State, event: c_int, status: VmState) {
match status {
VmState::Continue => {}
VmState::Yield => {
// Only count and line events can yield
if (*ar).event == ffi::LUA_HOOKCOUNT || (*ar).event == ffi::LUA_HOOKLINE {
if event == ffi::LUA_HOOKCOUNT || event == ffi::LUA_HOOKLINE {
#[cfg(any(feature = "lua54", feature = "lua53"))]
if ffi::lua_isyieldable(state) != 0 {
ffi::lua_yield(state, 0);
Expand All @@ -423,9 +408,86 @@ impl RawLua {
}
}

(*self.extra.get()).hook_callback = Some(Rc::new(callback));
(*self.extra.get()).hook_thread = state; // Mark for what thread the hook is set
ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count());
unsafe extern "C-unwind" fn global_hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
let status = callback_error_ext(state, ptr::null_mut(), move |extra, _| {
let rawlua = (*extra).raw_lua();
let debug = Debug::new(rawlua, ar);
match (*extra).hook_callback.take() {
Some(hook_cb) => {
// Temporary obtain ownership of the hook callback
let result = hook_cb((*extra).lua(), debug);
(*extra).hook_callback = Some(hook_cb);
result
}
None => {
ffi::lua_sethook(state, None, 0, 0);
Ok(VmState::Continue)
}
}
});
process_status(state, (*ar).event, status);
}

unsafe extern "C-unwind" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
ffi::luaL_checkstack(state, 3, ptr::null());
ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, HOOKS_KEY);
ffi::lua_pushthread(state);
if ffi::lua_rawget(state, -2) != ffi::LUA_TUSERDATA {
ffi::lua_pop(state, 2);
ffi::lua_sethook(state, None, 0, 0);
return;
}

let status = callback_error_ext(state, ptr::null_mut(), |extra, _| {
let rawlua = (*extra).raw_lua();
let debug = Debug::new(rawlua, ar);
match get_internal_userdata::<HookCallback>(state, -1, ptr::null()).as_ref() {
Some(hook_cb) => hook_cb((*extra).lua(), debug),
None => {
ffi::lua_sethook(state, None, 0, 0);
Ok(VmState::Continue)
}
}
});
process_status(state, (*ar).event, status)
}

let (triggers, callback) = match hook {
HookKind::Global if (*self.extra.get()).hook_callback.is_none() => {
return Ok(());
}
HookKind::Global => {
let triggers = (*self.extra.get()).hook_triggers;
let (mask, count) = (triggers.mask(), triggers.count());
ffi::lua_sethook(thread_state, Some(global_hook_proc), mask, count);
return Ok(());
}
HookKind::Thread(triggers, callback) => (triggers, callback),
};

// Hooks for threads stored in the registry (in a weak table)
let state = self.state();
let _sg = StackGuard::new(state);
check_stack(state, 3)?;
protect_lua!(state, 0, 0, |state| {
if ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, HOOKS_KEY) == 0 {
// Table just created, initialize it
ffi::lua_pushliteral(state, "k");
ffi::lua_setfield(state, -2, cstr!("__mode")); // hooktable.__mode = "k"
ffi::lua_pushvalue(state, -1);
ffi::lua_setmetatable(state, -2); // metatable(hooktable) = hooktable
}

ffi::lua_pushthread(thread_state);
ffi::lua_xmove(thread_state, state, 1); // key (thread)
let callback: HookCallback = Box::new(callback);
let _ = push_internal_userdata(state, callback, false); // value (hook callback)
ffi::lua_rawset(state, -3); // hooktable[thread] = hook callback
})?;

ffi::lua_sethook(thread_state, Some(hook_proc), triggers.mask(), triggers.count());

Ok(())
}

/// See [`Lua::create_string`]
Expand Down Expand Up @@ -497,6 +559,11 @@ impl RawLua {
} else {
protect_lua!(state, 0, 1, |state| ffi::lua_newthread(state))?
};

// Inherit global hook if set
#[cfg(not(feature = "luau"))]
self.set_thread_hook(thread_state, HookKind::Global)?;

let thread = Thread(self.pop_ref(), thread_state);
ffi::lua_xpush(self.ref_thread(), thread_state, func.0.index);
Ok(thread)
Expand Down
20 changes: 15 additions & 5 deletions src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
#[cfg(not(feature = "luau"))]
use crate::{
hook::{Debug, HookTriggers},
types::MaybeSend,
types::HookKind,
};

#[cfg(feature = "async")]
Expand Down Expand Up @@ -262,16 +262,26 @@ impl Thread {
/// Sets a hook function that will periodically be called as Lua code executes.
///
/// This function is similar or [`Lua::set_hook`] except that it sets for the thread.
/// To remove a hook call [`Lua::remove_hook`].
/// You can have multiple hooks for different threads.
///
/// To remove a hook call [`Thread::remove_hook`].
#[cfg(not(feature = "luau"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F) -> Result<()>
where
F: Fn(&crate::Lua, Debug) -> Result<crate::VmState> + MaybeSend + 'static,
F: Fn(&crate::Lua, Debug) -> Result<crate::VmState> + crate::MaybeSend + 'static,
{
let lua = self.0.lua.lock();
unsafe { lua.set_thread_hook(self.state(), HookKind::Thread(triggers, Box::new(callback))) }
}

/// Removes any hook function from this thread.
#[cfg(not(feature = "luau"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn remove_hook(&self) {
let _lua = self.0.lua.lock();
unsafe {
lua.set_thread_hook(self.state(), triggers, callback);
ffi::lua_sethook(self.state(), None, 0, 0);
}
}

Expand Down
17 changes: 11 additions & 6 deletions src/types.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::cell::UnsafeCell;
use std::os::raw::{c_int, c_void};
use std::rc::Rc;

use crate::error::Result;
#[cfg(not(feature = "luau"))]
use crate::hook::Debug;
use crate::hook::{Debug, HookTriggers};
use crate::state::{ExtraData, Lua, RawLua};

// Re-export mutex wrappers
Expand Down Expand Up @@ -73,17 +72,23 @@ pub enum VmState {
Yield,
}

#[cfg(not(feature = "luau"))]
pub(crate) enum HookKind {
Global,
Thread(HookTriggers, HookCallback),
}

#[cfg(all(feature = "send", not(feature = "luau")))]
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;
pub(crate) type HookCallback = Box<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;

#[cfg(all(not(feature = "send"), not(feature = "luau")))]
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState>>;
pub(crate) type HookCallback = Box<dyn Fn(&Lua, Debug) -> Result<VmState>>;

#[cfg(all(feature = "send", feature = "luau"))]
pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;
pub(crate) type InterruptCallback = std::rc::Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;

#[cfg(all(not(feature = "send"), feature = "luau"))]
pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState>>;
pub(crate) type InterruptCallback = std::rc::Rc<dyn Fn(&Lua) -> Result<VmState>>;

#[cfg(all(feature = "send", feature = "lua54"))]
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &str, bool) -> Result<()> + Send>;
Expand Down
Loading

0 comments on commit 74b4601

Please sign in to comment.