Skip to content

Commit 536ad5c

Browse files
committed
Protect rmap.SetAndWait against corner cases
Fix all tests
1 parent 76324a5 commit 536ad5c

File tree

3 files changed

+106
-42
lines changed

3 files changed

+106
-42
lines changed

pool/node_test.go

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,6 @@ func TestJobKeys(t *testing.T) {
131131
for _, job := range jobs {
132132
assert.Contains(t, allJobKeys, job.key, fmt.Sprintf("Job key %s not found in JobKeys", job.key))
133133
}
134-
135-
// Dispatch a job with an existing key to node1
136-
assert.NoError(t, node1.DispatchJob(ctx, "job1", []byte("updated payload")), "Failed to dispatch job with existing key")
137-
138-
// Check that the number of job keys hasn't changed
139-
updatedAllJobKeys := node1.JobKeys()
140-
assert.Equal(t, len(allJobKeys), len(updatedAllJobKeys), "Number of job keys shouldn't change when updating an existing job")
141134
}
142135

143136
func TestJobPayload(t *testing.T) {
@@ -176,14 +169,12 @@ func TestJobPayload(t *testing.T) {
176169
assert.False(t, ok, "Expected false for non-existent job")
177170
assert.Nil(t, payload, "Expected nil payload for non-existent job")
178171

179-
// Update existing job
180-
updatedPayload := []byte("updated payload")
181-
assert.NoError(t, node.DispatchJob(ctx, "job1", updatedPayload), "Failed to update existing job")
182-
183-
// Check if the payload was updated
172+
// Remove existing job
173+
assert.NoError(t, node.StopJob(ctx, "job1"))
174+
// Check if the payload was removed
184175
assert.Eventually(t, func() bool {
185-
payload, ok := node.JobPayload("job1")
186-
return ok && assert.Equal(t, updatedPayload, payload, "Payload was not updated correctly")
176+
_, ok := node.JobPayload("job1")
177+
return !ok
187178
}, max, delay, "Failed to get updated payload for job")
188179
}
189180

@@ -333,30 +324,17 @@ func TestDispatchJobRaceCondition(t *testing.T) {
333324
payload := []byte("test payload")
334325

335326
// Set a stale pending timestamp
336-
staleTS := time.Now().Add(-3 * node1.ackGracePeriod).UnixNano()
337-
_, err := node1.pendingJobsMap.Set(ctx, jobKey, strconv.FormatInt(staleTS, 10))
327+
staleTS := time.Now().Add(-time.Hour).UnixNano()
328+
_, err := node1.pendingJobsMap.SetAndWait(ctx, jobKey, strconv.FormatInt(staleTS, 10))
338329
require.NoError(t, err, "Failed to set stale pending timestamp")
330+
defer func() {
331+
_, err = node1.pendingJobsMap.Delete(ctx, jobKey)
332+
assert.NoError(t, err, "Failed to delete pending timestamp")
333+
}()
339334

340335
// Dispatch should succeed because pending timestamp is in the past
341-
err = node2.DispatchJob(ctx, jobKey, payload)
342-
assert.NoError(t, err, "Dispatch should succeed after pending timeout")
343-
})
344-
345-
t.Run("dispatch cleans up pending entry on failure", func(t *testing.T) {
346-
jobKey := "cleanup-job"
347-
payload := []byte("test payload")
348-
349-
// Corrupt the pool stream to force dispatch failure
350-
err := rdb.Del(ctx, "pulse:stream:"+poolStreamName(node1.PoolName)).Err()
351-
require.NoError(t, err, "Failed to delete pool stream")
352-
353-
// Attempt dispatch (should fail)
354336
err = node1.DispatchJob(ctx, jobKey, payload)
355-
require.Error(t, err, "Expected dispatch to fail")
356-
357-
// Verify pending entry was cleaned up
358-
_, exists := node1.pendingJobsMap.Get(jobKey)
359-
assert.False(t, exists, "Pending entry should be cleaned up after failed dispatch")
337+
assert.NoError(t, err, "Dispatch should succeed after pending timeout")
360338
})
361339

362340
t.Run("dispatch cleans up pending entry on success", func(t *testing.T) {
@@ -386,6 +364,27 @@ func TestDispatchJobRaceCondition(t *testing.T) {
386364
err = node1.DispatchJob(ctx, jobKey, payload)
387365
assert.NoError(t, err, "Dispatch should succeed with invalid pending timestamp")
388366
})
367+
368+
// Keep this test last, it destroys the stream
369+
t.Run("dispatch cleans up pending entry on failure", func(t *testing.T) {
370+
jobKey := "cleanup-job"
371+
payload := []byte("test payload")
372+
373+
// Corrupt the pool stream to force dispatch failure
374+
err := rdb.Del(ctx, "pulse:stream:"+poolStreamName(node1.PoolName)).Err()
375+
require.NoError(t, err, "Failed to delete pool stream")
376+
377+
// Attempt dispatch (should fail)
378+
err = node1.DispatchJob(ctx, jobKey, payload)
379+
require.Error(t, err, "Expected dispatch to fail")
380+
381+
// Verify pending entry was cleaned up
382+
require.Eventually(t, func() bool {
383+
_, exists := node1.pendingJobsMap.Get(jobKey)
384+
return !exists
385+
}, max, delay, "Pending entry should be cleaned up after failed dispatch")
386+
})
387+
389388
}
390389

