Skip to content

Commit

Permalink
Merge pull request #50 from oiwn/dev
Browse files Browse the repository at this point in the history
add postgres backend
  • Loading branch information
oiwn authored Dec 21, 2024
2 parents 9d625dd + 373a54e commit 9896fde
Show file tree
Hide file tree
Showing 8 changed files with 529 additions and 1 deletion.
3 changes: 3 additions & 0 deletions capp-queue/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ rustis = { version = "0.13", features = ["tokio-runtime"], optional = true }
mongodb = { version = "3", optional = true }
futures = { version = "0.3", optional = true }
futures-util = { version = "0.3", optional = true }
sqlx = { version = "0.8", features = [ "runtime-tokio", "postgres", "json", "uuid", "chrono" ], optional = true }

[dev-dependencies]
dotenvy = "0.15"
tokio = { version = "1.41", features = ["full", "test-util"] }
serial_test = "3"

[features]
redis = ["dep:tokio", "dep:rustis"]
mongodb = ["dep:tokio", "dep:futures", "dep:futures-util", "dep:mongodb"]
postgres = ["dep:tokio", "dep:sqlx"]
4 changes: 4 additions & 0 deletions capp-queue/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pub mod memory;
#[cfg(feature = "mongodb")]
pub mod mongodb;
#[cfg(feature = "postgres")]
pub mod postgres;
#[cfg(feature = "redis")]
pub mod redis;
#[cfg(feature = "redis")]
Expand All @@ -9,6 +11,8 @@ pub mod redis_rr;
pub use memory::InMemoryTaskQueue;
#[cfg(feature = "mongodb")]
pub use mongodb::MongoTaskQueue;
#[cfg(feature = "postgres")]
pub use postgres::PostgresTaskQueue;
#[cfg(feature = "redis")]
pub use redis::RedisTaskQueue;
#[cfg(feature = "redis")]
Expand Down
240 changes: 240 additions & 0 deletions capp-queue/src/backend/postgres.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
use crate::{Task, TaskId, TaskQueue, TaskQueueError, TaskSerializer, TaskStatus};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use sqlx::types::chrono::{DateTime, Utc};
use sqlx::{PgPool, Pool, Postgres};
use std::marker::PhantomData;

pub struct PostgresTaskQueue<D, S>
where
S: TaskSerializer,
{
pub pool: PgPool,
_marker: PhantomData<(D, S)>,
}

impl<D, S> PostgresTaskQueue<D, S>
where
D: Send + Sync + 'static,
S: TaskSerializer + Send + Sync,
{
pub async fn new(connection_string: &str) -> Result<Self, TaskQueueError> {
let pool = PgPool::connect(connection_string)
.await
.map_err(|e| TaskQueueError::PostgresError(e))?;

Ok(Self {
pool,
_marker: PhantomData,
})
}
}

