Skip to content

Commit

Permalink
Merge pull request #46 from oiwn/dev
Browse files Browse the repository at this point in the history
add failing tests for mongodb
  • Loading branch information
oiwn authored Dec 14, 2024
2 parents 23526e6 + e16c67e commit a5bc3ce
Show file tree
Hide file tree
Showing 4 changed files with 396 additions and 68 deletions.
149 changes: 85 additions & 64 deletions capp-queue/src/backend/mongodb.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
use async_trait::async_trait;
use mongodb::{
bson::{doc, Document},
options::{ClientOptions, FindOneAndDeleteOptions},
Client, Collection,
bson::doc,
error::TRANSIENT_TRANSACTION_ERROR,
error::UNKNOWN_TRANSACTION_COMMIT_RESULT,
options::{ClientOptions, IndexOptions},
Client, ClientSession, Collection, IndexModel,
};
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;

use crate::queue::{TaskQueue, TaskQueueError};
use crate::task::{Task, TaskId};

pub struct MongoTaskQueue<D> {
pub struct MongoTaskQueue<D: Clone>
where
D: Send + Sync + 'static,
{
pub client: Client,
pub tasks_collection: Collection<Document>,
pub dlq_collection: Collection<Document>,
_marker: PhantomData<D>,
pub tasks_collection: Collection<Task<D>>,
pub dlq_collection: Collection<Task<D>>,
}

impl<D> MongoTaskQueue<D> {
impl<D> MongoTaskQueue<D>
where
D: Clone + Serialize + DeserializeOwned + Send + Sync + 'static,
{
pub async fn new(
connection_string: &str,
queue_name: &str,
Expand All @@ -30,24 +36,65 @@ impl<D> MongoTaskQueue<D> {
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

let db = client.database("task_queue");
let tasks_collection = db.collection(&format!("{}_tasks", queue_name));
let dlq_collection = db.collection(&format!("{}_dlq", queue_name));
let tasks_collection =
db.collection::<Task<D>>(&format!("{}_tasks", queue_name));
let dlq_collection =
db.collection::<Task<D>>(&format!("{}_dlq", queue_name));

// Create index on task_id
let index_model = IndexModel::builder()
.keys(doc! { "task_id": 1 })
.options(IndexOptions::builder().unique(true).build())
.build();

// Create indexes
tasks_collection
.create_index(doc! { "task_id": 1 }, None)
.create_index(index_model)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

Ok(Self {
client,
tasks_collection,
dlq_collection,
_marker: PhantomData,
})
}
}

impl<D> MongoTaskQueue<D>
where
D: Clone + Serialize + DeserializeOwned + Send + Sync + 'static,
{
// Helper method to execute the nack transaction
async fn execute_nack_transaction(
&self,
task: &Task<D>,
session: &mut ClientSession,
) -> mongodb::error::Result<()> {
// Move to DLQ
self.dlq_collection
.insert_one(task)
.session(&mut *session)
.await?;

// Remove from main queue
self.tasks_collection
.delete_one(doc! { "task_id": task.task_id.to_string() })
.session(&mut *session)
.await?;

// Commit with retry logic for unknown commit results
loop {
let result = session.commit_transaction().await;
if let Err(ref error) = result {
if error.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) {
continue;
}
}
return result;
}
}
}

#[async_trait]
impl<D> TaskQueue<D> for MongoTaskQueue<D>
where
Expand All @@ -60,40 +107,30 @@ where
+ 'static,
{
async fn push(&self, task: &Task<D>) -> Result<(), TaskQueueError> {
let task_doc = mongodb::bson::to_document(&task)
.map_err(|e| TaskQueueError::SerdeError(e.to_string()))?;

self.tasks_collection
.insert_one(task_doc, None)
.insert_one(task)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

Ok(())
}

async fn pop(&self) -> Result<Task<D>, TaskQueueError> {
let options = FindOneAndDeleteOptions::default();

let result = self
match self
.tasks_collection
.find_one_and_delete(doc! {}, options)
.find_one_and_delete(doc! {})
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

match result {
Some(doc) => {
let task: Task<D> = mongodb::bson::from_document(doc)
.map_err(|e| TaskQueueError::SerdeError(e.to_string()))?;
Ok(task)
}
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?
{
Some(task) => Ok(task),
None => Err(TaskQueueError::QueueEmpty),
}
}

async fn ack(&self, task_id: &TaskId) -> Result<(), TaskQueueError> {
let result = self
.tasks_collection
.delete_one(doc! { "task_id": task_id.to_string() }, None)
.delete_one(doc! { "task_id": task_id.to_string() })
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

Expand All @@ -105,54 +142,38 @@ where
}

async fn nack(&self, task: &Task<D>) -> Result<(), TaskQueueError> {
let task_doc = mongodb::bson::to_document(&task)
.map_err(|e| TaskQueueError::SerdeError(e.to_string()))?;

// Start session for transaction
let mut session = self
.client
.start_session(None)
.start_session()
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

// Configure transaction options with majority read/write concerns
session
.start_transaction(None)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

// Move to DLQ and remove from main queue
self.dlq_collection
.insert_one_with_session(task_doc, None, &mut session)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

self.tasks_collection
.delete_one_with_session(
doc! { "task_id": task.task_id.to_string() },
None,
&mut session,
)
.start_transaction()
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

session
.commit_transaction()
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;
// Execute transaction with retry logic for transient errors
while let Err(error) =
self.execute_nack_transaction(task, &mut session).await
{
if !error.contains_label(TRANSIENT_TRANSACTION_ERROR) {
return Err(TaskQueueError::QueueError(error.to_string()));
}
// Retry transaction on transient errors
session
.start_transaction()
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;
}

Ok(())
}

async fn set(&self, task: &Task<D>) -> Result<(), TaskQueueError> {
let task_doc = mongodb::bson::to_document(&task)
.map_err(|e| TaskQueueError::SerdeError(e.to_string()))?;

self.tasks_collection
.replace_one(
doc! { "task_id": task.task_id.to_string() },
task_doc,
None,
)
.replace_one(doc! { "task_id": task.task_id.to_string() }, task)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

Expand Down
Loading

0 comments on commit a5bc3ce

Please sign in to comment.