391390
func TestNotifyWorker(t *testing.T) {

rmap/map.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ type (
2626
hashkey string // Redis hash key
2727
msgch <-chan *redis.Message // channel to receive map updates
2828
chans []chan EventKind // channels to send notifications
29-
ichan chan string // internal channel to send set notifications
29+
ichan chan setNotification // internal channel to send set notifications
3030
done chan struct{} // channel to signal shutdown
3131
wait sync.WaitGroup // wait for read goroutine to exit
3232
logger pulse.Logger // logger
@@ -52,6 +52,12 @@ type (
5252

5353
// EventKind is the type of map event.
5454
EventKind int
55+
56+
// setNotification is the type of internal notification sent when a key is set.
57+
setNotification struct {
58+
key string
59+
value string
60+
}
5561
)
5662

5763
const (
@@ -86,7 +92,7 @@ func Join(ctx context.Context, name string, rdb *redis.Client, opts ...MapOption
8692
Name: name,
8793
chankey: fmt.Sprintf("map:%s:updates", name),
8894
hashkey: fmt.Sprintf("map:%s:content", name),
89-
ichan: make(chan string, 1),
95+
ichan: make(chan setNotification, 100),
9096
done: make(chan struct{}),
9197
logger: o.Logger.WithPrefix("map", name),
9298
rdb: rdb,
@@ -231,11 +237,11 @@ func (sm *Map) SetAndWait(ctx context.Context, key, value string) (string, error
231237
select {
232238
case <-ctx.Done():
233239
return "", ctx.Err()
234-
case val, ok := <-sm.ichan:
240+
case ev, ok := <-sm.ichan:
235241
if !ok {
236242
return "", fmt.Errorf("pulse map: %s is stopped", sm.Name)
237243
}
238-
if val == value {
244+
if ev.key == key && ev.value == value {
239245
return prev, nil
240246
}
241247
}
@@ -562,10 +568,7 @@ func (sm *Map) run() {
562568
continue
563569
}
564570
sm.content[key] = val
565-
select {
566-
case sm.ichan <- val:
567-
default:
568-
}
571+
sm.ichan <- setNotification{key: key, value: val}
569572
sm.logger.Debug("set", "key", key, "val", val)
570573
}
571574
for _, c := range sm.chans {

rmap/map_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package rmap
33
import (
44
"bytes"
55
"context"
6+
"fmt"
67
"os"
78
"strings"
89
"sync"
@@ -136,6 +137,67 @@ func TestMapLocal(t *testing.T) {
136137
cleanup(t, m)
137138
}
138139

140+
func TestSetAndWait(t *testing.T) {
141+
rdb := redis.NewClient(&redis.Options{
142+
Addr: "localhost:6379",
143+
Password: redisPwd,
144+
})
145+
var buf Buffer
146+
ctx := context.Background()
147+
ctx = log.Context(ctx, log.WithOutput(&buf))
148+
log.FlushAndDisableBuffering(ctx)
149+
150+
// Join or create a replicated map
151+
m, err := Join(ctx, "test", rdb)
152+
if err != nil {
153+
if strings.Contains(err.Error(), "WRONGPASS") {
154+
t.Fatal("Unexpected Redis password error (did you set REDIS_PASSWORD?)")
155+
} else if strings.Contains(err.Error(), "connection refused") {
156+
t.Fatal("Unexpected Redis connection error (is Redis running?)")
157+
}
158+
}
159+
assert.NoError(t, err)
160+
assert.NotNil(t, m)
161+
assert.NoError(t, m.Reset(ctx))
162+
163+
// Test SetAndWait with new key
164+
old, err := m.SetAndWait(ctx, "key1", "value1")
165+
assert.NoError(t, err)
166+
assert.Equal(t, "", old)
167+
v, ok := m.Get("key1")
168+
assert.True(t, ok)
169+
assert.Equal(t, "value1", v)
170+
171+
// Test SetAndWait with existing key
172+
old, err = m.SetAndWait(ctx, "key1", "value2")
173+
assert.NoError(t, err)
174+
assert.Equal(t, "value1", old)
175+
v, ok = m.Get("key1")
176+
assert.True(t, ok)
177+
assert.Equal(t, "value2", v)
178+
179+
// Test many Set then SetAndWait then many Set
180+
for i := 0; i < 20; i++ {
181+
_, err := m.Set(ctx, fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i))
182+
assert.NoError(t, err)
183+
}
184+
_, err = m.SetAndWait(ctx, "key", "value")
185+
for i := 0; i < 20; i++ {
186+
_, err := m.Set(ctx, fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i))
187+
assert.NoError(t, err)
188+
}
189+
assert.NoError(t, err)
190+
191+
// Test SetAndWait with canceled context
192+
ctx2, cancel := context.WithCancel(ctx)
193+
cancel()
194+
_, err = m.SetAndWait(ctx2, "key2", "value3")
195+
assert.ErrorIs(t, err, context.Canceled)
196+
197+
// Cleanup
198+
cleanup(t, m)
199+
}
200+
139201
func TestReadAfterClose(t *testing.T) {
140202
rdb := redis.NewClient(&redis.Options{
141203
Addr: "localhost:6379",

0 commit comments

Comments
 (0)