#[async_trait]
impl<D, S> TaskQueue<D> for PostgresTaskQueue<D, S>
where
D: std::fmt::Debug
+ Clone
+ Serialize
+ DeserializeOwned
+ Send
+ Sync
+ 'static,
S: TaskSerializer + Send + Sync,
{
async fn push(&self, task: &Task<D>) -> Result<(), TaskQueueError> {
// Serialize task data to JSON
let task_bytes = S::serialize_task(task)?;
let payload: serde_json::Value = serde_json::from_slice(&task_bytes)
.map_err(|e| TaskQueueError::Serialization(e.to_string()))?;
let queued_at = DateTime::<Utc>::from(task.queued_at);

sqlx::query!(
r#"
INSERT INTO tasks (id, payload, status, queued_at)
VALUES ($1, $2, 'Queued', $3)
"#,
task.task_id.get(),
payload,
queued_at,
)
.execute(&self.pool)
.await
.map_err(TaskQueueError::PostgresError)?;

Ok(())
}

async fn pop(&self) -> Result<Task<D>, TaskQueueError> {
let mut tx = self
.pool
.begin()
.await
.map_err(TaskQueueError::PostgresError)?;

let result = async {
// Get oldest queued task and lock it
let row = sqlx::query!(
r#"
SELECT id, payload, queued_at, started_at, finished_at, retries, error_msg
FROM tasks
WHERE status = 'Queued'
ORDER BY queued_at ASC
FOR UPDATE SKIP LOCKED
LIMIT 1
"#
)
.fetch_optional(&mut *tx)
.await
.map_err(TaskQueueError::PostgresError)?;

match row {
Some(row) => {
// Update status to InProgress
let started_at = Utc::now();
sqlx::query!(
r#"
UPDATE tasks
SET status = 'InProgress', started_at = $1
WHERE id = $2
"#,
started_at,
row.id,
)
.execute(&mut *tx)
.await
.map_err(TaskQueueError::PostgresError)?;

// Deserialize task
let task_bytes = serde_json::to_vec(&row.payload)
.map_err(|e| TaskQueueError::Serialization(e.to_string()))?;

let mut task = S::deserialize_task(&task_bytes)?;
task.set_in_progress();
task.started_at = Some(started_at.into());

Ok(task)
}
None => Err(TaskQueueError::QueueEmpty),
}
}.await;

// Either commit on success or rollback on error
match result {
Ok(task) => {
tx.commit().await.map_err(TaskQueueError::PostgresError)?;
Ok(task)
}
Err(e) => {
tx.rollback().await.map_err(TaskQueueError::PostgresError)?;
Err(e)
}
}
}

async fn ack(&self, task_id: &TaskId) -> Result<(), TaskQueueError> {
let mut tx = self
.pool
.begin()
.await
.map_err(TaskQueueError::PostgresError)?;

let result =
sqlx::query!(r#"DELETE FROM tasks WHERE id = $1"#, task_id.get(),)
.execute(&mut *tx)
.await
.map_err(TaskQueueError::PostgresError)?;

if result.rows_affected() == 0 {
tx.rollback().await.map_err(TaskQueueError::PostgresError)?;
return Err(TaskQueueError::TaskNotFound(*task_id));
}

tx.commit().await.map_err(TaskQueueError::PostgresError)?;
Ok(())
}

async fn nack(&self, task: &Task<D>) -> Result<(), TaskQueueError> {
let mut tx = self
.pool
.begin()
.await
.map_err(TaskQueueError::PostgresError)?;

// First move to DLQ
let task_bytes = S::serialize_task(task)?;
let payload: serde_json::Value = serde_json::from_slice(&task_bytes)
.map_err(|e| TaskQueueError::Serialization(e.to_string()))?;

sqlx::query!(
r#"INSERT INTO dlq (id, payload, error_msg) VALUES ($1, $2, $3)"#,
task.task_id.get(),
payload,
task.error_msg.as_deref(),
)
.execute(&mut *tx)
.await
.map_err(TaskQueueError::PostgresError)?;

// Then remove from tasks
sqlx::query!(r#"DELETE FROM tasks WHERE id = $1"#, task.task_id.get(),)
.execute(&mut *tx)
.await
.map_err(TaskQueueError::PostgresError)?;

tx.commit().await.map_err(TaskQueueError::PostgresError)?;
Ok(())
}

async fn set(&self, task: &Task<D>) -> Result<(), TaskQueueError> {
let mut tx = self
.pool
.begin()
.await
.map_err(TaskQueueError::PostgresError)?;

// Serialize task data to JSON
let task_bytes = S::serialize_task(task)?;
let payload: serde_json::Value = serde_json::from_slice(&task_bytes)
.map_err(|e| TaskQueueError::Serialization(e.to_string()))?;

let status = match task.status {
TaskStatus::Queued => "Queued",
TaskStatus::InProgress => "InProgress",
TaskStatus::Completed => "Completed",
TaskStatus::Failed => "Failed",
TaskStatus::DeadLetter => "DeadLetter",
};

let result = sqlx::query(
r#"
UPDATE tasks
SET payload = $1,
status = $2,
started_at = $3,
finished_at = $4,
retries = $5,
error_msg = $6
WHERE id = $7
"#,
)
.bind(payload)
.bind(status)
.bind(task.started_at.map(|t| DateTime::<Utc>::from(t)))
.bind(task.finished_at.map(|t| DateTime::<Utc>::from(t)))
.bind(task.retries as i32)
.bind(&task.error_msg)
.bind(task.task_id.get())
.execute(&mut *tx)
.await
.map_err(TaskQueueError::PostgresError)?;

if result.rows_affected() == 0 {
tx.rollback().await.map_err(TaskQueueError::PostgresError)?;
return Err(TaskQueueError::TaskNotFound(task.task_id));
}

tx.commit().await.map_err(TaskQueueError::PostgresError)?;
Ok(())
}
}
5 changes: 5 additions & 0 deletions capp-queue/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub mod task;
pub use crate::backend::InMemoryTaskQueue;
#[cfg(feature = "mongodb")]
pub use crate::backend::MongoTaskQueue;
#[cfg(feature = "postgres")]
pub use crate::backend::PostgresTaskQueue;
#[cfg(feature = "redis")]
pub use crate::backend::{RedisRoundRobinTaskQueue, RedisTaskQueue};
pub use crate::queue::{AbstractTaskQueue, HasTagKey, TaskQueue};
Expand Down Expand Up @@ -34,4 +36,7 @@ pub enum TaskQueueError {
#[cfg(feature = "mongodb")]
#[error("Mongodb Error")]
MongodbError(#[from] mongodb::error::Error),
#[cfg(feature = "postgres")]
#[error("Postgres Error")]
PostgresError(#[from] sqlx::Error),
}
Loading

0 comments on commit 9896fde

Please sign in to comment.