From b85aa8c27119270cac6dd8880092fb538f768e3d Mon Sep 17 00:00:00 2001 From: Anton Mashko Date: Wed, 26 Jun 2024 15:45:28 +0300 Subject: [PATCH] fix: context create fucntions --- taskq.go | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/taskq.go b/taskq.go index 2e81090..a12a17d 100644 --- a/taskq.go +++ b/taskq.go @@ -28,6 +28,11 @@ type TaskQ struct { workerCount int32 workers chan worker + // TaskContextFunc is a function that returns a context for a task.Do execution function. + TaskContextFunc func(t Task) context.Context + // OnDequeueContextFunc is a function that returns a context for Queue.Dequeue function. + DequeueContextFunc func() context.Context + // OnEnqueueError is a function that is called when an error occurs during the Dequeue operation. OnDequeueError func(ctx context.Context, workerID uint64, err error) } @@ -56,7 +61,7 @@ func NewWithQueue(limit int, q Queue) *TaskQ { } } -func (t *TaskQ) triggerDequeue(ctx context.Context) bool { +func (t *TaskQ) triggerDequeue() bool { if atomic.LoadInt32(&t.isRunning) != 1 { return false } @@ -67,9 +72,15 @@ func (t *TaskQ) triggerDequeue(ctx context.Context) bool { return false } atomic.AddInt32(&t.workerCount, 1) - go func(ctx context.Context, w worker) { + go func(w worker) { + var qCtx context.Context for atomic.LoadInt32(&t.isStopped) != 1 { - task, err := t.queue.Dequeue(ctx) + if t.DequeueContextFunc != nil { + qCtx = t.DequeueContextFunc() + } else { + qCtx = context.Background() + } + task, err := t.queue.Dequeue(qCtx) if err != nil { if err == EmptyQueue { break @@ -77,14 +88,20 @@ func (t *TaskQ) triggerDequeue(ctx context.Context) bool { if t.OnDequeueError == nil { panic(err) } - t.OnDequeueError(ctx, w.id, err) + t.OnDequeueError(qCtx, w.id, err) break } - processTask(ctx, task) + var tCtx context.Context + if t.TaskContextFunc != nil { + tCtx = t.TaskContextFunc(task) + } else { + tCtx = context.Background() + } + processTask(tCtx, task) } t.workers <- w // return worker to pool atomic.AddInt32(&t.workerCount, -1) - }(ctx, w) + }(w) return true default: return false @@ -105,14 +122,14 @@ func (t *TaskQ) Enqueue(ctx context.Context, task Task) (int64, error) { return -1, err } - t.triggerDequeue(ctx) + t.triggerDequeue() return id, nil } func (t *TaskQ) triggerFreeWorkers(ctx context.Context) { count := len(t.workers) for i := 0; i < count; i++ { - if !t.triggerDequeue(ctx) { + if !t.triggerDequeue() { return } }