Skip to content

Commit e5ade6a

Browse files
committed
sweepbatcher: close the quit channel when the batcher is shutting down
1 parent c01e801 commit e5ade6a

File tree

3 files changed

+32
-20
lines changed

3 files changed

+32
-20
lines changed

sweepbatcher/sweep_batch.go

+14-4
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,11 @@ type batch struct {
197197
// main event loop.
198198
callLeave chan struct{}
199199

200-
// quit signals that the batch must stop.
200+
// stopped signals that the batch has stopped.
201+
stopped chan struct{}
202+
203+
// quit is owned by the parent batcher and signals that the batch must
204+
// stop.
201205
quit chan struct{}
202206

203207
// wallet is the wallet client used to create and publish the batch
@@ -261,6 +265,7 @@ type batchKit struct {
261265
purger Purger
262266
store BatcherStore
263267
log btclog.Logger
268+
quit chan struct{}
264269
}
265270

266271
// scheduleNextCall schedules the next call to the batch handler's main event
@@ -270,6 +275,9 @@ func (b *batch) scheduleNextCall() (func(), error) {
270275
case b.callEnter <- struct{}{}:
271276

272277
case <-b.quit:
278+
return func() {}, ErrBatcherShuttingDown
279+
280+
case <-b.stopped:
273281
return func() {}, ErrBatchShuttingDown
274282
}
275283

@@ -293,7 +301,8 @@ func NewBatch(cfg batchConfig, bk batchKit) *batch {
293301
errChan: make(chan error, 1),
294302
callEnter: make(chan struct{}),
295303
callLeave: make(chan struct{}),
296-
quit: make(chan struct{}),
304+
stopped: make(chan struct{}),
305+
quit: bk.quit,
297306
batchTxid: bk.batchTxid,
298307
wallet: bk.wallet,
299308
chainNotifier: bk.chainNotifier,
@@ -320,7 +329,8 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) *batch {
320329
errChan: make(chan error, 1),
321330
callEnter: make(chan struct{}),
322331
callLeave: make(chan struct{}),
323-
quit: make(chan struct{}),
332+
stopped: make(chan struct{}),
333+
quit: bk.quit,
324334
batchTxid: bk.batchTxid,
325335
batchPkScript: bk.batchPkScript,
326336
rbfCache: bk.rbfCache,
@@ -447,7 +457,7 @@ func (b *batch) Run(ctx context.Context) error {
447457
runCtx, cancel := context.WithCancel(ctx)
448458
defer func() {
449459
cancel()
450-
close(b.quit)
460+
close(b.stopped)
451461
b.wg.Wait()
452462
}()
453463

sweepbatcher/sweep_batcher.go

+3
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ func (b *Batcher) Run(ctx context.Context) error {
216216
runCtx, cancel := context.WithCancel(ctx)
217217
defer func() {
218218
cancel()
219+
close(b.quit)
219220

220221
for _, batch := range b.batches {
221222
batch.Wait()
@@ -379,6 +380,7 @@ func (b *Batcher) spinUpBatch(ctx context.Context) (*batch, error) {
379380
verifySchnorrSig: b.VerifySchnorrSig,
380381
purger: b.AddSweep,
381382
store: b.store,
383+
quit: b.quit,
382384
}
383385

384386
batch := NewBatch(cfg, batchKit)
@@ -461,6 +463,7 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error {
461463
purger: b.AddSweep,
462464
store: b.store,
463465
log: batchPrefixLogger(fmt.Sprintf("%d", batch.id)),
466+
quit: b.quit,
464467
}
465468

466469
cfg := batchConfig{

sweepbatcher/sweep_batcher_test.go

+15-16
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package sweepbatcher
22

33
import (
44
"context"
5-
"strings"
5+
"errors"
66
"testing"
77
"time"
88

@@ -43,6 +43,15 @@ var dummyNotifier = SpendNotifier{
4343
QuitChan: make(chan bool, ntfnBufferSize),
4444
}
4545

46+
func checkBatcherError(t *testing.T, err error) {
47+
if !errors.Is(err, context.Canceled) &&
48+
!errors.Is(err, ErrBatcherShuttingDown) &&
49+
!errors.Is(err, ErrBatchShuttingDown) {
50+
51+
require.NoError(t, err)
52+
}
53+
}
54+
4655
// TestSweepBatcherBatchCreation tests that sweep requests enter the expected
4756
// batch based on their timeout distance.
4857
func TestSweepBatcherBatchCreation(t *testing.T) {
@@ -60,9 +69,7 @@ func TestSweepBatcherBatchCreation(t *testing.T) {
6069
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
6170
go func() {
6271
err := batcher.Run(ctx)
63-
if !strings.Contains(err.Error(), "context canceled") {
64-
require.NoError(t, err)
65-
}
72+
checkBatcherError(t, err)
6673
}()
6774

6875
// Create a sweep request.
@@ -215,9 +222,7 @@ func TestSweepBatcherSimpleLifecycle(t *testing.T) {
215222
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
216223
go func() {
217224
err := batcher.Run(ctx)
218-
if !strings.Contains(err.Error(), "context canceled") {
219-
require.NoError(t, err)
220-
}
225+
checkBatcherError(t, err)
221226
}()
222227

223228
// Create a sweep request.
@@ -354,9 +359,7 @@ func TestSweepBatcherSweepReentry(t *testing.T) {
354359
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
355360
go func() {
356361
err := batcher.Run(ctx)
357-
if !strings.Contains(err.Error(), "context canceled") {
358-
require.NoError(t, err)
359-
}
362+
checkBatcherError(t, err)
360363
}()
361364

362365
// Create some sweep requests with timeouts not too far away, in order
@@ -561,9 +564,7 @@ func TestSweepBatcherNonWalletAddr(t *testing.T) {
561564
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
562565
go func() {
563566
err := batcher.Run(ctx)
564-
if !strings.Contains(err.Error(), "context canceled") {
565-
require.NoError(t, err)
566-
}
567+
checkBatcherError(t, err)
567568
}()
568569

569570
// Create a sweep request.
@@ -727,9 +728,7 @@ func TestSweepBatcherComposite(t *testing.T) {
727728
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
728729
go func() {
729730
err := batcher.Run(ctx)
730-
if !strings.Contains(err.Error(), "context canceled") {
731-
require.NoError(t, err)
732-
}
731+
checkBatcherError(t, err)
733732
}()
734733

735734
// Create a sweep request.

0 commit comments

Comments
 (0)