Skip to content

Commit 50de404

Browse files
committed
rpcwallet: allow remote signer to reconnect
Allow the remote signer to reconnect to the wallet after disconnecting, as long as the remote signer reconnects within the timeout limit. This is not a complete solution to the problem to allow the watch-only node to stay online when the remote signer is disconnected, but is more fault-tolerant than the current implementation as it allows the remote to be temporarily disconnected.
1 parent dcd0209 commit 50de404

File tree

2 files changed

+345
-51
lines changed

2 files changed

+345
-51
lines changed

lnwallet/rpcwallet/sign_coordinator.go

Lines changed: 83 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
1313
"github.com/lightningnetwork/lnd/lnutils"
1414
"google.golang.org/grpc"
15+
"google.golang.org/grpc/codes"
16+
"google.golang.org/grpc/status"
1517
)
1618

1719
var (
@@ -70,16 +72,20 @@ type SignCoordinator struct {
7072
// signer has errored, and we can no longer process responses.
7173
receiveErrChan chan error
7274

73-
// doneReceiving is closed when either party terminates and signals to
75+
// disconnected is closed when either party terminates and signals to
7476
// any pending requests that we'll no longer process the response for
7577
// that request.
76-
doneReceiving chan struct{}
78+
disconnected chan struct{}
7779

7880
// quit is closed when lnd is shutting down.
7981
quit chan struct{}
8082

81-
// clientConnected is sent over when the remote signer connects.
82-
clientConnected chan struct{}
83+
// clientReady is closed and sent over when the remote signer is
84+
// connected and ready to accept requests (after the initial handshake).
85+
clientReady chan struct{}
86+
87+
// clientConnected is true if a remote signer is currently connected.
88+
clientConnected bool
8389

8490
// requestTimeout is the maximum time we will wait for a response from
8591
// the remote signer.
@@ -107,11 +113,14 @@ func NewSignCoordinator(requestTimeout time.Duration,
107113
s := &SignCoordinator{
108114
responses: respsMap,
109115
receiveErrChan: make(chan error, 1),
110-
doneReceiving: make(chan struct{}),
111-
clientConnected: make(chan struct{}),
116+
clientReady: make(chan struct{}),
117+
clientConnected: false,
112118
quit: make(chan struct{}),
113119
requestTimeout: requestTimeout,
114120
connectionTimeout: connectionTimeout,
121+
// Note that the disconnected channel is not initialized here,
122+
// as no code listens to it until the Run method has been called
123+
// and set the field.
115124
}
116125

117126
// We initialize the atomic nextRequestID to the handshakeRequestID, as
@@ -132,25 +141,33 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
132141
s.mu.Unlock()
133142
return ErrShuttingDown
134143

135-
case <-s.doneReceiving:
136-
s.mu.Unlock()
137-
return ErrNotConnected
138-
139144
default:
140145
}
141146

147+
if s.clientConnected {
148+
// If we already have a stream, we error out as we can only have
149+
// one connection at a time.
150+
return ErrMultipleConnections
151+
}
152+
142153
s.wg.Add(1)
143154
defer s.wg.Done()
144155

145-
// If we already have a stream, we error out as we can only have one
146-
// connection throughout the lifetime of the SignCoordinator.
147-
if s.stream != nil {
148-
s.mu.Unlock()
149-
return ErrMultipleConnections
150-
}
156+
s.clientConnected = true
157+
defer func() {
158+
s.mu.Lock()
159+
defer s.mu.Unlock()
160+
161+
// When `Run` returns, we set the clientConnected field to false
162+
// to allow a new remote signer connection to be set up.
163+
s.clientConnected = false
164+
}()
151165

152166
s.stream = stream
153167

168+
s.disconnected = make(chan struct{})
169+
defer close(s.disconnected)
170+
154171
s.mu.Unlock()
155172

156173
// The handshake must be completed before we can start sending requests
@@ -160,8 +177,18 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
160177
return err
161178
}
162179

163-
log.Infof("Remote signer connected")
164-
close(s.clientConnected)
180+
log.Infof("Remote signer connected and ready")
181+
182+
close(s.clientReady)
183+
defer func() {
184+
s.mu.Lock()
185+
defer s.mu.Unlock()
186+
187+
// We create a new clientReady channel, once this function
188+
// has exited, to ensure that a new remote signer connection can
189+
// be set up.
190+
s.clientReady = make(chan struct{})
191+
}()
165192

166193
// Now let's start the main receiving loop, which will receive all
167194
// responses to our requests from the remote signer!
@@ -179,9 +206,6 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
179206

180207
case <-s.quit:
181208
return ErrShuttingDown
182-
183-
case <-s.doneReceiving:
184-
return ErrNotConnected
185209
}
186210
}
187211

