From 0fe48f3a9987aef2a34fd5437904cf560c59a2fa Mon Sep 17 00:00:00 2001 From: Radiant <69520693+RadiantUwU@users.noreply.github.com> Date: Fri, 24 Jan 2025 21:19:27 +0200 Subject: [PATCH] Implement thread creation deletion event callback. --- src/lib.rs | 2 ++ src/prelude.rs | 9 +++--- src/state.rs | 70 ++++++++++++++++++++++++++++++++++++++++++++++ src/state/extra.rs | 4 +++ src/state/raw.rs | 20 ++++++++++++- src/types.rs | 23 +++++++++++++++ 6 files changed, 123 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a404594c..4b68bf45 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -112,6 +112,8 @@ pub use crate::thread::{Thread, ThreadStatus}; pub use crate::traits::{ FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut, ObjectLike, }; +#[cfg(feature = "luau")] +pub use crate::types::ThreadEventInfo; pub use crate::types::{ AppDataRef, AppDataRefMut, Either, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState, }; diff --git a/src/prelude.rs b/src/prelude.rs index 68ba8f2f..54480159 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -20,13 +20,14 @@ pub use crate::{ #[doc(no_inline)] pub use crate::HookTriggers as LuaHookTriggers; -#[cfg(feature = "luau")] -#[doc(no_inline)] -pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector}; - #[cfg(feature = "async")] #[doc(no_inline)] pub use crate::{AsyncThread as LuaAsyncThread, LuaNativeAsyncFn}; +#[cfg(feature = "luau")] +#[doc(no_inline)] +pub use crate::{ + CoverageInfo as LuaCoverageInfo, ThreadEventInfo as LuaThreadEventInfo, Vector as LuaVector, +}; #[cfg(feature = "serialize")] #[doc(no_inline)] diff --git a/src/state.rs b/src/state.rs index 35d9e4ec..ec569e73 100644 --- a/src/state.rs +++ b/src/state.rs @@ -23,6 +23,10 @@ use crate::types::{ AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex, ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak, }; + +#[cfg(any(feature = "luau", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +use crate::types::ThreadEventInfo; use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage}; use crate::util::{ assert_stack, check_stack, protect_lua_closure, push_string, push_table, rawset_field, StackGuard, @@ -671,6 +675,72 @@ impl Lua { } } + /// Sets a callback that will be called by Luau whenever a thread is created/destroyed. + /// + /// Often used for keeping track of threads. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn set_thread_event_callback(&self, callback: F) + where + F: Fn(&Lua, ThreadEventInfo) -> Result<()> + MaybeSend + 'static, + { + use std::rc::Rc; + + unsafe extern "C-unwind" fn userthread_proc(parent: *mut ffi::lua_State, state: *mut ffi::lua_State) { + callback_error_ext(state, ptr::null_mut(), move |extra, _| { + let raw_lua: &RawLua = (*extra).raw_lua(); + let _guard = StateGuard::new(raw_lua, state); + + let userthread_cb = (*extra).userthread_callback.clone(); + let userthread_cb = + mlua_expect!(userthread_cb, "no userthread callback set in userthread_proc"); + if parent.is_null() { + raw_lua.push(Value::Nil).unwrap(); + } else { + raw_lua.push_ref_thread(parent).unwrap(); + } + if parent.is_null() { + let event_info = ThreadEventInfo::Destroyed(state.cast_const().cast()); + let main_state = raw_lua.main_state(); + if main_state == state { + return Ok(()); // Don't process Destroyed event on main thread. + } + let main_extra = ExtraData::get(main_state); + let main_raw_lua: &RawLua = (*main_extra).raw_lua(); + let _guard = StateGuard::new(main_raw_lua, state); + userthread_cb((*main_extra).lua(), event_info) + } else { + raw_lua.push_ref_thread(parent).unwrap(); + let event_info = match raw_lua.pop_value() { + Value::Thread(thr) => ThreadEventInfo::Created(thr), + _ => unimplemented!(), + }; + userthread_cb((*extra).lua(), event_info) + } + }); + } + + // Set interrupt callback + let lua = self.lock(); + unsafe { + (*lua.extra.get()).userthread_callback = Some(Rc::new(callback)); + (*ffi::lua_callbacks(lua.main_state())).userthread = Some(userthread_proc); + } + } + + /// Removes any thread event function previously set by `set_thread_event_callback`. + /// + /// This function has no effect if a callback was not previously set. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn remove_thread_event_callback(&self) { + let lua = self.lock(); + unsafe { + (*lua.extra.get()).userthread_callback = None; + (*ffi::lua_callbacks(lua.main_state())).userthread = None; + } + } + /// Sets the warning function to be used by Lua to emit warnings. /// /// Requires `feature = "lua54"` diff --git a/src/state/extra.rs b/src/state/extra.rs index d1823b5c..2937b106 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -80,6 +80,8 @@ pub(crate) struct ExtraData { pub(super) warn_callback: Option, #[cfg(feature = "luau")] pub(super) interrupt_callback: Option, + #[cfg(feature = "luau")] + pub(super) userthread_callback: Option, #[cfg(feature = "luau")] pub(super) sandboxed: bool, @@ -177,6 +179,8 @@ impl ExtraData { #[cfg(feature = "luau")] interrupt_callback: None, #[cfg(feature = "luau")] + userthread_callback: None, + #[cfg(feature = "luau")] sandboxed: false, #[cfg(feature = "luau")] compiler: None, diff --git a/src/state/raw.rs b/src/state/raw.rs index 0731f846..a77e395c 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -64,7 +64,10 @@ impl Drop for RawLua { } let mem_state = MemoryState::get(self.main_state()); - + #[cfg(feature = "luau")] // Fixes a crash during shutdown + { + (*ffi::lua_callbacks(self.main_state())).userthread = None; + } ffi::lua_close(self.main_state()); // Deallocate `MemoryState` @@ -556,6 +559,21 @@ impl RawLua { value.push_into_stack(self) } + pub(crate) unsafe fn push_ref_thread(&self, ref_thread: *mut ffi::lua_State) -> Result<()> { + let state = self.state(); + check_stack(state, 1)?; + let _sg = StackGuard::new(ref_thread); + check_stack(ref_thread, 1)?; + + if self.unlikely_memory_error() { + ffi::lua_pushthread(ref_thread) + } else { + protect_lua!(ref_thread, 0, 1, |ref_thread| ffi::lua_pushthread(ref_thread))? + }; + ffi::lua_xmove(ref_thread, self.state(), 1); + Ok(()) + } + /// Pushes a `Value` (by reference) onto the Lua stack. /// /// Uses 2 stack spaces, does not call `checkstack`. diff --git a/src/types.rs b/src/types.rs index afeb239d..1d6a12b5 100644 --- a/src/types.rs +++ b/src/types.rs @@ -7,6 +7,9 @@ use crate::error::Result; use crate::hook::Debug; use crate::state::{ExtraData, Lua, RawLua}; +#[cfg(any(feature = "luau", doc))] +use crate::thread::Thread; + // Re-export mutex wrappers pub(crate) use sync::{ArcReentrantMutexGuard, ReentrantMutex, ReentrantMutexGuard, XRc, XWeak}; @@ -73,6 +76,20 @@ pub enum VmState { Yield, } +/// Information about a thread event. +/// +/// For creating a thread, it contains the thread that created it. +/// +/// This is useful for tracking the origin of all threads. +#[cfg(any(feature = "luau", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +pub enum ThreadEventInfo { + /// When a thread is created, it contains the thread that created it. + Created(Thread), + /// When a thread is destroyed, it returns its .to_pointer representation. + Destroyed(*const c_void), +} + #[cfg(all(feature = "send", not(feature = "luau")))] pub(crate) type HookCallback = Rc Result + Send>; @@ -85,6 +102,12 @@ pub(crate) type InterruptCallback = Rc Result + Send>; #[cfg(all(not(feature = "send"), feature = "luau"))] pub(crate) type InterruptCallback = Rc Result>; +#[cfg(all(feature = "send", feature = "luau"))] +pub(crate) type ThreadEventCallback = Rc Result<()> + Send>; + +#[cfg(all(not(feature = "send"), feature = "luau"))] +pub(crate) type ThreadEventCallback = Rc Result<()>>; + #[cfg(all(feature = "send", feature = "lua54"))] pub(crate) type WarnCallback = Box Result<()> + Send>;