diff --git a/benches/chained_spawn.rs b/benches/chained_spawn.rs index da4a756..4ac5a71 100644 --- a/benches/chained_spawn.rs +++ b/benches/chained_spawn.rs @@ -37,17 +37,17 @@ mod yatp_future { use criterion::*; use std::sync::mpsc; use yatp::task::future::TaskCell; - use yatp::Remote; + use yatp::Handle; pub fn chained_spawn(b: &mut Bencher<'_>, iter_count: usize) { let pool = yatp::Builder::new("chained_spawn").build_future_pool(); - fn iter(remote: Remote, done_tx: mpsc::SyncSender<()>, n: usize) { + fn iter(handle: Handle, done_tx: mpsc::SyncSender<()>, n: usize) { if n == 0 { done_tx.send(()).unwrap(); } else { - let s2 = remote.clone(); - remote.spawn(async move { + let s2 = handle.clone(); + handle.spawn(async move { iter(s2, done_tx, n - 1); }); } @@ -57,9 +57,9 @@ mod yatp_future { b.iter(move || { let done_tx = done_tx.clone(); - let remote = pool.remote().clone(); + let handle = pool.handle().clone(); pool.spawn(async move { - iter(remote, done_tx, iter_count); + iter(handle, done_tx, iter_count); }); done_rx.recv().unwrap(); diff --git a/benches/ping_pong.rs b/benches/ping_pong.rs index 59ee155..38dff79 100644 --- a/benches/ping_pong.rs +++ b/benches/ping_pong.rs @@ -65,20 +65,20 @@ mod yatp_future { let rem = rem.clone(); rem.store(ping_count, Ordering::Relaxed); - let remote = pool.remote().clone(); + let handle = pool.handle().clone(); pool.spawn(async move { for _ in 0..ping_count { let rem = rem.clone(); let done_tx = done_tx.clone(); - let remote2 = remote.clone(); + let handle2 = handle.clone(); - remote.spawn(async move { + handle.spawn(async move { let (tx1, rx1) = oneshot::channel(); let (tx2, rx2) = oneshot::channel(); - remote2.spawn(async move { + handle2.spawn(async move { rx1.await.unwrap(); tx2.send(()).unwrap(); }); diff --git a/src/lib.rs b/src/lib.rs index c53c1ce..90dcf42 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,4 +9,4 @@ pub mod pool; pub mod queue; pub mod task; -pub use self::pool::{Builder, Remote, ThreadPool}; +pub use self::pool::{Builder, Handle, ThreadPool}; diff --git a/src/pool.rs b/src/pool.rs index d3ace9d..2c897de 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -12,7 +12,7 @@ mod worker; pub use self::builder::{Builder, SchedConfig}; pub use self::runner::{CloneRunnerBuilder, Runner, RunnerBuilder}; -pub use self::spawn::{build_spawn, Local, Remote}; +pub use self::spawn::{build_spawn, Handle, Local}; use crate::queue::{TaskCell, WithExtras}; use std::mem; @@ -20,37 +20,37 @@ use std::sync::Mutex; use std::thread::JoinHandle; /// A generic thread pool. -pub struct ThreadPool { - remote: Remote, +pub struct ThreadPool { + handle: Handle, threads: Mutex>>, } -impl ThreadPool { +impl ThreadPool { /// Spawns the task into the thread pool. /// /// If the pool is shutdown, it becomes no-op. pub fn spawn(&self, t: impl WithExtras) { - self.remote.spawn(t); + self.handle.spawn(t); } /// Shutdowns the pool. /// /// Closes the queue and wait for all threads to exit. pub fn shutdown(&self) { - self.remote.stop(); + self.handle.stop(); let mut threads = mem::replace(&mut *self.threads.lock().unwrap(), Vec::new()); for j in threads.drain(..) { j.join().unwrap(); } } - /// Get a remote queue for spawning tasks without owning the thread pool. - pub fn remote(&self) -> &Remote { - &self.remote + /// Get a handle for spawning tasks without owning the thread pool. + pub fn handle(&self) -> &Handle { + &self.handle } } -impl Drop for ThreadPool { +impl Drop for ThreadPool { /// Will shutdown the thread pool if it has not. fn drop(&mut self) { self.shutdown(); diff --git a/src/pool/builder.rs b/src/pool/builder.rs index 9b2fa0e..20c137d 100644 --- a/src/pool/builder.rs +++ b/src/pool/builder.rs @@ -2,8 +2,8 @@ use crate::pool::spawn::QueueCore; use crate::pool::worker::WorkerThread; -use crate::pool::{CloneRunnerBuilder, Local, Remote, Runner, RunnerBuilder, ThreadPool}; -use crate::queue::{self, LocalQueue, QueueType, TaskCell}; +use crate::pool::{CloneRunnerBuilder, Handle, Local, Runner, RunnerBuilder, ThreadPool}; +use crate::queue::{self, LocalQueueBuilder, QueueType, TaskCell}; use crate::task::{callback, future}; use std::sync::{Arc, Mutex}; use std::thread; @@ -49,10 +49,10 @@ impl Default for SchedConfig { pub struct LazyBuilder { builder: Builder, core: Arc>, - local_queues: Vec>, + local_queue_builders: Vec>, } -impl LazyBuilder { +impl LazyBuilder { /// Sets the name prefix of threads. The thread name will follow the /// format "prefix-index". pub fn name(mut self, name_prefix: impl Into) -> LazyBuilder { @@ -61,7 +61,7 @@ impl LazyBuilder { } } -impl LazyBuilder +impl LazyBuilder where T: TaskCell + Send + 'static, { @@ -76,25 +76,26 @@ where F::Runner: Runner + Send + 'static, { let mut threads = Vec::with_capacity(self.builder.sched_config.max_thread_count); - for (i, local_queue) in self.local_queues.into_iter().enumerate() { + for (i, queue_builder) in self.local_queue_builders.into_iter().enumerate() { let runner = factory.build(); let name = format!("{}-{}", self.builder.name_prefix, i); let mut builder = thread::Builder::new().name(name); if let Some(size) = self.builder.stack_size { builder = builder.stack_size(size) } - let local = Local::new(i + 1, local_queue, self.core.clone()); - let thd = WorkerThread::new(local, runner); + let core = self.core.clone(); threads.push( builder .spawn(move || { + let local = Local::new(i + 1, queue_builder(), core); + let thd = WorkerThread::new(local, runner); thd.run(); }) .unwrap(), ); } ThreadPool { - remote: Remote::new(self.core.clone()), + handle: Handle::new(self.core.clone()), threads: Mutex::new(threads), } } @@ -183,9 +184,9 @@ impl Builder { /// In some cases, especially building up a large application, a task /// scheduler is required before spawning new threads. You can use this /// to separate the construction and starting. - pub fn freeze(&self) -> (Remote, LazyBuilder) + pub fn freeze(&self) -> (Handle, LazyBuilder) where - T: TaskCell + Send, + T: TaskCell + Send + 'static, { self.freeze_with_queue(QueueType::SingleLevel) } @@ -199,20 +200,21 @@ impl Builder { /// In some cases, especially building up a large application, a task /// scheduler is required before spawning new threads. You can use this /// to separate the construction and starting. - pub fn freeze_with_queue(&self, queue_type: QueueType) -> (Remote, LazyBuilder) + pub fn freeze_with_queue(&self, queue_type: QueueType) -> (Handle, LazyBuilder) where - T: TaskCell + Send, + T: TaskCell + Send + 'static, { assert!(self.sched_config.min_thread_count <= self.sched_config.max_thread_count); - let (injector, local_queues) = queue::build(queue_type, self.sched_config.max_thread_count); + let (injector, local_queue_builders) = + queue::build(queue_type, self.sched_config.max_thread_count); let core = Arc::new(QueueCore::new(injector, self.sched_config.clone())); ( - Remote::new(core.clone()), + Handle::new(core.clone()), LazyBuilder { builder: self.clone(), core, - local_queues, + local_queue_builders, }, ) } diff --git a/src/pool/spawn.rs b/src/pool/spawn.rs index 5249757..4c3d434 100644 --- a/src/pool/spawn.rs +++ b/src/pool/spawn.rs @@ -5,8 +5,10 @@ //! tasks waiting to be handled. use crate::pool::SchedConfig; -use crate::queue::{Extras, LocalQueue, Pop, TaskCell, TaskInjector, WithExtras}; +use crate::queue::{Extras, LocalInjector, LocalQueue, Pop, TaskCell, TaskInjector, WithExtras}; use parking_lot_core::{ParkResult, ParkToken, UnparkToken}; +use std::cell::Cell; +use std::ptr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -31,7 +33,7 @@ pub fn is_shutdown(cnt: usize) -> bool { /// The core of queues. /// /// Every thread pool instance should have one and only `QueueCore`. It's -/// saved in an `Arc` and shared between all worker threads and remote handles. +/// saved in an `Arc` and shared between all worker threads and handles. pub(crate) struct QueueCore { global_queue: TaskInjector, active_workers: AtomicUsize, @@ -119,7 +121,7 @@ impl QueueCore { } } -impl QueueCore { +impl QueueCore { /// Pushes the task to global queue. /// /// `source` is used to trace who triggers the action. @@ -133,21 +135,58 @@ impl QueueCore { } } +thread_local! { + static LOCAL_INJECTOR: Cell = Cell::new(TlsLocalInjector::uninit()); +} + +#[derive(Copy, Clone)] +struct TlsLocalInjector { + pool_id: usize, + injector_ptr: *mut (), +} + +impl TlsLocalInjector { + const fn uninit() -> TlsLocalInjector { + TlsLocalInjector { + pool_id: 0, + injector_ptr: ptr::null_mut(), + } + } +} + /// Submits tasks to associated thread pool. /// -/// Note that thread pool can be shutdown and dropped even not all remotes are +/// Note that thread pool can be shutdown and dropped even not all handles are /// dropped. -pub struct Remote { +pub struct Handle { core: Arc>, } -impl Remote { - pub(crate) fn new(core: Arc>) -> Remote { - Remote { core } +impl Handle { + pub(crate) fn new(core: Arc>) -> Handle { + Handle { core } } /// Submits a task to the thread pool. pub fn spawn(&self, task: impl WithExtras) { + let t = task.with_extras(|| self.core.default_extras()); + LOCAL_INJECTOR.with(|c| { + let tls_injector = c.get(); + if tls_injector.pool_id == &*self.core as *const _ as usize { + unsafe { + Box::leak(Box::from_raw( + tls_injector.injector_ptr as *mut LocalInjector, + )) + .push(t); + } + } else { + self.core.push(0, t); + } + }) + } + + /// Spawns a task to the remote queue. + pub fn spawn_remote(&self, task: impl WithExtras) { let t = task.with_extras(|| self.core.default_extras()); self.core.push(0, t); } @@ -157,21 +196,21 @@ impl Remote { } } -impl Clone for Remote { - fn clone(&self) -> Remote { - Remote { +impl Clone for Handle { + fn clone(&self) -> Handle { + Handle { core: self.core.clone(), } } } -/// Note that implements of Runner assumes `Remote` is `Sync` and `Send`. +/// Note that implements of Runner assumes `handle` is `Sync` and `Send`. /// So we need to use assert trait to ensure the constraint at compile time /// to avoid future breaks. trait AssertSync: Sync {} -impl AssertSync for Remote {} +impl AssertSync for Handle {} trait AssertSend: Send {} -impl AssertSend for Remote {} +impl AssertSend for Handle {} /// Spawns tasks to the associated thread pool. /// @@ -184,8 +223,14 @@ pub struct Local { core: Arc>, } -impl Local { +impl Local { pub(crate) fn new(id: usize, local_queue: LocalQueue, core: Arc>) -> Local { + let local_injector = Box::new(local_queue.local_injector()); + let tls_injector = TlsLocalInjector { + pool_id: &*core as *const _ as usize, + injector_ptr: Box::into_raw(local_injector) as *mut (), + }; + LOCAL_INJECTOR.with(|c| c.set(tls_injector)); Local { id, local_queue, @@ -205,9 +250,9 @@ impl Local { self.core.push(self.id, t); } - /// Gets a remote so that tasks can be spawned from other threads. - pub fn remote(&self) -> Remote { - Remote { + /// Gets a handle so that tasks can be spawned from other threads. + pub fn handle(&self) -> Handle { + Handle { core: self.core.clone(), } } @@ -255,16 +300,16 @@ impl Local { } } -/// Building remotes and locals from the given queue and configuration. +/// Building handles and locals from the given queue and configuration. /// /// This is only for tests purpose so that a thread pool doesn't have to be /// spawned to test a Runner. pub fn build_spawn( queue_type: impl Into, config: SchedConfig, -) -> (Remote, Vec>) +) -> (Handle, Vec>) where - T: TaskCell + Send, + T: TaskCell + Send + 'static, { let queue_type = queue_type.into(); let (global, locals) = crate::queue::build(queue_type, config.max_thread_count); @@ -272,8 +317,8 @@ where let l = locals .into_iter() .enumerate() - .map(|(i, l)| Local::new(i + 1, l, core.clone())) + .map(|(i, builder)| Local::new(i + 1, builder(), core.clone())) .collect(); - let g = Remote::new(core); + let g = Handle::new(core); (g, l) } diff --git a/src/pool/tests.rs b/src/pool/tests.rs index 4e9ca67..a5192f9 100644 --- a/src/pool/tests.rs +++ b/src/pool/tests.rs @@ -78,21 +78,21 @@ fn test_basic() { } #[test] -fn test_remote() { - let pool = Builder::new("test_remote") +fn test_handle() { + let pool = Builder::new("test_handle") .max_thread_count(4) .build_callback_pool(); - // Remote should work just like pool. - let remote = pool.remote(); + // Handle should work just like pool. + let handle = pool.handle(); let (tx, rx) = mpsc::channel(); let t = tx.clone(); - remote.spawn(move |_: &mut Handle<'_>| t.send(1).unwrap()); + handle.spawn(move |_: &mut Handle<'_>| t.send(1).unwrap()); assert_eq!(Ok(1), rx.recv_timeout(Duration::from_millis(500))); // Shutdown should stop processing tasks. pool.shutdown(); - remote.spawn(move |_: &mut Handle<'_>| tx.send(2).unwrap()); + handle.spawn(move |_: &mut Handle<'_>| tx.send(2).unwrap()); let res = rx.recv_timeout(Duration::from_millis(500)); assert_eq!(res, Err(mpsc::RecvTimeoutError::Timeout)); } diff --git a/src/pool/worker.rs b/src/pool/worker.rs index 9ccd94e..5787043 100644 --- a/src/pool/worker.rs +++ b/src/pool/worker.rs @@ -17,7 +17,7 @@ impl WorkerThread { impl WorkerThread where - T: TaskCell + Send, + T: TaskCell + Send + 'static, R: Runner, { #[inline] @@ -55,7 +55,7 @@ where mod tests { use super::*; use crate::pool::spawn::*; - use crate::queue::QueueType; + use crate::queue::{self, QueueType}; use crate::task::callback; use std::sync::*; use std::time::*; @@ -121,9 +121,13 @@ mod tests { }; let metrics = r.metrics.clone(); let mut expected_metrics = Metrics::default(); - let (injector, mut locals) = build_spawn(QueueType::SingleLevel, Default::default()); - let th = WorkerThread::new(locals.remove(0), r); - let handle = std::thread::spawn(move || { + let (injector, mut local_builders) = queue::build(QueueType::SingleLevel, num_cpus::get()); + let core = Arc::new(QueueCore::new(injector, Default::default())); + let handle = Handle::new(core.clone()); + let local_builder = local_builders.remove(0); + let join_handle = std::thread::spawn(move || { + let local = Local::new(1, local_builder(), core); + let th = WorkerThread::new(local, r); th.run(); }); rx.recv_timeout(Duration::from_secs(1)).unwrap(); @@ -131,15 +135,15 @@ mod tests { expected_metrics.pause = 1; assert_eq!(expected_metrics, *metrics.lock().unwrap()); - injector.spawn(move |_: &mut callback::Handle<'_>| {}); + handle.spawn(move |_: &mut callback::Handle<'_>| {}); rx.recv_timeout(Duration::from_secs(1)).unwrap(); expected_metrics.pause = 2; expected_metrics.handle = 1; expected_metrics.resume = 1; assert_eq!(expected_metrics, *metrics.lock().unwrap()); - injector.stop(); - handle.join().unwrap(); + handle.stop(); + join_handle.join().unwrap(); expected_metrics.resume = 2; expected_metrics.end = 1; assert_eq!(expected_metrics, *metrics.lock().unwrap()); diff --git a/src/queue.rs b/src/queue.rs index 0f49bfb..4f9fb7d 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -44,7 +44,7 @@ enum InjectorInner { Multilevel(multilevel::TaskInjector), } -impl TaskInjector { +impl TaskInjector { /// Pushes a task to the queue. pub fn push(&self, task_cell: T) { match &self.0 { @@ -59,6 +59,22 @@ impl TaskInjector { InjectorInner::Multilevel(_) => Extras::multilevel_default(), } } + + #[cfg(test)] + pub fn into_single_level(self) -> single_level::TaskInjector { + match self.0 { + InjectorInner::SingleLevel(inj) => inj, + _ => unreachable!(), + } + } + + #[cfg(test)] + pub fn into_multilevel(self) -> multilevel::TaskInjector { + match self.0 { + InjectorInner::Multilevel(inj) => inj, + _ => unreachable!(), + } + } } /// Popped task cell from a task queue. @@ -106,6 +122,43 @@ impl LocalQueue { LocalQueueInner::Multilevel(_) => Extras::multilevel_default(), } } + + pub fn local_injector(&self) -> LocalInjector { + match &self.0 { + LocalQueueInner::SingleLevel(q) => LocalInjector::SingleLevel(q.local_injector()), + LocalQueueInner::Multilevel(q) => LocalInjector::Multilevel(q.local_injector()), + } + } + + #[cfg(test)] + pub fn into_single_level(self) -> single_level::LocalQueue { + match self.0 { + LocalQueueInner::SingleLevel(q) => q, + _ => unreachable!(), + } + } + + #[cfg(test)] + pub fn into_multilevel(self) -> multilevel::LocalQueue { + match self.0 { + LocalQueueInner::Multilevel(q) => q, + _ => unreachable!(), + } + } +} + +pub(crate) enum LocalInjector { + SingleLevel(single_level::LocalInjector), + Multilevel(multilevel::LocalInjector), +} + +impl LocalInjector { + pub(crate) fn push(&self, task: T) { + match self { + LocalInjector::SingleLevel(inj) => inj.push(task), + LocalInjector::Multilevel(inj) => inj.push(task), + } + } } /// Supported available queues. @@ -130,21 +183,14 @@ impl From for QueueType { } } -pub(crate) fn build(ty: QueueType, local_num: usize) -> (TaskInjector, Vec>) { +pub(crate) fn build( + ty: QueueType, + local_num: usize, +) -> (TaskInjector, Vec>) { match ty { - QueueType::SingleLevel => single_level(local_num), + QueueType::SingleLevel => single_level::create(local_num), QueueType::Multilevel(b) => b.build(local_num), } } -/// Creates a task queue that allows given number consumers. -fn single_level(local_num: usize) -> (TaskInjector, Vec>) { - let (injector, locals) = single_level::create(local_num); - ( - TaskInjector(InjectorInner::SingleLevel(injector)), - locals - .into_iter() - .map(|i| LocalQueue(LocalQueueInner::SingleLevel(i))) - .collect(), - ) -} +pub(crate) type LocalQueueBuilder = Box LocalQueue + Send>; diff --git a/src/queue/multilevel.rs b/src/queue/multilevel.rs index ff7f1d2..89bbdc2 100644 --- a/src/queue/multilevel.rs +++ b/src/queue/multilevel.rs @@ -6,7 +6,7 @@ //! The task queue requires that the accompanying [`MultilevelRunner`] must be //! used to collect necessary information. -use super::{Pop, TaskCell}; +use super::{LocalQueueBuilder, Pop, TaskCell}; use crate::metrics::*; use crate::pool::{Local, Runner, RunnerBuilder}; @@ -16,6 +16,7 @@ use prometheus::local::LocalIntCounter; use prometheus::{Gauge, IntCounter}; use rand::prelude::*; use std::cell::Cell; +use std::rc::Rc; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst}; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; @@ -81,7 +82,7 @@ where /// The local queue of a multilevel task queue. pub(crate) struct LocalQueue { - local_queue: Worker, + local_queue: Rc>, level_injectors: Arc<[Injector; LEVEL_NUM]>, stealers: Vec>, manager: Arc, @@ -163,6 +164,25 @@ where } None } + + pub(super) fn local_injector(&self) -> LocalInjector { + LocalInjector { + manager: self.manager.clone(), + worker: self.local_queue.clone(), + } + } +} + +pub(crate) struct LocalInjector { + manager: Arc, + worker: Rc>, +} + +impl LocalInjector { + pub(super) fn push(&self, mut task: T) { + self.manager.prepare_before_push(&mut task); + self.worker.push(task); + } } /// The runner builder for multilevel task queues. @@ -527,7 +547,11 @@ impl Builder { } } - fn build_raw(self, local_num: usize) -> (TaskInjector, Vec>) { + /// Creates the injector and local queues of the multilevel task queue. + pub(crate) fn build( + self, + local_num: usize, + ) -> (super::TaskInjector, Vec>) { let level_injectors: Arc<[Injector; LEVEL_NUM]> = Arc::new([Injector::new(), Injector::new(), Injector::new()]); let workers: Vec<_> = iter::repeat_with(Worker::new_lifo) @@ -537,48 +561,37 @@ impl Builder { let locals = workers .into_iter() .enumerate() - .map(|(self_index, local_queue)| { - let mut stealers: Vec<_> = stealers - .iter() - .enumerate() - .filter(|(index, _)| *index != self_index) - .map(|(_, stealer)| stealer.clone()) - .collect(); - // Steal with a random start to avoid imbalance. - stealers.shuffle(&mut thread_rng()); - LocalQueue { - local_queue, - level_injectors: level_injectors.clone(), - stealers, - manager: self.manager.clone(), - } - }) + .map( + |(self_index, local_queue)| -> Box super::LocalQueue + Send> { + let mut stealers: Vec<_> = stealers + .iter() + .enumerate() + .filter(|(index, _)| *index != self_index) + .map(|(_, stealer)| stealer.clone()) + .collect(); + // Steal with a random start to avoid imbalance. + stealers.shuffle(&mut thread_rng()); + let manager = self.manager.clone(); + let level_injectors = level_injectors.clone(); + Box::new(move || { + super::LocalQueue(super::LocalQueueInner::Multilevel(LocalQueue { + local_queue: Rc::new(local_queue), + level_injectors, + stealers, + manager, + })) + }) + }, + ) .collect(); - ( - TaskInjector { + super::TaskInjector(super::InjectorInner::Multilevel(TaskInjector { level_injectors, manager: self.manager, - }, + })), locals, ) } - - /// Creates the injector and local queues of the multilevel task queue. - pub(crate) fn build( - self, - local_num: usize, - ) -> (super::TaskInjector, Vec>) { - let (injector, locals) = self.build_raw(local_num); - let local_queues = locals - .into_iter() - .map(|local| super::LocalQueue(super::LocalQueueInner::Multilevel(local))) - .collect(); - ( - super::TaskInjector(super::InjectorInner::Multilevel(injector)), - local_queues, - ) - } } thread_local!(static RECENT_NOW: Cell = Cell::new(Instant::now())); @@ -691,7 +704,8 @@ mod tests { const SLEEP_DUR: Duration = Duration::from_millis(5); let builder = Builder::new(Config::default()); - let (injector, mut locals) = builder.build(1); + let (injector, locals) = builder.build(1); + let mut locals: Vec<_> = locals.into_iter().map(|b| b()).collect(); injector.push(MockTask::new(0, Extras::multilevel_default())); thread::sleep(SLEEP_DUR); let schedule_time = locals[0].pop().unwrap().schedule_time; @@ -704,7 +718,8 @@ mod tests { Config::default() .level_time_threshold([Duration::from_millis(1), Duration::from_millis(100)]), ); - let (injector, _) = builder.build_raw(1); + let (injector, _) = builder.build(1); + let injector = injector.into_multilevel(); // Running time is 50us. It should be pushed to level 0. let extras = Extras { @@ -775,7 +790,9 @@ mod tests { #[test] fn test_pop_by_stealing_injector() { let builder = Builder::new(Config::default()); - let (injector, mut locals) = builder.build(3); + let (injector, locals) = builder.build(3); + let injector = injector.into_multilevel(); + let mut locals: Vec<_> = locals.into_iter().map(|b| b()).collect(); for i in 0..100 { injector.push(MockTask::new(i, Extras::multilevel_default())); } @@ -789,7 +806,9 @@ mod tests { #[test] fn test_pop_by_steal_others() { let builder = Builder::new(Config::default()); - let (injector, mut locals) = builder.build_raw(3); + let (injector, locals) = builder.build(3); + let injector = injector.into_multilevel(); + let mut locals: Vec<_> = locals.into_iter().map(|b| b().into_multilevel()).collect(); for i in 0..50 { injector.push(MockTask::new(i, Extras::multilevel_default())); } @@ -819,9 +838,10 @@ mod tests { let sum = Arc::new(AtomicU64::new(0)); let handles: Vec<_> = locals .into_iter() - .map(|mut consumer| { + .map(|builder| { let sum = sum.clone(); thread::spawn(move || { + let mut consumer = builder(); while let Some(pop) = consumer.pop() { sum.fetch_add(pop.task_cell.sleep_ms, SeqCst); } @@ -839,10 +859,10 @@ mod tests { let builder = Builder::new(Config::default()); let mut runner_builder = builder.runner_builder(MockRunnerBuilder); let manager = builder.manager.clone(); - let (remote, mut locals) = build_spawn(builder, Default::default()); + let (handle, mut locals) = build_spawn(builder, Default::default()); let mut runner = runner_builder.build(); - remote.spawn(MockTask::new(100, Extras::new_multilevel(1, None))); + handle.spawn(MockTask::new(100, Extras::new_multilevel(1, None))); if let Some(Pop { task_cell, .. }) = locals[0].pop() { assert!(runner.handle(&mut locals[0], task_cell)); } diff --git a/src/queue/single_level.rs b/src/queue/single_level.rs index 46d0fce..0a6daa6 100644 --- a/src/queue/single_level.rs +++ b/src/queue/single_level.rs @@ -5,11 +5,12 @@ //! The instant when the task cell is pushed into the queue is recorded //! in the extras. -use super::{Pop, TaskCell}; +use super::{LocalQueueBuilder, Pop, TaskCell}; use crossbeam_deque::{Injector, Steal, Stealer, Worker}; use rand::prelude::*; use std::iter; +use std::rc::Rc; use std::sync::Arc; use std::time::Instant; @@ -43,7 +44,7 @@ where /// The local queue of a single level work stealing task queue. pub struct LocalQueue { - local_queue: Worker, + local_queue: Rc>, injector: Arc>, stealers: Vec>, } @@ -102,10 +103,25 @@ where } None } + + pub(super) fn local_injector(&self) -> LocalInjector { + LocalInjector(self.local_queue.clone()) + } +} + +pub(crate) struct LocalInjector(Rc>); + +impl LocalInjector { + pub(super) fn push(&self, mut task: T) { + set_schedule_time(&mut task); + self.0.push(task); + } } /// Creates a single level work stealing task queue with `local_num` local queues. -pub fn create(local_num: usize) -> (TaskInjector, Vec>) { +pub(crate) fn create( + local_num: usize, +) -> (super::TaskInjector, Vec>) { let injector = Arc::new(Injector::new()); let workers: Vec<_> = iter::repeat_with(Worker::new_lifo) .take(local_num) @@ -114,24 +130,32 @@ pub fn create(local_num: usize) -> (TaskInjector, Vec>) { let local_queues = workers .into_iter() .enumerate() - .map(|(self_index, local_queue)| { - let mut stealers: Vec<_> = stealers - .iter() - .enumerate() - .filter(|(index, _)| *index != self_index) - .map(|(_, stealer)| stealer.clone()) - .collect(); - // Steal with a random start to avoid imbalance. - stealers.shuffle(&mut thread_rng()); - LocalQueue { - local_queue, - injector: injector.clone(), - stealers, - } - }) + .map( + |(self_index, local_queue)| -> Box super::LocalQueue + Send> { + let mut stealers: Vec<_> = stealers + .iter() + .enumerate() + .filter(|(index, _)| *index != self_index) + .map(|(_, stealer)| stealer.clone()) + .collect(); + // Steal with a random start to avoid imbalance. + stealers.shuffle(&mut thread_rng()); + let injector = injector.clone(); + Box::new(move || { + super::LocalQueue(super::LocalQueueInner::SingleLevel(LocalQueue { + local_queue: Rc::new(local_queue), + injector, + stealers, + })) + }) + }, + ) .collect(); - (TaskInjector(injector), local_queues) + ( + super::TaskInjector(super::InjectorInner::SingleLevel(TaskInjector(injector))), + local_queues, + ) } #[cfg(test)] @@ -168,7 +192,8 @@ mod tests { fn test_schedule_time_is_set() { const SLEEP_DUR: Duration = Duration::from_millis(5); - let (injector, mut locals) = super::create(1); + let (injector, locals) = super::create(1); + let mut locals: Vec<_> = locals.into_iter().map(|b| b()).collect(); injector.push(MockCell::new(0)); thread::sleep(SLEEP_DUR); let schedule_time = locals[0].pop().unwrap().schedule_time; @@ -177,7 +202,8 @@ mod tests { #[test] fn test_pop_by_stealing_injector() { - let (injector, mut locals) = super::create(3); + let (injector, locals) = super::create(3); + let mut locals: Vec<_> = locals.into_iter().map(|b| b()).collect(); for i in 0..100 { injector.push(MockCell::new(i)); } @@ -190,7 +216,12 @@ mod tests { #[test] fn test_pop_by_steal_others() { - let (injector, mut locals) = super::create(3); + let (injector, locals) = super::create(3); + let injector = injector.into_single_level(); + let mut locals: Vec<_> = locals + .into_iter() + .map(|b| b().into_single_level()) + .collect(); for i in 0..50 { injector.push(MockCell::new(i)); } @@ -215,9 +246,10 @@ mod tests { let sum = Arc::new(AtomicI32::new(0)); let handles: Vec<_> = locals .into_iter() - .map(|mut consumer| { + .map(|builder| { let sum = sum.clone(); thread::spawn(move || { + let mut consumer = builder(); while let Some(pop) = consumer.pop() { sum.fetch_add(pop.task_cell.value, Ordering::SeqCst); } diff --git a/src/task/future.rs b/src/task/future.rs index 58d6fdb..f91715d 100644 --- a/src/task/future.rs +++ b/src/task/future.rs @@ -2,7 +2,7 @@ //! A [`Future`]. -use crate::pool::{Local, Remote}; +use crate::pool::{Handle, Local}; use crate::queue::{Extras, WithExtras}; use std::borrow::Cow; @@ -21,7 +21,7 @@ const DEFAULT_REPOLL_LIMIT: usize = 5; struct TaskExtras { extras: Extras, - remote: Option>, + handle: Option>, } /// A [`Future`] task. @@ -74,7 +74,7 @@ impl TaskCell { future: UnsafeCell::new(Box::pin(future)), extras: UnsafeCell::new(TaskExtras { extras, - remote: None, + handle: None, }), })) } @@ -160,54 +160,19 @@ unsafe fn task_cell(task: *const Task) -> TaskCell { #[inline] unsafe fn clone_task(task: *const Task) -> TaskCell { let task_cell = task_cell(task); - let extras = { &mut *task_cell.0.extras.get() }; - if extras.remote.is_none() { - LOCAL.with(|l| { - extras.remote = Some((&*l.get()).remote()); - }) - } mem::forget(task_cell.0.clone()); task_cell } -thread_local! { - /// Local queue reference that is set before polling and unset after polled. - static LOCAL: Cell<*mut Local> = Cell::new(std::ptr::null_mut()); -} - unsafe fn wake_task(task: Cow<'_, Arc>, reschedule: bool) { - LOCAL.with(|ptr| { - if ptr.get().is_null() { - // It's out of polling process, has to be spawn to global queue. - // It needs to clone to make it safe as it's unclear whether `self` - // is still used inside method `spawn` after `TaskCell` is dropped. - (*task.as_ref().extras.get()) - .remote - .as_ref() - .expect("remote should exist!!!") - .spawn(TaskCell(task.clone().into_owned())); - } else if reschedule { - // It's requested explicitly to schedule to global queue. - (*ptr.get()).spawn_remote(TaskCell(task.into_owned())); - } else { - // Otherwise spawns to local queue for best locality. - (*ptr.get()).spawn(TaskCell(task.into_owned())); - } - }) -} - -struct Scope<'a>(&'a mut Local); - -impl<'a> Scope<'a> { - fn new(l: &'a mut Local) -> Scope<'a> { - LOCAL.with(|c| c.set(l)); - Scope(l) - } -} - -impl<'a> Drop for Scope<'a> { - fn drop(&mut self) { - LOCAL.with(|c| c.set(std::ptr::null_mut())); + let handle = (*task.extras.get()) + .handle + .as_ref() + .expect("handle should exist!!!"); + if !reschedule { + handle.spawn(TaskCell(task.into_owned())); + } else { + handle.spawn_remote(TaskCell(task.into_owned())); } } @@ -243,7 +208,6 @@ impl crate::pool::Runner for Runner { type TaskCell = TaskCell; fn handle(&mut self, local: &mut Local, task_cell: TaskCell) -> bool { - let _scope = Scope::new(local); let task = task_cell.0; unsafe { let waker = ManuallyDrop::new(waker(&*task)); @@ -255,14 +219,11 @@ impl crate::pool::Runner for Runner { task.status.store(COMPLETED, SeqCst); return true; } - let extras = { &mut *task.extras.get() }; - if extras.remote.is_none() { - // It's possible to avoid assigning remote in some cases, but it requires - // at least one atomic load to detect such situation. So here just assign - // it to make things simple. - LOCAL.with(|l| { - extras.remote = Some((&*l.get()).remote()); - }) + { + let extras = { &mut *task.extras.get() }; + if extras.handle.is_none() { + extras.handle = Some(local.handle()); + } } match task.status.compare_exchange(POLLING, IDLE, SeqCst, SeqCst) { Ok(_) => return false, @@ -322,16 +283,16 @@ mod tests { struct MockLocal { runner: Rc>, - remote: Remote, + handle: Handle, locals: Vec>, } impl MockLocal { fn new(runner: Runner) -> MockLocal { - let (remote, locals) = build_spawn(QueueType::SingleLevel, Default::default()); + let (handle, locals) = build_spawn(QueueType::SingleLevel, Default::default()); MockLocal { runner: Rc::new(RefCell::new(runner)), - remote, + handle, locals, } } @@ -390,7 +351,7 @@ mod tests { WakeLater::new(waker_tx.clone()).await; res_tx.send(2).unwrap(); }; - local.remote.spawn(fut); + local.handle.spawn(fut); local.handle_once(); assert_eq!(res_rx.recv().unwrap(), 1); @@ -451,7 +412,7 @@ mod tests { PendingOnce::new().await; res_tx.send(2).unwrap(); }; - local.remote.spawn(fut); + local.handle.spawn(fut); local.handle_once(); assert_eq!(res_rx.recv().unwrap(), 1); @@ -472,7 +433,7 @@ mod tests { PendingOnce::new().await; res_tx.send(4).unwrap(); }; - local.remote.spawn(fut); + local.handle.spawn(fut); local.handle_once(); assert_eq!(res_rx.recv().unwrap(), 1); @@ -496,7 +457,7 @@ mod tests { PendingOnce::new().await; res_tx.send(3).unwrap(); }; - local.remote.spawn(fut); + local.handle.spawn(fut); local.handle_once(); assert_eq!(res_rx.recv().unwrap(), 1);