@@ -12,6 +12,8 @@ import (
1212 "github.com/lightningnetwork/lnd/lnrpc/walletrpc"
1313 "github.com/lightningnetwork/lnd/lnutils"
1414 "google.golang.org/grpc"
15+ "google.golang.org/grpc/codes"
16+ "google.golang.org/grpc/status"
1517)
1618
1719var (
@@ -70,16 +72,20 @@ type SignCoordinator struct {
7072 // signer has errored, and we can no longer process responses.
7173 receiveErrChan chan error
7274
73- // doneReceiving is closed when either party terminates and signals to
75+ // disconnected is closed when either party terminates and signals to
7476 // any pending requests that we'll no longer process the response for
7577 // that request.
76- doneReceiving chan struct {}
78+ disconnected chan struct {}
7779
7880 // quit is closed when lnd is shutting down.
7981 quit chan struct {}
8082
81- // clientConnected is sent over when the remote signer connects.
82- clientConnected chan struct {}
83+ // clientReady is closed and sent over when the remote signer is
84+ // connected and ready to accept requests (after the initial handshake).
85+ clientReady chan struct {}
86+
87+ // clientConnected is true if a remote signer is currently connected.
88+ clientConnected bool
8389
8490 // requestTimeout is the maximum time we will wait for a response from
8591 // the remote signer.
@@ -107,11 +113,14 @@ func NewSignCoordinator(requestTimeout time.Duration,
107113 s := & SignCoordinator {
108114 responses : respsMap ,
109115 receiveErrChan : make (chan error , 1 ),
110- doneReceiving : make (chan struct {}),
111- clientConnected : make ( chan struct {}) ,
116+ clientReady : make (chan struct {}),
117+ clientConnected : false ,
112118 quit : make (chan struct {}),
113119 requestTimeout : requestTimeout ,
114120 connectionTimeout : connectionTimeout ,
121+ // Note that the disconnected channel is not initialized here,
122+ // as no code listens to it until the Run method has been called
123+ // and set the field.
115124 }
116125
117126 // We initialize the atomic nextRequestID to the handshakeRequestID, as
@@ -132,25 +141,33 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
132141 s .mu .Unlock ()
133142 return ErrShuttingDown
134143
135- case <- s .doneReceiving :
136- s .mu .Unlock ()
137- return ErrNotConnected
138-
139144 default :
140145 }
141146
147+ if s .clientConnected {
148+ // If we already have a stream, we error out as we can only have
149+ // one connection at a time.
150+ return ErrMultipleConnections
151+ }
152+
142153 s .wg .Add (1 )
143154 defer s .wg .Done ()
144155
145- // If we already have a stream, we error out as we can only have one
146- // connection throughout the lifetime of the SignCoordinator.
147- if s .stream != nil {
148- s .mu .Unlock ()
149- return ErrMultipleConnections
150- }
156+ s .clientConnected = true
157+ defer func () {
158+ s .mu .Lock ()
159+ defer s .mu .Unlock ()
160+
161+ // When `Run` returns, we set the clientConnected field to false
162+ // to allow a new remote signer connection to be set up.
163+ s .clientConnected = false
164+ }()
151165
152166 s .stream = stream
153167
168+ s .disconnected = make (chan struct {})
169+ defer close (s .disconnected )
170+
154171 s .mu .Unlock ()
155172
156173 // The handshake must be completed before we can start sending requests
@@ -160,8 +177,18 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
160177 return err
161178 }
162179
163- log .Infof ("Remote signer connected" )
164- close (s .clientConnected )
180+ log .Infof ("Remote signer connected and ready" )
181+
182+ close (s .clientReady )
183+ defer func () {
184+ s .mu .Lock ()
185+ defer s .mu .Unlock ()
186+
187+ // We create a new clientReady channel, once this function
188+ // has exited, to ensure that a new remote signer connection can
189+ // be set up.
190+ s .clientReady = make (chan struct {})
191+ }()
165192
166193 // Now let's start the main receiving loop, which will receive all
167194 // responses to our requests from the remote signer!
@@ -179,9 +206,6 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
179206
180207 case <- s .quit :
181208 return ErrShuttingDown
182-
183- case <- s .doneReceiving :
184- return ErrNotConnected
185209 }
186210}
187211
@@ -364,10 +388,6 @@ func (s *SignCoordinator) handshake(stream StreamServer) error {
364388func (s * SignCoordinator ) StartReceiving () {
365389 defer s .wg .Done ()
366390
367- // Signals to any ongoing requests that the remote signer is no longer
368- // connected.
369- defer close (s .doneReceiving )
370-
371391 for {
372392 resp , err := s .stream .Recv ()
373393 if err != nil {
@@ -426,8 +446,16 @@ func (s *SignCoordinator) StartReceiving() {
426446// signer does not connect within the configured connection timeout, or if the
427447// passed context is canceled, an error is returned.
428448func (s * SignCoordinator ) WaitUntilConnected (ctx context.Context ) error {
449+ // As the Run method will redefine the clientReady channel once it
450+ // returns, we need copy the pointer to the current clientReady channel
451+ // to ensure that we're waiting for the correct channel, and to avoid
452+ // a data race.
453+ s .mu .Lock ()
454+ currentClientReady := s .clientReady
455+ s .mu .Unlock ()
456+
429457 select {
430- case <- s . clientConnected :
458+ case <- currentClientReady :
431459 return nil
432460
433461 case <- s .quit :
@@ -438,9 +466,6 @@ func (s *SignCoordinator) WaitUntilConnected(ctx context.Context) error {
438466
439467 case <- time .After (s .connectionTimeout ):
440468 return ErrConnectTimeout
441-
442- case <- s .doneReceiving :
443- return ErrNotConnected
444469 }
445470}
446471
@@ -524,7 +549,7 @@ func (s *SignCoordinator) getResponse(ctx context.Context,
524549
525550 return resp , nil
526551
527- case <- s .doneReceiving :
552+ case <- s .disconnected :
528553 log .Debugf ("Stopped waiting for remote signer response for " +
529554 "request ID %d as the stream has been closed" ,
530555 requestID )
@@ -848,8 +873,30 @@ func processRequest[R comparable](ctx context.Context, s *SignCoordinator,
848873
849874 log .Tracef ("Request content: %v" , formatSignCoordinatorMsg (& req ))
850875
876+ // reprocessOnDisconnect is a helper function that will be used to
877+ // resend the request if the remote signer disconnects, through which
878+ // we will wait for it to reconnect within the configured timeout, and
879+ // then resend the request.
880+ reprocessOnDisconnect := func () (R , error ) {
881+ log .Debugf ("Remote signer disconnected while waiting for " +
882+ "response for request ID %d. Retrying request..." ,
883+ reqID )
884+
885+ return processRequest [R ](
886+ ctx , s , generateRequest , extractResponse ,
887+ )
888+ }
889+
851890 err = s .stream .Send (& req )
852891 if err != nil {
892+ st , isStatusError := status .FromError (err )
893+ if isStatusError && st .Code () == codes .Unavailable {
894+ // If the stream was closed due to the remote signer
895+ // disconnecting, we will retry to process the request
896+ // if the remote signer reconnects.
897+ return reprocessOnDisconnect ()
898+ }
899+
853900 return zero , err
854901 }
855902
@@ -860,7 +907,12 @@ func processRequest[R comparable](ctx context.Context, s *SignCoordinator,
860907 // cancelled/timed out.
861908 resp , err = s .getResponse (ctx , reqID )
862909
863- if err != nil {
910+ if errors .Is (err , ErrNotConnected ) {
911+ // If the remote signer disconnected while we were waiting for
912+ // the response, we will retry to process the request if the
913+ // remote signer reconnects.
914+ return reprocessOnDisconnect ()
915+ } else if err != nil {
864916 return zero , err
865917 }
866918
0 commit comments