diff --git a/nativelink-config/src/schedulers.rs b/nativelink-config/src/schedulers.rs index bd5627da1..588b6021d 100644 --- a/nativelink-config/src/schedulers.rs +++ b/nativelink-config/src/schedulers.rs @@ -100,6 +100,12 @@ pub struct SimpleScheduler { #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] pub retain_completed_for_s: u32, + /// Mark operations as completed with error if no client has updated them + /// within this duration. + /// Default: 60 (seconds) + #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] + pub client_action_timeout_s: u64, + /// Remove workers from pool once the worker has not responded in this /// amount of time in seconds. /// Default: 5 (seconds) diff --git a/nativelink-scheduler/src/awaited_action_db/awaited_action.rs b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs index 8e9090267..9267c96fb 100644 --- a/nativelink-scheduler/src/awaited_action_db/awaited_action.rs +++ b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs @@ -66,6 +66,10 @@ pub struct AwaitedAction { #[metric(help = "The last time the worker updated the AwaitedAction")] last_worker_updated_timestamp: SystemTime, + /// The last time the client sent a keepalive message. + #[metric(help = "The last time the client sent a keepalive message")] + last_client_keepalive_timestamp: SystemTime, + /// Worker that is currently running this action, None if unassigned. #[metric(help = "The worker id of the AwaitedAction")] worker_id: Option, @@ -103,6 +107,7 @@ impl AwaitedAction { sort_key, attempts: 0, last_worker_updated_timestamp: now, + last_client_keepalive_timestamp: now, worker_id: None, state, } @@ -144,25 +149,33 @@ impl AwaitedAction { self.last_worker_updated_timestamp } - pub(crate) fn keep_alive(&mut self, now: SystemTime) { + pub(crate) fn worker_keep_alive(&mut self, now: SystemTime) { self.last_worker_updated_timestamp = now; } + pub(crate) fn last_client_keepalive_timestamp(&self) -> SystemTime { + self.last_client_keepalive_timestamp + } + pub(crate) fn update_client_keep_alive(&mut self, now: SystemTime) { + self.last_client_keepalive_timestamp = now; + } + + pub(crate) fn set_client_operation_id(&mut self, client_operation_id: OperationId) { + Arc::make_mut(&mut self.state).client_operation_id = client_operation_id; + } + /// Sets the worker id that is currently processing this action. pub(crate) fn set_worker_id(&mut self, new_maybe_worker_id: Option, now: SystemTime) { if self.worker_id != new_maybe_worker_id { self.worker_id = new_maybe_worker_id; - self.keep_alive(now); + self.worker_keep_alive(now); } } - /// Sets the current state of the action and notifies subscribers. - /// Returns true if the state was set, false if there are no subscribers. - pub fn set_state(&mut self, mut state: Arc, now: Option) { + /// Sets the current state of the action and updates the last worker updated timestamp. + pub fn worker_set_state(&mut self, mut state: Arc, now: SystemTime) { std::mem::swap(&mut self.state, &mut state); - if let Some(now) = now { - self.keep_alive(now); - } + self.worker_keep_alive(now); } } diff --git a/nativelink-scheduler/src/memory_awaited_action_db.rs b/nativelink-scheduler/src/memory_awaited_action_db.rs index 394b9549c..8e0102ba4 100644 --- a/nativelink-scheduler/src/memory_awaited_action_db.rs +++ b/nativelink-scheduler/src/memory_awaited_action_db.rs @@ -201,18 +201,14 @@ where // At this stage we know that this event is a client request, so we need // to populate the client_operation_id. let mut awaited_action = self.awaited_action_rx.borrow().clone(); - let mut state = awaited_action.state().as_ref().clone(); - state.client_operation_id = client_operation_id; - awaited_action.set_state(Arc::new(state), None); + awaited_action.set_client_operation_id(client_operation_id); Ok(awaited_action) } async fn borrow(&self) -> Result { let mut awaited_action = self.awaited_action_rx.borrow().clone(); if let Some(client_info) = self.client_info.as_ref() { - let mut state = awaited_action.state().as_ref().clone(); - state.client_operation_id = client_info.client_operation_id.clone(); - awaited_action.set_state(Arc::new(state), None); + awaited_action.set_client_operation_id(client_info.client_operation_id.clone()); } Ok(awaited_action) } diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index c344d2c29..ee4ce0491 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -44,6 +44,11 @@ use crate::worker_scheduler::WorkerScheduler; /// If this changes, remember to change the documentation in the config. const DEFAULT_WORKER_TIMEOUT_S: u64 = 5; +/// Mark operations as completed with error if no client has updated them +/// within this duration. +/// If this changes, remember to change the documentation in the config. +const DEFAULT_CLIENT_ACTION_TIMEOUT_S: u64 = 60; + /// Default times a job can retry before failing. /// If this changes, remember to change the documentation in the config. const DEFAULT_MAX_JOB_RETRIES: usize = 3; @@ -324,6 +329,11 @@ impl SimpleScheduler { worker_timeout_s = DEFAULT_WORKER_TIMEOUT_S; } + let mut client_action_timeout_s = scheduler_cfg.client_action_timeout_s; + if client_action_timeout_s == 0 { + client_action_timeout_s = DEFAULT_CLIENT_ACTION_TIMEOUT_S; + } + let mut max_job_retries = scheduler_cfg.max_job_retries; if max_job_retries == 0 { max_job_retries = DEFAULT_MAX_JOB_RETRIES; @@ -333,6 +343,7 @@ impl SimpleScheduler { let state_manager = SimpleSchedulerStateManager::new( max_job_retries, Duration::from_secs(worker_timeout_s), + Duration::from_secs(client_action_timeout_s), awaited_action_db, now_fn, ); diff --git a/nativelink-scheduler/src/simple_scheduler_state_manager.rs b/nativelink-scheduler/src/simple_scheduler_state_manager.rs index 573a9acf2..dce783623 100644 --- a/nativelink-scheduler/src/simple_scheduler_state_manager.rs +++ b/nativelink-scheduler/src/simple_scheduler_state_manager.rs @@ -19,7 +19,7 @@ use std::time::{Duration, SystemTime}; use async_lock::Mutex; use async_trait::async_trait; -use futures::{future, stream, FutureExt, StreamExt, TryStreamExt}; +use futures::{stream, StreamExt, TryStreamExt}; use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_metric::MetricsComponent; use nativelink_util::action_messages::{ @@ -60,71 +60,6 @@ impl ActionStateResult for ErrorActionStateResult { } } -fn apply_filter_predicate(awaited_action: &AwaitedAction, filter: &OperationFilter) -> bool { - // Note: The caller must filter `client_operation_id`. - - if let Some(operation_id) = &filter.operation_id { - if operation_id != awaited_action.operation_id() { - return false; - } - } - - if filter.worker_id.is_some() && filter.worker_id != awaited_action.worker_id() { - return false; - } - - { - if let Some(filter_unique_key) = &filter.unique_key { - match &awaited_action.action_info().unique_qualifier { - ActionUniqueQualifier::Cachable(unique_key) => { - if filter_unique_key != unique_key { - return false; - } - } - ActionUniqueQualifier::Uncachable(_) => { - return false; - } - } - } - if let Some(action_digest) = filter.action_digest { - if action_digest != awaited_action.action_info().digest() { - return false; - } - } - } - - { - let last_worker_update_timestamp = awaited_action.last_worker_updated_timestamp(); - if let Some(worker_update_before) = filter.worker_update_before { - if worker_update_before < last_worker_update_timestamp { - return false; - } - } - if let Some(completed_before) = filter.completed_before { - if awaited_action.state().stage.is_finished() - && completed_before < last_worker_update_timestamp - { - return false; - } - } - if filter.stages != OperationStageFlags::Any { - let stage_flag = match awaited_action.state().stage { - ActionStage::Unknown => OperationStageFlags::Any, - ActionStage::CacheCheck => OperationStageFlags::CacheCheck, - ActionStage::Queued => OperationStageFlags::Queued, - ActionStage::Executing => OperationStageFlags::Executing, - ActionStage::Completed(_) => OperationStageFlags::Completed, - ActionStage::CompletedFromCache(_) => OperationStageFlags::Completed, - }; - if !filter.stages.intersects(stage_flag) { - return false; - } - } - } - - true -} - struct ClientActionStateResult where U: AwaitedActionSubscriber, @@ -329,6 +264,11 @@ where )] no_event_action_timeout: Duration, + /// Mark operation as timed out if the worker has not updated in this duration. + /// This is used to prevent operations from being stuck in the queue forever + /// if it is not being processed by any worker. + client_action_timeout: Duration, + // A lock to ensure only one timeout operation is running at a time // on this service. timeout_operation_mux: Mutex<()>, @@ -352,6 +292,7 @@ where pub fn new( max_job_retries: usize, no_event_action_timeout: Duration, + client_action_timeout: Duration, action_db: T, now_fn: NowFn, ) -> Arc { @@ -359,12 +300,111 @@ where action_db, max_job_retries, no_event_action_timeout, + client_action_timeout, timeout_operation_mux: Mutex::new(()), weak_self: weak_self.clone(), now_fn, }) } + async fn apply_filter_predicate( + &self, + awaited_action: &AwaitedAction, + filter: &OperationFilter, + ) -> bool { + // Note: The caller must filter `client_operation_id`. + + if awaited_action.last_client_keepalive_timestamp() + self.client_action_timeout + < (self.now_fn)().now() + { + if !awaited_action.state().stage.is_finished() { + let mut state = awaited_action.state().as_ref().clone(); + state.stage = ActionStage::Completed(ActionResult { + error: Some(make_err!( + Code::DeadlineExceeded, + "Operation timed out {} seconds of having no more clients listening", + self.client_action_timeout.as_secs_f32(), + )), + ..ActionResult::default() + }); + let mut new_awaited_action = awaited_action.clone(); + new_awaited_action.worker_set_state(Arc::new(state), (self.now_fn)().now()); + if let Err(err) = self + .action_db + .update_awaited_action(new_awaited_action) + .await + { + event!( + Level::WARN, + "Failed to update action to timed out state after client keepalive timeout. This is ok if multiple schedulers tried to set the state at the same time: {err}", + ); + } + } + return false; + } + + if let Some(operation_id) = &filter.operation_id { + if operation_id != awaited_action.operation_id() { + return false; + } + } + + if filter.worker_id.is_some() && filter.worker_id != awaited_action.worker_id() { + return false; + } + + { + if let Some(filter_unique_key) = &filter.unique_key { + match &awaited_action.action_info().unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => { + if filter_unique_key != unique_key { + return false; + } + } + ActionUniqueQualifier::Uncachable(_) => { + return false; + } + } + } + if let Some(action_digest) = filter.action_digest { + if action_digest != awaited_action.action_info().digest() { + return false; + } + } + } + + { + let last_worker_update_timestamp = awaited_action.last_worker_updated_timestamp(); + if let Some(worker_update_before) = filter.worker_update_before { + if worker_update_before < last_worker_update_timestamp { + return false; + } + } + if let Some(completed_before) = filter.completed_before { + if awaited_action.state().stage.is_finished() + && completed_before < last_worker_update_timestamp + { + return false; + } + } + if filter.stages != OperationStageFlags::Any { + let stage_flag = match awaited_action.state().stage { + ActionStage::Unknown => OperationStageFlags::Any, + ActionStage::CacheCheck => OperationStageFlags::CacheCheck, + ActionStage::Queued => OperationStageFlags::Queued, + ActionStage::Executing => OperationStageFlags::Executing, + ActionStage::Completed(_) => OperationStageFlags::Completed, + ActionStage::CompletedFromCache(_) => OperationStageFlags::Completed, + }; + if !filter.stages.intersects(stage_flag) { + return false; + } + } + } + + true + } + /// Let the scheduler know that an operation has timed out from /// the client side (ie: worker has not updated in a while). async fn timeout_operation_id(&self, operation_id: &OperationId) -> Result<(), Error> { @@ -488,7 +528,7 @@ where let stage = match &update { UpdateOperationType::KeepAlive => { - awaited_action.keep_alive((self.now_fn)().now()); + awaited_action.worker_keep_alive((self.now_fn)().now()); return self .action_db .update_awaited_action(awaited_action) @@ -531,7 +571,7 @@ where } else { awaited_action.set_worker_id(maybe_worker_id.copied(), now); } - awaited_action.set_state( + awaited_action.worker_set_state( Arc::new(ActionState { stage, // Client id is not known here, it is the responsibility of @@ -540,7 +580,7 @@ where client_operation_id: operation_id.clone(), action_digest: awaited_action.action_info().digest(), }), - Some(now), + now, ); let update_action_result = self @@ -615,7 +655,7 @@ where .borrow() .await .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")?; - if !apply_filter_predicate(&awaited_action, &filter) { + if !self.apply_filter_predicate(&awaited_action, &filter).await { return Ok(Box::pin(stream::empty())); } return Ok(Box::pin(stream::once(async move { @@ -635,7 +675,7 @@ where .borrow() .await .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")?; - if !apply_filter_predicate(&awaited_action, &filter) { + if !self.apply_filter_predicate(&awaited_action, &filter).await { return Ok(Box::pin(stream::empty())); } return Ok(Box::pin(stream::once(async move { @@ -659,11 +699,13 @@ where Ok((awaited_action_subscriber, awaited_action)) }) .try_filter_map(|(subscriber, awaited_action)| { - if apply_filter_predicate(&awaited_action, &filter) { - future::ready(Ok(Some((subscriber, awaited_action.sort_key())))) - .left_future() - } else { - future::ready(Result::<_, Error>::Ok(None)).right_future() + let filter = filter.clone(); + async move { + if self.apply_filter_predicate(&awaited_action, &filter).await { + Ok(Some((subscriber, awaited_action.sort_key()))) + } else { + Ok(None) + } } }) .try_collect() @@ -703,10 +745,13 @@ where Ok((awaited_action_subscriber, awaited_action)) }) .try_filter_map(move |(subscriber, awaited_action)| { - if apply_filter_predicate(&awaited_action, &filter) { - future::ready(Ok(Some(subscriber))).left_future() - } else { - future::ready(Result::<_, Error>::Ok(None)).right_future() + let filter = filter.clone(); + async move { + if self.apply_filter_predicate(&awaited_action, &filter).await { + Ok(Some(subscriber)) + } else { + Ok(None) + } } }) .map(move |result| -> Box { diff --git a/nativelink-scheduler/src/store_awaited_action_db.rs b/nativelink-scheduler/src/store_awaited_action_db.rs index ce28cc92f..296d8a5c2 100644 --- a/nativelink-scheduler/src/store_awaited_action_db.rs +++ b/nativelink-scheduler/src/store_awaited_action_db.rs @@ -14,8 +14,9 @@ use std::borrow::Cow; use std::ops::Bound; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Weak}; -use std::time::{Duration, SystemTime}; +use std::time::Duration; use bytes::Bytes; use futures::{Stream, TryStreamExt}; @@ -24,6 +25,7 @@ use nativelink_metric::MetricsComponent; use nativelink_util::action_messages::{ ActionInfo, ActionStage, ActionUniqueQualifier, OperationId, }; +use nativelink_util::instant_wrapper::InstantWrapper; use nativelink_util::spawn; use nativelink_util::store_trait::{ FalseValue, SchedulerCurrentVersionProvider, SchedulerIndexProvider, SchedulerStore, @@ -41,42 +43,55 @@ use crate::awaited_action_db::{ type ClientOperationId = OperationId; +/// Duration to wait before sending client keep alive messages. +const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10); + +/// Maximum number of retries to update client keep alive. +const MAX_RETRIES_FOR_CLIENT_KEEPALIVE: u32 = 8; + enum OperationSubscriberState { Unsubscribed, Subscribed(Sub), } -pub struct OperationSubscriber { +pub struct OperationSubscriber I> { maybe_client_operation_id: Option, subscription_key: OperationIdToAwaitedAction<'static>, weak_store: Weak, state: OperationSubscriberState< ::Subscription, >, - now_fn: fn() -> SystemTime, + last_known_keepalive_ts: AtomicU64, + now_fn: NowFn, } -impl OperationSubscriber { +impl OperationSubscriber +where + S: SchedulerStore, + I: InstantWrapper, + NowFn: Fn() -> I, +{ fn new( maybe_client_operation_id: Option, subscription_key: OperationIdToAwaitedAction<'static>, weak_store: Weak, - now_fn: fn() -> SystemTime, + now_fn: NowFn, ) -> Self { Self { maybe_client_operation_id, subscription_key, weak_store, + last_known_keepalive_ts: AtomicU64::new(0), state: OperationSubscriberState::Unsubscribed, now_fn, } } - async fn get_awaited_action(&self) -> Result { - let store = self - .weak_store - .upgrade() - .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")?; - let key = self.subscription_key.borrow(); + async fn inner_get_awaited_action( + store: &S, + key: OperationIdToAwaitedAction<'_>, + maybe_client_operation_id: Option, + last_known_keepalive_ts: &AtomicU64, + ) -> Result { let mut awaited_action = store .get_and_decode(key.borrow()) .await @@ -87,16 +102,39 @@ impl OperationSubscriber { "Could not find AwaitedAction for the given operation id {key:?}", ) })?; - if let Some(client_operation_id) = &self.maybe_client_operation_id { - let mut state = awaited_action.state().as_ref().clone(); - state.client_operation_id = client_operation_id.clone(); - awaited_action.set_state(Arc::new(state), Some((self.now_fn)())); + if let Some(client_operation_id) = maybe_client_operation_id { + awaited_action.set_client_operation_id(client_operation_id); } + last_known_keepalive_ts.store( + awaited_action + .last_client_keepalive_timestamp() + .unix_timestamp(), + Ordering::Release, + ); Ok(awaited_action) } + + async fn get_awaited_action(&self) -> Result { + let store = self + .weak_store + .upgrade() + .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")?; + Self::inner_get_awaited_action( + store.as_ref(), + self.subscription_key.borrow(), + self.maybe_client_operation_id.clone(), + &self.last_known_keepalive_ts, + ) + .await + } } -impl AwaitedActionSubscriber for OperationSubscriber { +impl AwaitedActionSubscriber for OperationSubscriber +where + S: SchedulerStore, + I: InstantWrapper, + NowFn: Fn() -> I + Send + Sync + 'static, +{ async fn changed(&mut self) -> Result { let store = self .weak_store @@ -117,13 +155,63 @@ impl AwaitedActionSubscriber for OperationSubscriber { } OperationSubscriberState::Subscribed(subscription) => subscription, }; - subscription - .changed() - .await - .err_tip(|| "In OperationSubscriber::changed")?; - self.get_awaited_action() - .await - .err_tip(|| "In OperationSubscriber::changed") + + let changed_fut = subscription.changed(); + tokio::pin!(changed_fut); + loop { + let mut retries = 0; + loop { + let last_known_keepalive_ts = self.last_known_keepalive_ts.load(Ordering::Acquire); + if I::from_secs(last_known_keepalive_ts).elapsed() <= CLIENT_KEEPALIVE_DURATION { + break; // We are still within the keep alive duration. + } + if retries > MAX_RETRIES_FOR_CLIENT_KEEPALIVE { + return Err(make_err!( + Code::Aborted, + "Could not update client keep alive for AwaitedAction", + )); + } + let mut awaited_action = Self::inner_get_awaited_action( + store.as_ref(), + self.subscription_key.borrow(), + self.maybe_client_operation_id.clone(), + &self.last_known_keepalive_ts, + ) + .await + .err_tip(|| "In OperationSubscriber::changed")?; + awaited_action.update_client_keep_alive((self.now_fn)().now()); + let update_res = inner_update_awaited_action(store.as_ref(), awaited_action) + .await + .err_tip(|| "In OperationSubscriber::changed"); + if update_res.is_ok() { + break; + } + retries += 1; + // Wait a tick before retrying. + (self.now_fn)().sleep(Duration::from_millis(100)).await; + } + let sleep_fut = (self.now_fn)().sleep(CLIENT_KEEPALIVE_DURATION); + tokio::select! { + result = &mut changed_fut => { + result?; + break; + } + _ = sleep_fut => { + // If we haven't received any updates for a while, we should + // let the database know that we are still listening to prevent + // the action from being dropped. + } + } + } + + Self::inner_get_awaited_action( + store.as_ref(), + self.subscription_key.borrow(), + self.maybe_client_operation_id.clone(), + &self.last_known_keepalive_ts, + ) + .await + .err_tip(|| "In OperationSubscriber::changed") } async fn borrow(&self) -> Result { @@ -294,19 +382,54 @@ impl SchedulerStoreDataProvider for UpdateClientIdToOperationId { } } +async fn inner_update_awaited_action( + store: &impl SchedulerStore, + mut new_awaited_action: AwaitedAction, +) -> Result<(), Error> { + let operation_id = new_awaited_action.operation_id().clone(); + if new_awaited_action.state().client_operation_id != operation_id { + // Just in case the client_operation_id was set to something else + // we put it back to the underlying operation_id. + new_awaited_action.set_client_operation_id(operation_id.clone()); + } + let maybe_version = store + .update_data(UpdateOperationIdToAwaitedAction(new_awaited_action)) + .await + .err_tip(|| "In RedisAwaitedActionDb::update_awaited_action")?; + if maybe_version.is_none() { + return Err(make_err!( + Code::Aborted, + "Could not update AwaitedAction because the version did not match for {operation_id:?}", + )); + } + Ok(()) +} + #[derive(MetricsComponent)] -pub struct StoreAwaitedActionDb OperationId> { +pub struct StoreAwaitedActionDb +where + S: SchedulerStore, + F: Fn() -> OperationId, + I: InstantWrapper, + NowFn: Fn() -> I, +{ store: Arc, - now_fn: fn() -> SystemTime, + now_fn: NowFn, operation_id_creator: F, _pull_task_change_subscriber_spawn: JoinHandleDropGuard<()>, } -impl OperationId> StoreAwaitedActionDb { +impl StoreAwaitedActionDb +where + S: SchedulerStore, + F: Fn() -> OperationId, + I: InstantWrapper, + NowFn: Fn() -> I + Send + Sync + Clone + 'static, +{ pub fn new( store: Arc, task_change_publisher: Arc, - now_fn: fn() -> SystemTime, + now_fn: NowFn, operation_id_creator: F, ) -> Result { let mut subscription = store @@ -354,7 +477,7 @@ impl OperationId> StoreAwaitedActionDb { // removed the ability to upgrade priorities of actions. // we should add priority upgrades back in. _priority: i32, - ) -> Result>, Error> { + ) -> Result>, Error> { match unique_qualifier { ActionUniqueQualifier::Cachable(_) => {} ActionUniqueQualifier::Uncachable(_) => return Ok(None), @@ -383,7 +506,7 @@ impl OperationId> StoreAwaitedActionDb { Some(client_operation_id.clone()), OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())), Arc::downgrade(&self.store), - self.now_fn, + self.now_fn.clone(), ))) } None => Ok(None), @@ -393,7 +516,7 @@ impl OperationId> StoreAwaitedActionDb { async fn inner_get_awaited_action_by_id( &self, client_operation_id: &ClientOperationId, - ) -> Result>, Error> { + ) -> Result>, Error> { let maybe_operation_id = self .store .get_and_decode(ClientIdToOperationId(client_operation_id)) @@ -406,15 +529,19 @@ impl OperationId> StoreAwaitedActionDb { Some(client_operation_id.clone()), OperationIdToAwaitedAction(Cow::Owned(operation_id)), Arc::downgrade(&self.store), - self.now_fn, + self.now_fn.clone(), ))) } } -impl OperationId + Send + Sync + Unpin + 'static> AwaitedActionDb - for StoreAwaitedActionDb +impl AwaitedActionDb for StoreAwaitedActionDb +where + S: SchedulerStore, + F: Fn() -> OperationId + Send + Sync + Unpin + 'static, + I: InstantWrapper, + NowFn: Fn() -> I + Send + Sync + Unpin + Clone + 'static, { - type Subscriber = OperationSubscriber; + type Subscriber = OperationSubscriber; async fn get_awaited_action_by_id( &self, @@ -432,24 +559,12 @@ impl OperationId + Send + Sync + Unpin + 'static> None, OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())), Arc::downgrade(&self.store), - self.now_fn, + self.now_fn.clone(), ))) } async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> { - let operation_id = new_awaited_action.operation_id().clone(); - let maybe_version = self - .store - .update_data(UpdateOperationIdToAwaitedAction(new_awaited_action)) - .await - .err_tip(|| "In RedisAwaitedActionDb::update_awaited_action")?; - if maybe_version.is_none() { - return Err(make_err!( - Code::Aborted, - "Could not update AwaitedAction because the version did not match for {operation_id:?}", - )); - } - Ok(()) + inner_update_awaited_action(self.store.as_ref(), new_awaited_action).await } async fn add_action( @@ -472,7 +587,7 @@ impl OperationId + Send + Sync + Unpin + 'static> let new_operation_id = (self.operation_id_creator)(); let awaited_action = - AwaitedAction::new(new_operation_id.clone(), action_info, (self.now_fn)()); + AwaitedAction::new(new_operation_id.clone(), action_info, (self.now_fn)().now()); debug_assert!( ActionStage::Queued == awaited_action.state().stage, "Expected action to be queued" @@ -500,7 +615,7 @@ impl OperationId + Send + Sync + Unpin + 'static> Some(client_operation_id), OperationIdToAwaitedAction(Cow::Owned(new_operation_id)), Arc::downgrade(&self.store), - self.now_fn, + self.now_fn.clone(), )) } @@ -541,7 +656,7 @@ impl OperationId + Send + Sync + Unpin + 'static> None, OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())), Arc::downgrade(&self.store), - self.now_fn, + self.now_fn.clone(), ) })) } @@ -559,7 +674,7 @@ impl OperationId + Send + Sync + Unpin + 'static> None, OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())), Arc::downgrade(&self.store), - self.now_fn, + self.now_fn.clone(), ) })) } diff --git a/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs b/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs index c7f761c70..4eea3ca43 100644 --- a/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs +++ b/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs @@ -38,6 +38,7 @@ use nativelink_util::action_messages::{ }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; +use nativelink_util::instant_wrapper::MockInstantWrapped; use nativelink_util::store_trait::{SchedulerStore, SchedulerSubscriptionManager}; use parking_lot::Mutex; use pretty_assertions::assert_eq; @@ -188,7 +189,7 @@ async fn add_action_smoke_test() -> Result<(), Error> { let mut new_awaited_action = worker_awaited_action.clone(); let mut new_state = new_awaited_action.state().as_ref().clone(); new_state.stage = ActionStage::Executing; - new_awaited_action.set_state(Arc::new(new_state), Some(MockSystemTime::now().into())); + new_awaited_action.worker_set_state(Arc::new(new_state), MockSystemTime::now().into()); new_awaited_action }; @@ -428,7 +429,7 @@ async fn add_action_smoke_test() -> Result<(), Error> { let awaited_action_db = StoreAwaitedActionDb::new( store.clone(), notifier.clone(), - || MockSystemTime::now().into(), + MockInstantWrapped::default, move || WORKER_OPERATION_ID.into(), ) .unwrap();