Skip to content

Commit

Permalink
Merge pull request #41 from oiwn/dev
Browse files Browse the repository at this point in the history
round robin tests
  • Loading branch information
oiwn authored Dec 9, 2024
2 parents 00f82ab + 71299e5 commit 6a36bdf
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 35 deletions.
90 changes: 56 additions & 34 deletions capp-queue/src/backend/redis_rr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,14 @@ impl<D> RedisRoundRobinTaskQueue<D> {
Ok(results.first().map(|(tag, _score)| tag.clone()))
}

// Update tag's timestamp in schedule
async fn update_tag_timestamp(&self, tag: &str) -> Result<(), TaskQueueError> {
let timestamp = self.current_timestamp()?;
// Add a small increment to ensure proper ordering
let score = timestamp as f64 + 0.001;
self.client
.zadd(
self.get_schedule_key(),
vec![(timestamp as f64, tag)],
vec![(score, tag)],
ZAddOptions::default(),
)
.await?;
Expand Down Expand Up @@ -153,51 +154,72 @@ where
let tag = task.payload.get_tag_value().to_string();
let list_key = self.get_list_key(&tag);
let hashmap_key = self.get_hashmap_key();
let schedule_key = self.get_schedule_key();

let mut pipeline = self.client.create_pipeline();

// Add task to list and hashmap
pipeline
.lpush(&list_key, &task.task_id.to_string())
.forget();
pipeline
.hset(&hashmap_key, [(&task.task_id.to_string(), &task_json)])
.forget();

// Ensure tag exists in schedule with current timestamp if it doesn't exist
let timestamp = self.current_timestamp()?;
pipeline
.zadd(
schedule_key,
vec![(timestamp as f64, tag)],
ZAddOptions::default()
.condition(rustis::commands::ZAddCondition::NX), // Only add if not exists
)
.forget();

self.execute_pipeline(pipeline).await
}

async fn pop(&self) -> Result<Task<D>, TaskQueueError> {
let tag = self
.get_next_tag()
.await?
.ok_or(TaskQueueError::QueueEmpty)?;

let list_key = self.get_list_key(&tag);
let hashmap_key = self.get_hashmap_key();

// Try to get task from the selected tag's list
let task_ids: Vec<String> = self
.client
.rpop(&list_key, 1)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

if let Some(task_id) = task_ids.first() {
// Get task data from hash
let task_value: String =
self.client.hget(&hashmap_key, task_id).await?;

let task: Task<D> = serde_json::from_str(&task_value)
.map_err(|e| TaskQueueError::SerdeError(e.to_string()))?;

// Update tag's timestamp in schedule
self.update_tag_timestamp(&tag).await?;

Ok(task)
} else {
// If no tasks in this tag's list, remove it from schedule and try again
self.client.zrem(self.get_schedule_key(), tag).await?;

Err(TaskQueueError::QueueEmpty)
// Keep trying until we find a task or exhaust all tags
loop {
let tag = self
.get_next_tag()
.await?
.ok_or(TaskQueueError::QueueEmpty)?;

let list_key = self.get_list_key(&tag);
let hashmap_key = self.get_hashmap_key();

// Try to get task from the selected tag's list
let task_ids: Vec<String> = self
.client
.rpop(&list_key, 1)
.await
.map_err(|e| TaskQueueError::QueueError(e.to_string()))?;

if let Some(task_id) = task_ids.first() {
// Get task data from hash
let task_value: String =
self.client.hget(&hashmap_key, task_id).await?;

let task: Task<D> = serde_json::from_str(&task_value)
.map_err(|e| TaskQueueError::SerdeError(e.to_string()))?;

// Update tag's timestamp in schedule
self.update_tag_timestamp(&tag).await?;

return Ok(task);
}

// No tasks in this tag's list, remove it from schedule and continue
self.client.zrem(self.get_schedule_key(), &tag).await?;

// Check if we still have any tags in schedule
let count = self.client.zcard(self.get_schedule_key()).await?;
if count == 0 {
return Err(TaskQueueError::QueueEmpty);
}
}
}

Expand Down
16 changes: 15 additions & 1 deletion capp-queue/tests/redis_rr_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,21 @@ mod tests {
.expect("Error while establishing redis connection")
}

// Cleanup before each test
async fn cleanup_before_test(test_name: &str) {
let redis = get_redis_connection().await;
let pattern = format!("test-rr-{}*", test_name);
// Get all keys matching our test pattern
let keys: Vec<String> =
redis.keys(&pattern).await.expect("Failed to get keys");
if !keys.is_empty() {
redis.del(keys).await.expect("Failed to delete keys");
}
}

async fn setup_queue(test_name: &str) -> RedisRoundRobinTaskQueue<TestData> {
// cleanup tests if present
cleanup_before_test(test_name).await;
let redis = get_redis_connection().await;
let tags = HashSet::from([
"tag1".to_string(),
Expand Down Expand Up @@ -105,7 +119,7 @@ mod tests {

// Verify tasks are stored properly
let hashmap_len: u64 =
queue.client.hlen(&queue.get_hashmap_key()).await.unwrap() as u64;
queue.client.hlen(queue.get_hashmap_key()).await.unwrap() as u64;
assert_eq!(hashmap_len, 6, "All tasks should be in hashmap");

// Pop tasks and verify round-robin behavior
Expand Down

0 comments on commit 6a36bdf

Please sign in to comment.