@@ -364,10 +388,6 @@ func (s *SignCoordinator) handshake(stream StreamServer) error {
364388
func (s *SignCoordinator) StartReceiving() {
365389
defer s.wg.Done()
366390

367-
// Signals to any ongoing requests that the remote signer is no longer
368-
// connected.
369-
defer close(s.doneReceiving)
370-
371391
for {
372392
resp, err := s.stream.Recv()
373393
if err != nil {
@@ -426,8 +446,16 @@ func (s *SignCoordinator) StartReceiving() {
426446
// signer does not connect within the configured connection timeout, or if the
427447
// passed context is canceled, an error is returned.
428448
func (s *SignCoordinator) WaitUntilConnected(ctx context.Context) error {
449+
// As the Run method will redefine the clientReady channel once it
450+
// returns, we need copy the pointer to the current clientReady channel
451+
// to ensure that we're waiting for the correct channel, and to avoid
452+
// a data race.
453+
s.mu.Lock()
454+
currentClientReady := s.clientReady
455+
s.mu.Unlock()
456+
429457
select {
430-
case <-s.clientConnected:
458+
case <-currentClientReady:
431459
return nil
432460

433461
case <-s.quit:
@@ -438,9 +466,6 @@ func (s *SignCoordinator) WaitUntilConnected(ctx context.Context) error {
438466

439467
case <-time.After(s.connectionTimeout):
440468
return ErrConnectTimeout
441-
442-
case <-s.doneReceiving:
443-
return ErrNotConnected
444469
}
445470
}
446471

@@ -524,7 +549,7 @@ func (s *SignCoordinator) getResponse(ctx context.Context,
524549

525550
return resp, nil
526551

527-
case <-s.doneReceiving:
552+
case <-s.disconnected:
528553
log.Debugf("Stopped waiting for remote signer response for "+
529554
"request ID %d as the stream has been closed",
530555
requestID)
@@ -848,8 +873,30 @@ func processRequest[R comparable](ctx context.Context, s *SignCoordinator,
848873

849874
log.Tracef("Request content: %v", formatSignCoordinatorMsg(&req))
850875

876+
// reprocessOnDisconnect is a helper function that will be used to
877+
// resend the request if the remote signer disconnects, through which
878+
// we will wait for it to reconnect within the configured timeout, and
879+
// then resend the request.
880+
reprocessOnDisconnect := func() (R, error) {
881+
log.Debugf("Remote signer disconnected while waiting for "+
882+
"response for request ID %d. Retrying request...",
883+
reqID)
884+
885+
return processRequest[R](
886+
ctx, s, generateRequest, extractResponse,
887+
)
888+
}
889+
851890
err = s.stream.Send(&req)
852891
if err != nil {
892+
st, isStatusError := status.FromError(err)
893+
if isStatusError && st.Code() == codes.Unavailable {
894+
// If the stream was closed due to the remote signer
895+
// disconnecting, we will retry to process the request
896+
// if the remote signer reconnects.
897+
return reprocessOnDisconnect()
898+
}
899+
853900
return zero, err
854901
}
855902

@@ -860,7 +907,12 @@ func processRequest[R comparable](ctx context.Context, s *SignCoordinator,
860907
// cancelled/timed out.
861908
resp, err = s.getResponse(ctx, reqID)
862909

863-
if err != nil {
910+
if errors.Is(err, ErrNotConnected) {
911+
// If the remote signer disconnected while we were waiting for
912+
// the response, we will retry to process the request if the
913+
// remote signer reconnects.
914+
return reprocessOnDisconnect()
915+
} else if err != nil {
864916
return zero, err
865917
}
866918

0 commit comments

Comments
 (0)