Skip to content

Commit

Permalink
Merge pull request #47 from oiwn/dev
Browse files Browse the repository at this point in the history
adding tests for mongodb backend
  • Loading branch information
oiwn authored Dec 15, 2024
2 parents a5bc3ce + c4ba0c5 commit cea874e
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 303 deletions.
5 changes: 4 additions & 1 deletion capp-queue/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"

[dependencies]
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
tokio = { workspace = true, optional = true }
thiserror = { workspace = true }
serde = { workspace = true }
Expand All @@ -13,10 +14,12 @@ async-trait = { workspace = true }
uuid = { workspace = true }
rustis = { version = "0.13", features = ["tokio-runtime"], optional = true }
mongodb = { version = "3", optional = true }
futures = { version = "0.3", optional = true }
futures-util = { version = "0.3", optional = true }

[dev-dependencies]
dotenvy = "0.15"

[features]
redis = ["dep:tokio", "dep:rustis"]
mongodb = ["dep:tokio", "dep:mongodb"]
mongodb = ["dep:tokio", "dep:futures", "dep:futures-util", "dep:mongodb"]
138 changes: 56 additions & 82 deletions capp-queue/src/backend/mongodb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ where
connection_string: &str,
queue_name: &str,
) -> Result<Self, TaskQueueError> {
let client_options = ClientOptions::parse(connection_string)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;
let client_options = ClientOptions::parse(connection_string).await?;

let client = Client::with_options(client_options)
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;
let client = Client::with_options(client_options.clone())?;

let db = client.database("task_queue");
// Get database name from URI or use default
let db_name = client_options
.default_database
.as_ref()
.expect("No database specified in MongoDB URI");

let db = client.database(db_name);
let tasks_collection =
db.collection::<Task<D>>(&format!("{}_tasks", queue_name));
let dlq_collection =
Expand All @@ -47,10 +50,8 @@ where
.options(IndexOptions::builder().unique(true).build())
.build();

tasks_collection
.create_index(index_model)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;
tasks_collection.create_index(index_model.clone()).await?;
dlq_collection.create_index(index_model).await?;

Ok(Self {
client,
Expand All @@ -60,41 +61,6 @@ where
}
}

impl<D> MongoTaskQueue<D>
where
D: Clone + Serialize + DeserializeOwned + Send + Sync + 'static,
{
// Helper method to execute the nack transaction
async fn execute_nack_transaction(
&self,
task: &Task<D>,
session: &mut ClientSession,
) -> mongodb::error::Result<()> {
// Move to DLQ
self.dlq_collection
.insert_one(task)
.session(&mut *session)
.await?;

// Remove from main queue
self.tasks_collection
.delete_one(doc! { "task_id": task.task_id.to_string() })
.session(&mut *session)
.await?;

// Commit with retry logic for unknown commit results
loop {
let result = session.commit_transaction().await;
if let Err(ref error) = result {
if error.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) {
continue;
}
}
return result;
}
}
}

#[async_trait]
impl<D> TaskQueue<D> for MongoTaskQueue<D>
where
Expand All @@ -107,21 +73,12 @@ where
+ 'static,
{
async fn push(&self, task: &Task<D>) -> Result<(), TaskQueueError> {
self.tasks_collection
.insert_one(task)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

self.tasks_collection.insert_one(task).await?;
Ok(())
}

async fn pop(&self) -> Result<Task<D>, TaskQueueError> {
match self
.tasks_collection
.find_one_and_delete(doc! {})
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?
{
match self.tasks_collection.find_one_and_delete(doc! {}).await? {
Some(task) => Ok(task),
None => Err(TaskQueueError::QueueEmpty),
}
Expand All @@ -131,41 +88,24 @@ where
let result = self
.tasks_collection
.delete_one(doc! { "task_id": task_id.to_string() })
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

.await?;
if result.deleted_count == 0 {
return Err(TaskQueueError::TaskNotFound(*task_id));
}

Ok(())
}

async fn nack(&self, task: &Task<D>) -> Result<(), TaskQueueError> {
let mut session = self
.client
.start_session()
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

// Configure transaction options with majority read/write concerns
session
.start_transaction()
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

// Execute transaction with retry logic for transient errors
let mut session = self.client.start_session().await?; // Convert to MongodbError

session.start_transaction().await?; // Convert to MongodbError

while let Err(error) =
self.execute_nack_transaction(task, &mut session).await
{
if !error.contains_label(TRANSIENT_TRANSACTION_ERROR) {
return Err(TaskQueueError::QueueError(error.to_string()));
return Err(TaskQueueError::MongodbError(error));
}
// Retry transaction on transient errors
session
.start_transaction()
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;
session.start_transaction().await?; // Convert to MongodbError
}

Ok(())
Expand All @@ -174,9 +114,43 @@ where
async fn set(&self, task: &Task<D>) -> Result<(), TaskQueueError> {
self.tasks_collection
.replace_one(doc! { "task_id": task.task_id.to_string() }, task)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;
.await?; // Convert to MongodbError

Ok(())
}
}

impl<D> MongoTaskQueue<D>
where
D: Clone + Serialize + DeserializeOwned + Send + Sync + 'static,
{
// Helper method to execute the nack transaction
async fn execute_nack_transaction(
&self,
task: &Task<D>,
session: &mut ClientSession,
) -> mongodb::error::Result<()> {
// Move to DLQ
self.dlq_collection
.insert_one(task)
.session(&mut *session)
.await?;

// Remove from main queue
self.tasks_collection
.delete_one(doc! { "task_id": task.task_id.to_string() })
.session(&mut *session)
.await?;

// Commit with retry logic for unknown commit results
loop {
let result = session.commit_transaction().await;
if let Err(ref error) = result {
if error.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) {
continue;
}
}
return result;
}
}
}
3 changes: 3 additions & 0 deletions capp-queue/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ pub enum TaskQueueError {
#[cfg(feature = "redis")]
#[error("Redis error")]
RedisError(#[from] rustis::Error),
#[cfg(feature = "mongodb")]
#[error("Mongodb Error")]
MongodbError(#[from] mongodb::error::Error),
}

#[async_trait]
Expand Down
16 changes: 15 additions & 1 deletion capp-queue/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use serde::{Deserialize, Serialize};
use std::time::SystemTime;
use uuid::Uuid;

#[cfg(feature = "mongodb")]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum TaskStatus {
Queued,
Expand Down Expand Up @@ -123,6 +123,20 @@ impl std::fmt::Display for TaskId {
}
}

#[cfg(feature = "mongodb")]
impl From<TaskId> for mongodb::bson::Uuid {
fn from(id: TaskId) -> Self {
mongodb::bson::Uuid::from_bytes(*id.get().as_bytes())
}
}

#[cfg(feature = "mongodb")]
impl From<mongodb::bson::Uuid> for TaskId {
fn from(uuid: mongodb::bson::Uuid) -> Self {
TaskId(uuid::Uuid::from_bytes(uuid.bytes()))
}
}

#[cfg(test)]
mod tests {
use core::panic;
Expand Down
Loading

0 comments on commit cea874e

Please sign in to comment.