Skip to content

Commit

Permalink
Merge pull request #51 from oiwn/dev
Browse files Browse the repository at this point in the history
fix postgres backend
  • Loading branch information
oiwn authored Dec 22, 2024
2 parents 9896fde + b403d47 commit 5798f25
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
36 changes: 26 additions & 10 deletions capp-queue/src/backend/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,32 @@ use crate::{Task, TaskId, TaskQueue, TaskQueueError, TaskSerializer, TaskStatus}
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use sqlx::types::chrono::{DateTime, Utc};
use sqlx::{PgPool, Pool, Postgres};
use sqlx::PgPool;
use std::marker::PhantomData;

#[derive(sqlx::Type)]
#[sqlx(type_name = "task_status", rename_all = "PascalCase")]
pub enum PostgresTaskStatus {
Queued,
InProgress,
Completed,
Failed,
DeadLetter,
}

// Add conversion between our TaskStatus and PostgresTaskStatus
impl From<TaskStatus> for PostgresTaskStatus {
fn from(status: TaskStatus) -> Self {
match status {
TaskStatus::Queued => PostgresTaskStatus::Queued,
TaskStatus::InProgress => PostgresTaskStatus::InProgress,
TaskStatus::Completed => PostgresTaskStatus::Completed,
TaskStatus::Failed => PostgresTaskStatus::Failed,
TaskStatus::DeadLetter => PostgresTaskStatus::DeadLetter,
}
}
}

pub struct PostgresTaskQueue<D, S>
where
S: TaskSerializer,
Expand Down Expand Up @@ -193,18 +216,11 @@ where
.await
.map_err(TaskQueueError::PostgresError)?;

// Serialize task data to JSON
let task_bytes = S::serialize_task(task)?;
let payload: serde_json::Value = serde_json::from_slice(&task_bytes)
.map_err(|e| TaskQueueError::Serialization(e.to_string()))?;

let status = match task.status {
TaskStatus::Queued => "Queued",
TaskStatus::InProgress => "InProgress",
TaskStatus::Completed => "Completed",
TaskStatus::Failed => "Failed",
TaskStatus::DeadLetter => "DeadLetter",
};
let status: PostgresTaskStatus = task.status.clone().into();

let result = sqlx::query(
r#"
Expand All @@ -219,7 +235,7 @@ where
"#,
)
.bind(payload)
.bind(status)
.bind(status) // Now we're binding a proper Postgres enum type
.bind(task.started_at.map(|t| DateTime::<Utc>::from(t)))
.bind(task.finished_at.map(|t| DateTime::<Utc>::from(t)))
.bind(task.retries as i32)
Expand Down
9 changes: 4 additions & 5 deletions capp-queue/tests/postgres_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,16 @@ mod tests {
// Verify changes were saved correctly using regular query
let row = sqlx::query(
r#"
SELECT payload, status
FROM tasks
WHERE id = $1
"#,
SELECT payload, status::text as status
FROM tasks
WHERE id = $1
"#,
)
.bind(task.task_id.get())
.fetch_one(&queue.pool)
.await
.expect("Failed to fetch task");

// Get payload and status from row
let payload: serde_json::Value = row.get("payload");
let status: String = row.get("status");

Expand Down

0 comments on commit 5798f25

Please sign in to comment.