-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathcelery.go
300 lines (258 loc) · 8.42 KB
/
celery.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
// Package celery helps to work with Celery (place tasks in queues and execute them).
package celery
import (
"context"
"fmt"
"runtime/debug"
"time"
"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"github.com/marselester/gopher-celery/protocol"
"github.com/marselester/gopher-celery/redis"
)
// TaskF represents a Celery task implemented by the client.
// The error doesn't affect anything, it's logged though.
type TaskF func(ctx context.Context, p *TaskParam) error
// Middleware is a chainable behavior modifier for tasks.
// For example, a caller can collect task metrics.
type Middleware func(next TaskF) TaskF
// Broker is responsible for receiving and sending task messages.
// For example, it knows how to read a message from a given queue in Redis.
// The messages can be in defferent formats depending on Celery protocol version.
type Broker interface {
// Send puts a message to a queue.
// Note, the method is safe to call concurrently.
Send(msg []byte, queue string) error
// Observe sets the queues from which the tasks should be received.
// Note, the method is not concurrency safe.
Observe(queues []string)
// Receive returns a raw message from one of the queues.
// It blocks until there is a message available for consumption.
// Note, the method is not concurrency safe.
Receive() ([]byte, error)
}
// AsyncParam represents parameters for sending a task message.
type AsyncParam struct {
// Args is a list of arguments.
// It will be an empty list if not provided.
Args []interface{}
// Kwargs is a dictionary of keyword arguments.
// It will be an empty dictionary if not provided.
Kwargs map[string]interface{}
// Expires is an expiration date.
// If not provided the message will never expire.
Expires time.Time
}
// NewApp creates a Celery app.
// The default broker is Redis assumed to run on localhost.
// When producing tasks the default message serializer is json and protocol is v2.
func NewApp(options ...Option) *App {
app := App{
conf: Config{
logger: log.NewNopLogger(),
registry: protocol.NewSerializerRegistry(),
mime: protocol.MimeJSON,
protocol: protocol.V2,
maxWorkers: DefaultMaxWorkers,
},
task: make(map[string]TaskF),
taskQueue: make(map[string]string),
}
for _, opt := range options {
opt(&app.conf)
}
if app.conf.broker == nil {
app.conf.broker = redis.NewBroker()
}
return &app
}
// App is a Celery app to produce or consume tasks asynchronously.
type App struct {
// conf represents app settings.
conf Config
// task maps a Celery task path to a task itself, e.g.,
// "myproject.apps.myapp.tasks.mytask": TaskF.
task map[string]TaskF
// taskQueue helps to determine which queue a task belongs to, e.g.,
// "myproject.apps.myapp.tasks.mytask": "important".
taskQueue map[string]string
}
// Register associates the task with given Python path and queue.
// For example, when "myproject.apps.myapp.tasks.mytask"
// is seen in "important" queue, the TaskF task is executed.
//
// Note, the method is not concurrency safe.
// The tasks mustn't be registered after the app starts processing tasks.
func (a *App) Register(path, queue string, task TaskF) {
a.task[path] = task
a.taskQueue[path] = queue
}
// ApplyAsync sends a task message.
func (a *App) ApplyAsync(path, queue string, p *AsyncParam) error {
m := protocol.Task{
ID: uuid.NewString(),
Name: path,
Args: p.Args,
Kwargs: p.Kwargs,
Expires: p.Expires,
}
rawMsg, err := a.conf.registry.Encode(queue, a.conf.mime, a.conf.protocol, &m)
if err != nil {
return fmt.Errorf("failed to encode task message: %w", err)
}
if err = a.conf.broker.Send(rawMsg, queue); err != nil {
return fmt.Errorf("failed to send task message to broker: %w", err)
}
return nil
}
// Delay is a shortcut to send a task message,
// i.e., it places the task associated with given Python path into queue.
func (a *App) Delay(path, queue string, args ...interface{}) error {
m := protocol.Task{
ID: uuid.NewString(),
Name: path,
Args: args,
}
rawMsg, err := a.conf.registry.Encode(queue, a.conf.mime, a.conf.protocol, &m)
if err != nil {
return fmt.Errorf("failed to encode task message: %w", err)
}
if err = a.conf.broker.Send(rawMsg, queue); err != nil {
return fmt.Errorf("failed to send task message to broker: %w", err)
}
return nil
}
// Run launches the workers that process the tasks received from the broker.
// The call is blocking until ctx is cancelled.
// The caller mustn't register any new tasks at this point.
func (a *App) Run(ctx context.Context) error {
qq := make([]string, 0, len(a.taskQueue))
for k := range a.taskQueue {
qq = append(qq, a.taskQueue[k])
}
a.conf.broker.Observe(qq)
level.Debug(a.conf.logger).Log("msg", "observing queues", "queues", qq)
// Tasks are processed concurrently only if there are multiple workers.
if a.conf.maxWorkers <= 1 {
return a.syncRun(ctx)
}
g, ctx := errgroup.WithContext(ctx)
// There will be at most maxWorkers goroutines processing tasks, and one fetching them.
g.SetLimit(a.conf.maxWorkers + 1)
msgs := make(chan *protocol.Task, 1)
g.Go(func() error {
defer close(msgs)
// One goroutine fetching and decoding tasks from queues
// shouldn't be a bottleneck since the worker goroutines
// usually take seconds/minutes to complete.
for {
// Stop fetching tasks.
if ctx.Err() != nil {
return nil
}
rawMsg, err := a.conf.broker.Receive()
if err != nil {
return fmt.Errorf("failed to receive a raw task message: %w", err)
}
// No messages in the broker so far.
if rawMsg == nil {
continue
}
m, err := a.conf.registry.Decode(rawMsg)
if err != nil {
level.Error(a.conf.logger).Log("msg", "failed to decode task message", "rawmsg", rawMsg, "err", err)
continue
}
msgs <- m
}
})
go func() {
// Start a worker when there is a task.
for m := range msgs {
level.Debug(a.conf.logger).Log("msg", "task received", "name", m.Name)
if a.task[m.Name] == nil {
level.Debug(a.conf.logger).Log("msg", "unregistered task", "name", m.Name)
continue
}
if m.IsExpired() {
level.Debug(a.conf.logger).Log("msg", "task message expired", "name", m.Name)
continue
}
// Stop processing tasks.
if ctx.Err() != nil {
return
}
m := m
g.Go(func() error {
if err := a.executeTask(ctx, m); err != nil {
level.Error(a.conf.logger).Log("msg", "task failed", "taskmsg", m, "err", err)
} else {
level.Debug(a.conf.logger).Log("msg", "task succeeded", "name", m.Name)
}
return nil
})
}
}()
return g.Wait()
}
// syncRun processes tasks one by one.
// Note, it doesn't fetch a new task until the current one is finished.
func (a *App) syncRun(ctx context.Context) error {
for {
// Stop fetching and processing tasks.
if ctx.Err() != nil {
return nil
}
rawMsg, err := a.conf.broker.Receive()
if err != nil {
return fmt.Errorf("failed to receive a raw task message: %w", err)
}
// No messages in the broker so far.
if rawMsg == nil {
continue
}
m, err := a.conf.registry.Decode(rawMsg)
if err != nil {
level.Error(a.conf.logger).Log("msg", "failed to decode task message", "rawmsg", rawMsg, "err", err)
continue
}
level.Debug(a.conf.logger).Log("msg", "task received", "name", m.Name)
if a.task[m.Name] == nil {
level.Debug(a.conf.logger).Log("msg", "unregistered task", "name", m.Name)
continue
}
if m.IsExpired() {
level.Debug(a.conf.logger).Log("msg", "task message expired", "name", m.Name)
continue
}
if err = a.executeTask(ctx, m); err != nil {
level.Error(a.conf.logger).Log("msg", "task failed", "taskmsg", m, "err", err)
} else {
level.Debug(a.conf.logger).Log("msg", "task succeeded", "name", m.Name)
}
}
}
type contextKey int
const (
// ContextKeyTaskName is a context key to access task names.
ContextKeyTaskName contextKey = iota
)
// executeTask calls the task function with args and kwargs from the message.
// If the task panics, the stack trace is returned as an error.
func (a *App) executeTask(ctx context.Context, m *protocol.Task) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("unexpected task error: %v: %s", r, debug.Stack())
}
}()
task := a.task[m.Name]
// Use middlewares if a client provided them.
if a.conf.chain != nil {
task = a.conf.chain(task)
}
ctx = context.WithValue(ctx, ContextKeyTaskName, m.Name)
p := NewTaskParam(m.Args, m.Kwargs)
return task(ctx, p)
}