diff --git a/fsm/example_fsm.go b/fsm/example_fsm.go index 9ba20599e..4373ce93b 100644 --- a/fsm/example_fsm.go +++ b/fsm/example_fsm.go @@ -1,6 +1,7 @@ package fsm import ( + "context" "fmt" ) @@ -90,7 +91,9 @@ type InitStuffRequest struct { } // initFSM is the action for the InitFSM state. -func (e *ExampleFSM) initFSM(eventCtx EventContext) EventType { +func (e *ExampleFSM) initFSM(_ context.Context, eventCtx EventContext, +) EventType { + req, ok := eventCtx.(*InitStuffRequest) if !ok { return e.HandleError( @@ -109,7 +112,9 @@ func (e *ExampleFSM) initFSM(eventCtx EventContext) EventType { } // waitForStuff is an action that waits for stuff to happen. -func (e *ExampleFSM) waitForStuff(eventCtx EventContext) EventType { +func (e *ExampleFSM) waitForStuff(ctx context.Context, eventCtx EventContext, +) EventType { + waitChan, err := e.service.WaitForStuffHappening() if err != nil { return e.HandleError(err) @@ -117,7 +122,7 @@ func (e *ExampleFSM) waitForStuff(eventCtx EventContext) EventType { go func() { <-waitChan - err := e.SendEvent(OnStuffSuccess, nil) + err := e.SendEvent(ctx, OnStuffSuccess, nil) if err != nil { log.Errorf("unable to send event: %v", err) } diff --git a/fsm/example_fsm_test.go b/fsm/example_fsm_test.go index a3e3f05dc..b2b9ebdac 100644 --- a/fsm/example_fsm_test.go +++ b/fsm/example_fsm_test.go @@ -79,6 +79,7 @@ func TestExampleFSM(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctxb := context.Background() respondChan := make(chan string, 1) if req, ok := tc.eventCtx.(*InitStuffRequest); ok { req.respondChan = respondChan @@ -102,7 +103,7 @@ func TestExampleFSM(t *testing.T) { exampleContext.RegisterObserver(cachedObserver) err := exampleContext.SendEvent( - tc.sendEvent, tc.eventCtx, + ctxb, tc.sendEvent, tc.eventCtx, ) require.Equal(t, tc.sendEventErr, err) @@ -195,6 +196,7 @@ func TestExampleFSMFlow(t *testing.T) { t.Run(tc.name, func(t *testing.T) { exampleContext, cachedObserver := getTestContext() + ctxb := context.Background() if tc.storeError != nil { exampleContext.store.(*mockStore). @@ -208,8 +210,7 @@ func TestExampleFSMFlow(t *testing.T) { go func() { err := exampleContext.SendEvent( - OnRequestStuff, - newInitStuffRequest(), + ctxb, OnRequestStuff, newInitStuffRequest(), ) require.NoError(t, err) @@ -273,6 +274,7 @@ func TestObserverAsyncWait(t *testing.T) { service := &mockService{ respondChan: make(chan bool), } + ctxb := context.Background() store := &mockStore{} @@ -282,7 +284,7 @@ func TestObserverAsyncWait(t *testing.T) { t0 := time.Now() timeoutCtx, cancel := context.WithTimeout( - context.Background(), tc.waitTime, + ctxb, tc.waitTime, ) defer cancel() @@ -293,8 +295,7 @@ func TestObserverAsyncWait(t *testing.T) { go func() { err := exampleContext.SendEvent( - OnRequestStuff, - newInitStuffRequest(), + ctxb, OnRequestStuff, newInitStuffRequest(), ) require.NoError(t, err) diff --git a/fsm/fsm.go b/fsm/fsm.go index f1088a767..a16a33567 100644 --- a/fsm/fsm.go +++ b/fsm/fsm.go @@ -1,6 +1,7 @@ package fsm import ( + "context" "errors" "fmt" "sync" @@ -45,7 +46,7 @@ type EventType string type EventContext interface{} // Action represents the action to be executed in a given state. -type Action func(eventCtx EventContext) EventType +type Action func(ctx context.Context, eventCtx EventContext) EventType // Transitions represents a mapping of events and states. type Transitions map[EventType]StateType @@ -95,11 +96,11 @@ type StateMachine struct { // ActionEntryFunc is a function that is called before an action is // executed. - ActionEntryFunc func(Notification) + ActionEntryFunc func(context.Context, Notification) // ActionExitFunc is a function that is called after an action is // executed, it is called with the EventType returned by the action. - ActionExitFunc func(NextEvent EventType) + ActionExitFunc func(ctx context.Context, NextEvent EventType) // LastActionError is an error set by the last action executed. LastActionError error @@ -200,7 +201,9 @@ func (s *StateMachine) getNextState(event EventType) (State, error) { // SendEvent sends an event to the state machine. It returns an error if the // event cannot be processed in the current state. Otherwise, it only returns // nil if the event for the last action is a no-op. -func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error { +func (s *StateMachine) SendEvent(ctx context.Context, event EventType, + eventCtx EventContext) error { + s.mutex.Lock() defer s.mutex.Unlock() @@ -235,7 +238,7 @@ func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error { // Execute the state machines ActionEntryFunc. if s.ActionEntryFunc != nil { - s.ActionEntryFunc(notification) + s.ActionEntryFunc(ctx, notification) } // Execute the current state's entry function @@ -245,7 +248,7 @@ func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error { // Execute the next state's action and loop over again if the // event returned is not a no-op. - nextEvent := state.Action(eventCtx) + nextEvent := state.Action(ctx, eventCtx) // Execute the current state's exit function if state.ExitFunc != nil { @@ -254,7 +257,7 @@ func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error { // Execute the state machines ActionExitFunc. if s.ActionExitFunc != nil { - s.ActionExitFunc(nextEvent) + s.ActionExitFunc(ctx, nextEvent) } // If the next event is a no-op, we're done. @@ -304,7 +307,7 @@ func (s *StateMachine) HandleError(err error) EventType { // NoOpAction is a no-op action that can be used by states that don't need to // execute any action. -func NoOpAction(_ EventContext) EventType { +func NoOpAction(_ context.Context, _ EventContext) EventType { return NoOp } diff --git a/fsm/fsm_test.go b/fsm/fsm_test.go index 23361a935..8864503d0 100644 --- a/fsm/fsm_test.go +++ b/fsm/fsm_test.go @@ -1,6 +1,7 @@ package fsm import ( + "context" "errors" "testing" @@ -22,7 +23,7 @@ type TestStateMachineContext struct { func (c *TestStateMachineContext) GetStates() States { return States{ "State1": State{ - Action: func(ctx EventContext) EventType { + Action: func(_ context.Context, ctx EventContext) EventType { return "Event1" }, Transitions: Transitions{ @@ -30,7 +31,7 @@ func (c *TestStateMachineContext) GetStates() States { }, }, "State2": State{ - Action: func(ctx EventContext) EventType { + Action: func(_ context.Context, ctx EventContext) EventType { return "NoOp" }, Transitions: Transitions{}, @@ -39,7 +40,9 @@ func (c *TestStateMachineContext) GetStates() States { } // errorAction returns an error. -func (c *TestStateMachineContext) errorAction(eventCtx EventContext) EventType { +func (c *TestStateMachineContext) errorAction(ctx context.Context, + eventCtx EventContext) EventType { + return c.StateMachine.HandleError(errAction) } @@ -58,9 +61,9 @@ func setupTestStateMachineContext() *TestStateMachineContext { // TestStateMachine_Success tests the state machine with a successful event. func TestStateMachine_Success(t *testing.T) { ctx := setupTestStateMachineContext() - + ctxb := context.Background() // Send an event to the state machine. - err := ctx.SendEvent("Event1", nil) + err := ctx.SendEvent(ctxb, "Event1", nil) require.NoError(t, err) // Check that the state machine has transitioned to the next state. @@ -72,8 +75,9 @@ func TestStateMachine_Success(t *testing.T) { func TestStateMachine_ConfigurationError(t *testing.T) { ctx := setupTestStateMachineContext() ctx.StateMachine.States = nil + ctxb := context.Background() - err := ctx.SendEvent("Event1", nil) + err := ctx.SendEvent(ctxb, "Event1", nil) require.EqualError( t, err, NewErrConfigError("state machine config is nil").Error(), @@ -83,6 +87,7 @@ func TestStateMachine_ConfigurationError(t *testing.T) { // TestStateMachine_ActionError tests the state machine with an action error. func TestStateMachine_ActionError(t *testing.T) { ctx := setupTestStateMachineContext() + ctxb := context.Background() states := ctx.StateMachine.States @@ -99,13 +104,13 @@ func TestStateMachine_ActionError(t *testing.T) { } states["ErrorState"] = State{ - Action: func(ctx EventContext) EventType { + Action: func(_ context.Context, ctx EventContext) EventType { return "NoOp" }, Transitions: Transitions{}, } - err := ctx.SendEvent("Event1", nil) + err := ctx.SendEvent(ctxb, "Event1", nil) // Sending an event to the state machine should not return an error. require.NoError(t, err) diff --git a/instantout/actions.go b/instantout/actions.go index 334898971..eb8a64152 100644 --- a/instantout/actions.go +++ b/instantout/actions.go @@ -65,7 +65,9 @@ type InitInstantOutCtx struct { // InitInstantOutAction is the first action that is executed when the instant // out FSM is started. It will send the instant out request to the server. -func (f *FSM) InitInstantOutAction(eventCtx fsm.EventContext) fsm.EventType { +func (f *FSM) InitInstantOutAction(ctx context.Context, + eventCtx fsm.EventContext) fsm.EventType { + initCtx, ok := eventCtx.(*InitInstantOutCtx) if !ok { return f.HandleError(fsm.ErrInvalidContextType) @@ -86,9 +88,7 @@ func (f *FSM) InitInstantOutAction(eventCtx fsm.EventContext) fsm.EventType { // The requested amount needs to be full reservation amounts. for _, reservationId := range initCtx.reservations { resId := reservationId - res, err := f.cfg.ReservationManager.GetReservation( - f.ctx, resId, - ) + res, err := f.cfg.ReservationManager.GetReservation(ctx, resId) if err != nil { return f.HandleError(err) } @@ -120,7 +120,7 @@ func (f *FSM) InitInstantOutAction(eventCtx fsm.EventContext) fsm.EventType { } // Create the keys for the swap. - keyRes, err := f.cfg.Wallet.DeriveNextKey(f.ctx, KeyFamily) + keyRes, err := f.cfg.Wallet.DeriveNextKey(ctx, KeyFamily) if err != nil { return f.HandleError(err) } @@ -128,7 +128,7 @@ func (f *FSM) InitInstantOutAction(eventCtx fsm.EventContext) fsm.EventType { swapHash := preimage.Hash() // Create a high fee rate so that the htlc will be confirmed quickly. - feeRate, err := f.cfg.Wallet.EstimateFeeRate(f.ctx, urgentConfTarget) + feeRate, err := f.cfg.Wallet.EstimateFeeRate(ctx, urgentConfTarget) if err != nil { f.Infof("error estimating fee rate: %v", err) return f.HandleError(err) @@ -136,7 +136,7 @@ func (f *FSM) InitInstantOutAction(eventCtx fsm.EventContext) fsm.EventType { // Send the instantout request to the server. instantOutResponse, err := f.cfg.InstantOutClient.RequestInstantLoopOut( - f.ctx, + ctx, &swapserverrpc.InstantLoopOutRequest{ ReceiverKey: keyRes.PubKey.SerializeCompressed(), SwapHash: swapHash[:], @@ -151,7 +151,7 @@ func (f *FSM) InitInstantOutAction(eventCtx fsm.EventContext) fsm.EventType { } // Decode the invoice to check if the hash is valid. payReq, err := f.cfg.LndClient.DecodePaymentRequest( - f.ctx, instantOutResponse.SwapInvoice, + ctx, instantOutResponse.SwapInvoice, ) if err != nil { return f.HandleError(err) @@ -170,7 +170,7 @@ func (f *FSM) InitInstantOutAction(eventCtx fsm.EventContext) fsm.EventType { sweepAddress := initCtx.sweepAddress if sweepAddress == nil { sweepAddress, err = f.cfg.Wallet.NextAddr( - f.ctx, lnwallet.DefaultAccountName, + ctx, lnwallet.DefaultAccountName, walletrpc.AddressType_TAPROOT_PUBKEY, false, ) if err != nil { @@ -196,7 +196,7 @@ func (f *FSM) InitInstantOutAction(eventCtx fsm.EventContext) fsm.EventType { sweepAddress: sweepAddress, } - err = f.cfg.Store.CreateInstantLoopOut(f.ctx, instantOut) + err = f.cfg.Store.CreateInstantLoopOut(ctx, instantOut) if err != nil { return f.HandleError(err) } @@ -208,21 +208,23 @@ func (f *FSM) InitInstantOutAction(eventCtx fsm.EventContext) fsm.EventType { // PollPaymentAcceptedAction locks the reservations, sends the payment to the // server and polls the server for the payment status. -func (f *FSM) PollPaymentAcceptedAction(_ fsm.EventContext) fsm.EventType { +func (f *FSM) PollPaymentAcceptedAction(ctx context.Context, + _ fsm.EventContext) fsm.EventType { + // Now that we're doing the swap, we first lock the reservations // so that they can't be used for other swaps. for _, reservation := range f.InstantOut.Reservations { err := f.cfg.ReservationManager.LockReservation( - f.ctx, reservation.ID, + ctx, reservation.ID, ) if err != nil { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } } // Now we send the payment to the server. payChan, paymentErrChan, err := f.cfg.RouterClient.SendPayment( - f.ctx, + ctx, lndclient.SendPaymentRequest{ Invoice: f.InstantOut.swapInvoice, Timeout: defaultSendpaymentTimeout, @@ -232,7 +234,7 @@ func (f *FSM) PollPaymentAcceptedAction(_ fsm.EventContext) fsm.EventType { ) if err != nil { f.Errorf("error sending payment: %v", err) - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } // We'll continuously poll the server for the payment status. @@ -246,20 +248,20 @@ func (f *FSM) PollPaymentAcceptedAction(_ fsm.EventContext) fsm.EventType { f.Debugf("payment result: %v", payRes) if payRes.State == lnrpc.Payment_FAILED { return f.handleErrorAndUnlockReservations( - fmt.Errorf("payment failed: %v", + ctx, fmt.Errorf("payment failed: %v", payRes.FailureReason), ) } case err := <-paymentErrChan: f.Errorf("error sending payment: %v", err) - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) - case <-f.ctx.Done(): - return f.handleErrorAndUnlockReservations(nil) + case <-ctx.Done(): + return f.handleErrorAndUnlockReservations(ctx, nil) case <-timer.C: res, err := f.cfg.InstantOutClient.PollPaymentAccepted( - f.ctx, + ctx, &swapserverrpc.PollPaymentAcceptedRequest{ SwapHash: f.InstantOut.SwapHash[:], }, @@ -267,7 +269,7 @@ func (f *FSM) PollPaymentAcceptedAction(_ fsm.EventContext) fsm.EventType { if err != nil { pollPaymentTries++ if pollPaymentTries > 20 { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } } if res != nil && res.Accepted { @@ -280,74 +282,76 @@ func (f *FSM) PollPaymentAcceptedAction(_ fsm.EventContext) fsm.EventType { // BuildHTLCAction creates the htlc transaction, exchanges nonces with // the server and sends the htlc signatures to the server. -func (f *FSM) BuildHTLCAction(eventCtx fsm.EventContext) fsm.EventType { +func (f *FSM) BuildHTLCAction(ctx context.Context, + eventCtx fsm.EventContext) fsm.EventType { + htlcSessions, htlcClientNonces, err := f.InstantOut.createMusig2Session( - f.ctx, f.cfg.Signer, + ctx, f.cfg.Signer, ) if err != nil { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } f.htlcMusig2Sessions = htlcSessions // Send the server the client nonces. htlcInitRes, err := f.cfg.InstantOutClient.InitHtlcSig( - f.ctx, + ctx, &swapserverrpc.InitHtlcSigRequest{ SwapHash: f.InstantOut.SwapHash[:], HtlcClientNonces: htlcClientNonces, }, ) if err != nil { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } if len(htlcInitRes.HtlcServerNonces) != len(f.InstantOut.Reservations) { return f.handleErrorAndUnlockReservations( - errors.New("invalid number of server nonces"), + ctx, errors.New("invalid number of server nonces"), ) } htlcServerNonces, err := toNonces(htlcInitRes.HtlcServerNonces) if err != nil { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } // Now that our nonces are set, we can create and sign the htlc // transaction. htlcTx, err := f.InstantOut.createHtlcTransaction(f.cfg.Network) if err != nil { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } // Next we'll get our sweep tx signatures. htlcSigs, err := f.InstantOut.signMusig2Tx( - f.ctx, f.cfg.Signer, htlcTx, f.htlcMusig2Sessions, + ctx, f.cfg.Signer, htlcTx, f.htlcMusig2Sessions, htlcServerNonces, ) if err != nil { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } // Send the server the htlc signatures. htlcRes, err := f.cfg.InstantOutClient.PushHtlcSig( - f.ctx, + ctx, &swapserverrpc.PushHtlcSigRequest{ SwapHash: f.InstantOut.SwapHash[:], ClientSigs: htlcSigs, }, ) if err != nil { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } // We can now finalize the htlc transaction. htlcTx, err = f.InstantOut.finalizeMusig2Transaction( - f.ctx, f.cfg.Signer, f.htlcMusig2Sessions, htlcTx, + ctx, f.cfg.Signer, f.htlcMusig2Sessions, htlcTx, htlcRes.ServerSigs, ) if err != nil { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } f.InstantOut.finalizedHtlcTx = htlcTx @@ -359,25 +363,27 @@ func (f *FSM) BuildHTLCAction(eventCtx fsm.EventContext) fsm.EventType { // sweepless sweep transaction and sends the signatures to the server. Finally, // it publishes the sweepless sweep transaction. If any of the steps after // pushing the preimage fail, the htlc timeout transaction will be published. -func (f *FSM) PushPreimageAction(eventCtx fsm.EventContext) fsm.EventType { +func (f *FSM) PushPreimageAction(ctx context.Context, + eventCtx fsm.EventContext) fsm.EventType { + // First we'll create the musig2 context. coopSessions, coopClientNonces, err := f.InstantOut.createMusig2Session( - f.ctx, f.cfg.Signer, + ctx, f.cfg.Signer, ) if err != nil { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } f.sweeplessSweepSessions = coopSessions // Get the feerate for the coop sweep. - feeRate, err := f.cfg.Wallet.EstimateFeeRate(f.ctx, normalConfTarget) + feeRate, err := f.cfg.Wallet.EstimateFeeRate(ctx, normalConfTarget) if err != nil { - return f.handleErrorAndUnlockReservations(err) + return f.handleErrorAndUnlockReservations(ctx, err) } pushPreImageRes, err := f.cfg.InstantOutClient.PushPreimage( - f.ctx, + ctx, &swapserverrpc.PushPreimageRequest{ Preimage: f.InstantOut.swapPreimage[:], ClientNonces: coopClientNonces, @@ -408,7 +414,7 @@ func (f *FSM) PushPreimageAction(eventCtx fsm.EventContext) fsm.EventType { // Next we'll get our sweep tx signatures. _, err = f.InstantOut.signMusig2Tx( - f.ctx, f.cfg.Signer, sweepTx, f.sweeplessSweepSessions, + ctx, f.cfg.Signer, sweepTx, f.sweeplessSweepSessions, coopServerNonces, ) if err != nil { @@ -418,7 +424,7 @@ func (f *FSM) PushPreimageAction(eventCtx fsm.EventContext) fsm.EventType { // Now we'll finalize the sweepless sweep transaction. sweepTx, err = f.InstantOut.finalizeMusig2Transaction( - f.ctx, f.cfg.Signer, f.sweeplessSweepSessions, sweepTx, + ctx, f.cfg.Signer, f.sweeplessSweepSessions, sweepTx, pushPreImageRes.Musig2SweepSigs, ) if err != nil { @@ -430,7 +436,7 @@ func (f *FSM) PushPreimageAction(eventCtx fsm.EventContext) fsm.EventType { f.InstantOut.swapPreimage.Hash()) // Publish the sweepless sweep transaction. - err = f.cfg.Wallet.PublishTransaction(f.ctx, sweepTx, txLabel) + err = f.cfg.Wallet.PublishTransaction(ctx, sweepTx, txLabel) if err != nil { f.LastActionError = err return OnErrorPublishHtlc @@ -446,7 +452,7 @@ func (f *FSM) PushPreimageAction(eventCtx fsm.EventContext) fsm.EventType { // WaitForSweeplessSweepConfirmedAction waits for the sweepless sweep // transaction to be confirmed. -func (f *FSM) WaitForSweeplessSweepConfirmedAction( +func (f *FSM) WaitForSweeplessSweepConfirmedAction(ctx context.Context, eventCtx fsm.EventContext) fsm.EventType { pkscript, err := txscript.PayToAddrScript(f.InstantOut.sweepAddress) @@ -456,7 +462,7 @@ func (f *FSM) WaitForSweeplessSweepConfirmedAction( confChan, confErrChan, err := f.cfg.ChainNotifier. RegisterConfirmationsNtfn( - f.ctx, f.InstantOut.SweepTxHash, pkscript, + ctx, f.InstantOut.SweepTxHash, pkscript, 1, f.InstantOut.initiationHeight, ) if err != nil { @@ -483,11 +489,13 @@ func (f *FSM) WaitForSweeplessSweepConfirmedAction( // PublishHtlcAction publishes the htlc transaction and the htlc sweep // transaction. -func (f *FSM) PublishHtlcAction(eventCtx fsm.EventContext) fsm.EventType { +func (f *FSM) PublishHtlcAction(ctx context.Context, + eventCtx fsm.EventContext) fsm.EventType { + // Publish the htlc transaction. + label := fmt.Sprintf("htlc-%v", f.InstantOut.swapPreimage.Hash()) err := f.cfg.Wallet.PublishTransaction( - f.ctx, f.InstantOut.finalizedHtlcTx, - fmt.Sprintf("htlc-%v", f.InstantOut.swapPreimage.Hash()), + ctx, f.InstantOut.finalizedHtlcTx, label, ) if err != nil { return f.HandleError(err) @@ -499,7 +507,7 @@ func (f *FSM) PublishHtlcAction(eventCtx fsm.EventContext) fsm.EventType { // We'll now wait for the htlc to be confirmed. confChan, confErrChan, err := f.cfg.ChainNotifier. RegisterConfirmationsNtfn( - f.ctx, &txHash, + ctx, &txHash, f.InstantOut.finalizedHtlcTx.TxOut[0].PkScript, 1, f.InstantOut.initiationHeight, ) @@ -518,21 +526,23 @@ func (f *FSM) PublishHtlcAction(eventCtx fsm.EventContext) fsm.EventType { } // PublishHtlcSweepAction publishes the htlc sweep transaction. -func (f *FSM) PublishHtlcSweepAction(eventCtx fsm.EventContext) fsm.EventType { +func (f *FSM) PublishHtlcSweepAction(ctx context.Context, + eventCtx fsm.EventContext) fsm.EventType { + // Create a feerate that will confirm the htlc quickly. - feeRate, err := f.cfg.Wallet.EstimateFeeRate(f.ctx, urgentConfTarget) + feeRate, err := f.cfg.Wallet.EstimateFeeRate(ctx, urgentConfTarget) if err != nil { return f.HandleError(err) } - getInfo, err := f.cfg.LndClient.GetInfo(f.ctx) + getInfo, err := f.cfg.LndClient.GetInfo(ctx) if err != nil { return f.HandleError(err) } // We can immediately publish the htlc sweep transaction. htlcSweepTx, err := f.InstantOut.generateHtlcSweepTx( - f.ctx, f.cfg.Signer, feeRate, f.cfg.Network, getInfo.BlockHeight, + ctx, f.cfg.Signer, feeRate, f.cfg.Network, getInfo.BlockHeight, ) if err != nil { return f.HandleError(err) @@ -540,7 +550,7 @@ func (f *FSM) PublishHtlcSweepAction(eventCtx fsm.EventContext) fsm.EventType { label := fmt.Sprintf("htlc-sweep-%v", f.InstantOut.swapPreimage.Hash()) - err = f.cfg.Wallet.PublishTransaction(f.ctx, htlcSweepTx, label) + err = f.cfg.Wallet.PublishTransaction(ctx, htlcSweepTx, label) if err != nil { log.Errorf("error publishing htlc sweep tx: %v", err) return f.HandleError(err) @@ -555,7 +565,7 @@ func (f *FSM) PublishHtlcSweepAction(eventCtx fsm.EventContext) fsm.EventType { // WaitForHtlcSweepConfirmedAction waits for the htlc sweep transaction to be // confirmed. -func (f *FSM) WaitForHtlcSweepConfirmedAction( +func (f *FSM) WaitForHtlcSweepConfirmedAction(ctx context.Context, eventCtx fsm.EventContext) fsm.EventType { sweepPkScript, err := txscript.PayToAddrScript( @@ -566,7 +576,7 @@ func (f *FSM) WaitForHtlcSweepConfirmedAction( } confChan, confErrChan, err := f.cfg.ChainNotifier.RegisterConfirmationsNtfn( - f.ctx, f.InstantOut.SweepTxHash, sweepPkScript, + ctx, f.InstantOut.SweepTxHash, sweepPkScript, 1, f.InstantOut.initiationHeight, ) if err != nil { @@ -592,10 +602,11 @@ func (f *FSM) WaitForHtlcSweepConfirmedAction( // handleErrorAndUnlockReservations handles an error and unlocks the // reservations. -func (f *FSM) handleErrorAndUnlockReservations(err error) fsm.EventType { +func (f *FSM) handleErrorAndUnlockReservations(ctx context.Context, + err error) fsm.EventType { // We might get here from a canceled context, we create a new context // with a timeout to unlock the reservations. - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() // Unlock the reservations. @@ -613,7 +624,7 @@ func (f *FSM) handleErrorAndUnlockReservations(err error) fsm.EventType { // release the reservations. This can be done in a goroutine as we // wan't to fail the fsm early. go func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() _, cancelErr := f.cfg.InstantOutClient.CancelInstantSwap( ctx, &swapserverrpc.CancelInstantSwapRequest{ diff --git a/instantout/fsm.go b/instantout/fsm.go index 96c6d776a..9cb0c9945 100644 --- a/instantout/fsm.go +++ b/instantout/fsm.go @@ -176,8 +176,6 @@ type Config struct { type FSM struct { *fsm.StateMachine - ctx context.Context - // cfg contains all the services that the reservation manager needs to // operate. cfg *Config @@ -195,24 +193,19 @@ type FSM struct { } // NewFSM creates a new instant out FSM. -func NewFSM(ctx context.Context, cfg *Config, - protocolVersion ProtocolVersion) (*FSM, error) { - +func NewFSM(cfg *Config, protocolVersion ProtocolVersion) (*FSM, error) { instantOut := &InstantOut{ State: fsm.EmptyState, protocolVersion: protocolVersion, } - return NewFSMFromInstantOut(ctx, cfg, instantOut) + return NewFSMFromInstantOut(cfg, instantOut) } // NewFSMFromInstantOut creates a new instantout FSM from an existing instantout // recovered from the database. -func NewFSMFromInstantOut(ctx context.Context, cfg *Config, - instantOut *InstantOut) (*FSM, error) { - +func NewFSMFromInstantOut(cfg *Config, instantOut *InstantOut) (*FSM, error) { instantOutFSM := &FSM{ - ctx: ctx, cfg: cfg, InstantOut: instantOut, } @@ -328,7 +321,9 @@ func (f *FSM) GetV1ReservationStates() fsm.States { // updateInstantOut is called after every action and updates the reservation // in the db. -func (f *FSM) updateInstantOut(notification fsm.Notification) { +func (f *FSM) updateInstantOut(ctx context.Context, + notification fsm.Notification) { + f.Infof("Previous: %v, Event: %v, Next: %v", notification.PreviousState, notification.Event, notification.NextState) @@ -349,7 +344,7 @@ func (f *FSM) updateInstantOut(notification fsm.Notification) { return } - err := f.cfg.Store.UpdateInstantLoopOut(f.ctx, f.InstantOut) + err := f.cfg.Store.UpdateInstantLoopOut(ctx, f.InstantOut) if err != nil { log.Errorf("Error updating instant out: %v", err) return diff --git a/instantout/manager.go b/instantout/manager.go index d9d15fbc8..38207c446 100644 --- a/instantout/manager.go +++ b/instantout/manager.go @@ -111,7 +111,7 @@ func (m *Manager) recoverInstantOuts(ctx context.Context) error { log.Debugf("Recovering instantout %v", instantOut.SwapHash) instantOutFSM, err := NewFSMFromInstantOut( - ctx, m.cfg, instantOut, + m.cfg, instantOut, ) if err != nil { return err @@ -122,7 +122,7 @@ func (m *Manager) recoverInstantOuts(ctx context.Context) error { // As SendEvent can block, we'll start a goroutine to process // the event. go func() { - err := instantOutFSM.SendEvent(OnRecover, nil) + err := instantOutFSM.SendEvent(ctx, OnRecover, nil) if err != nil { log.Errorf("FSM %v Error sending recover "+ "event %v, state: %v", @@ -162,9 +162,7 @@ func (m *Manager) NewInstantOut(ctx context.Context, sweepAddress: sweepAddr, } - instantOut, err := NewFSM( - m.runCtx, m.cfg, ProtocolVersionFullReservation, - ) + instantOut, err := NewFSM(m.cfg, ProtocolVersionFullReservation) if err != nil { m.Unlock() return nil, err @@ -174,7 +172,7 @@ func (m *Manager) NewInstantOut(ctx context.Context, // Start the instantout FSM. go func() { - err := instantOut.SendEvent(OnStart, request) + err := instantOut.SendEvent(m.runCtx, OnStart, request) if err != nil { log.Errorf("Error sending event: %v", err) } diff --git a/instantout/reservation/actions.go b/instantout/reservation/actions.go index 76cc3f067..1d58cd744 100644 --- a/instantout/reservation/actions.go +++ b/instantout/reservation/actions.go @@ -22,16 +22,16 @@ type InitReservationContext struct { // InitAction is the action that is executed when the reservation state machine // is initialized. It creates the reservation in the database and dispatches the // payment to the server. -func (f *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType { +func (f *FSM) InitAction(ctx context.Context, + eventCtx fsm.EventContext) fsm.EventType { + // Check if the context is of the correct type. reservationRequest, ok := eventCtx.(*InitReservationContext) if !ok { return f.HandleError(fsm.ErrInvalidContextType) } - keyRes, err := f.cfg.Wallet.DeriveNextKey( - f.ctx, KeyFamily, - ) + keyRes, err := f.cfg.Wallet.DeriveNextKey(ctx, KeyFamily) if err != nil { return f.HandleError(err) } @@ -45,7 +45,7 @@ func (f *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType { ClientKey: keyRes.PubKey.SerializeCompressed(), } - _, err = f.cfg.ReservationClient.OpenReservation(f.ctx, request) + _, err = f.cfg.ReservationClient.OpenReservation(ctx, request) if err != nil { return f.HandleError(err) } @@ -66,7 +66,7 @@ func (f *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType { f.reservation = reservation // Create the reservation in the database. - err = f.cfg.Store.CreateReservation(f.ctx, reservation) + err = f.cfg.Store.CreateReservation(ctx, reservation) if err != nil { return f.HandleError(err) } @@ -77,13 +77,15 @@ func (f *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType { // SubscribeToConfirmationAction is the action that is executed when the // reservation is waiting for confirmation. It subscribes to the confirmation // of the reservation transaction. -func (f *FSM) SubscribeToConfirmationAction(_ fsm.EventContext) fsm.EventType { +func (f *FSM) SubscribeToConfirmationAction(ctx context.Context, + _ fsm.EventContext) fsm.EventType { + pkscript, err := f.reservation.GetPkScript() if err != nil { return f.HandleError(err) } - callCtx, cancel := context.WithCancel(f.ctx) + callCtx, cancel := context.WithCancel(ctx) defer cancel() // Subscribe to the confirmation of the reservation transaction. @@ -141,7 +143,7 @@ func (f *FSM) SubscribeToConfirmationAction(_ fsm.EventContext) fsm.EventType { return OnTimedOut } - case <-f.ctx.Done(): + case <-ctx.Done(): return fsm.NoOp } } @@ -150,10 +152,10 @@ func (f *FSM) SubscribeToConfirmationAction(_ fsm.EventContext) fsm.EventType { // AsyncWaitForExpiredOrSweptAction waits for the reservation to be either // expired or swept. This is non-blocking and can be used to wait for the // reservation to expire while expecting other events. -func (f *FSM) AsyncWaitForExpiredOrSweptAction(_ fsm.EventContext, -) fsm.EventType { +func (f *FSM) AsyncWaitForExpiredOrSweptAction(ctx context.Context, + _ fsm.EventContext) fsm.EventType { - notifCtx, cancel := context.WithCancel(f.ctx) + notifCtx, cancel := context.WithCancel(ctx) blockHeightChan, errEpochChan, err := f.cfg.ChainNotifier. RegisterBlockEpochNtfn(notifCtx) @@ -184,13 +186,13 @@ func (f *FSM) AsyncWaitForExpiredOrSweptAction(_ fsm.EventContext, errSpendChan, ) if err != nil { - f.handleAsyncError(err) + f.handleAsyncError(ctx, err) return } if op == fsm.NoOp { return } - err = f.SendEvent(op, nil) + err = f.SendEvent(ctx, op, nil) if err != nil { f.Errorf("Error sending %s event: %v", op, err) } @@ -229,10 +231,10 @@ func (f *FSM) handleSubcriptions(ctx context.Context, } } -func (f *FSM) handleAsyncError(err error) { +func (f *FSM) handleAsyncError(ctx context.Context, err error) { f.LastActionError = err f.Errorf("Error on async action: %v", err) - err2 := f.SendEvent(fsm.OnError, err) + err2 := f.SendEvent(ctx, fsm.OnError, err) if err2 != nil { f.Errorf("Error sending event: %v", err2) } diff --git a/instantout/reservation/actions_test.go b/instantout/reservation/actions_test.go index d89e526db..3b5adfd26 100644 --- a/instantout/reservation/actions_test.go +++ b/instantout/reservation/actions_test.go @@ -144,7 +144,6 @@ func TestInitReservationAction(t *testing.T) { ).Return(tc.mockStoreErr) reservationFSM := &FSM{ - ctx: ctxb, cfg: &Config{ Wallet: mockLnd.WalletKit, ChainNotifier: mockLnd.ChainNotifier, @@ -154,7 +153,7 @@ func TestInitReservationAction(t *testing.T) { StateMachine: &fsm.StateMachine{}, } - event := reservationFSM.InitAction(tc.eventCtx) + event := reservationFSM.InitAction(ctxb, tc.eventCtx) require.Equal(t, tc.expectedEvent, event) } } @@ -227,10 +226,10 @@ func TestSubscribeToConfirmationAction(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { chainNotifier := new(MockChainNotifier) - + ctxb := context.Background() // Create the FSM. r := NewFSMFromReservation( - context.Background(), &Config{ + &Config{ ChainNotifier: chainNotifier, }, &Reservation{ @@ -296,7 +295,7 @@ func TestSubscribeToConfirmationAction(t *testing.T) { } }() - eventType := r.SubscribeToConfirmationAction(nil) + eventType := r.SubscribeToConfirmationAction(ctxb, nil) // Assert that the return value is as expected require.Equal(t, tc.expectedEvent, eventType) @@ -335,10 +334,11 @@ func TestAsyncWaitForExpiredOrSweptAction(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { // Create a mock ChainNotifier and Reservation chainNotifier := new(MockChainNotifier) + ctxb := context.Background() // Define your FSM r := NewFSMFromReservation( - context.Background(), &Config{ + &Config{ ChainNotifier: chainNotifier, }, &Reservation{ @@ -361,7 +361,7 @@ func TestAsyncWaitForExpiredOrSweptAction(t *testing.T) { make(chan error), tc.spendErr, ) - eventType := r.AsyncWaitForExpiredOrSweptAction(nil) + eventType := r.AsyncWaitForExpiredOrSweptAction(ctxb, nil) // Assert that the return value is as expected require.Equal(t, tc.expectedEvent, eventType) }) @@ -415,7 +415,7 @@ func TestHandleSubcriptions(t *testing.T) { // Create the FSM. r := NewFSMFromReservation( - context.Background(), &Config{ + &Config{ ChainNotifier: chainNotifier, }, &Reservation{ diff --git a/instantout/reservation/fsm.go b/instantout/reservation/fsm.go index 86afbbaa2..6bf567d28 100644 --- a/instantout/reservation/fsm.go +++ b/instantout/reservation/fsm.go @@ -40,26 +40,21 @@ type FSM struct { cfg *Config reservation *Reservation - - ctx context.Context } // NewFSM creates a new reservation FSM. -func NewFSM(ctx context.Context, cfg *Config) *FSM { +func NewFSM(cfg *Config) *FSM { reservation := &Reservation{ State: fsm.EmptyState, } - return NewFSMFromReservation(ctx, cfg, reservation) + return NewFSMFromReservation(cfg, reservation) } // NewFSMFromReservation creates a new reservation FSM from an existing // reservation recovered from the database. -func NewFSMFromReservation(ctx context.Context, cfg *Config, - reservation *Reservation) *FSM { - +func NewFSMFromReservation(cfg *Config, reservation *Reservation) *FSM { reservationFsm := &FSM{ - ctx: ctx, cfg: cfg, reservation: reservation, } @@ -206,7 +201,9 @@ func (f *FSM) GetReservationStates() fsm.States { // updateReservation updates the reservation in the database. This function // is called after every new state transition. -func (r *FSM) updateReservation(notification fsm.Notification) { +func (r *FSM) updateReservation(ctx context.Context, + notification fsm.Notification) { + if r.reservation == nil { return } @@ -229,7 +226,7 @@ func (r *FSM) updateReservation(notification fsm.Notification) { return } - err := r.cfg.Store.UpdateReservation(r.ctx, r.reservation) + err := r.cfg.Store.UpdateReservation(ctx, r.reservation) if err != nil { r.Errorf("unable to update reservation: %v", err) } diff --git a/instantout/reservation/manager.go b/instantout/reservation/manager.go index 3a20d113d..f2833c077 100644 --- a/instantout/reservation/manager.go +++ b/instantout/reservation/manager.go @@ -22,8 +22,6 @@ type Manager struct { // activeReservations contains all the active reservationsFSMs. activeReservations map[ID]*FSM - runCtx context.Context - sync.Mutex } @@ -36,13 +34,14 @@ func NewManager(cfg *Config) *Manager { } // Run runs the reservation manager. -func (m *Manager) Run(ctx context.Context, height int32) error { +func (m *Manager) Run(ctx context.Context, height int32, + initChan chan struct{}) error { + log.Debugf("Starting reservation manager") runCtx, cancel := context.WithCancel(ctx) defer cancel() - m.runCtx = runCtx currentHeight := height err := m.RecoverReservations(runCtx) @@ -58,6 +57,9 @@ func (m *Manager) Run(ctx context.Context, height int32) error { ntfnChan := m.cfg.NotificationManager.SubscribeReservations(runCtx) + // Signal that the manager has been initialized. + close(initChan) + for { select { case height := <-newBlockChan: @@ -111,9 +113,7 @@ func (m *Manager) newReservation(ctx context.Context, currentHeight uint32, // Create the reservation state machine. We need to pass in the runCtx // of the reservation manager so that the state machine will keep on // running even if the grpc conte - reservationFSM := NewFSM( - ctx, m.cfg, - ) + reservationFSM := NewFSM(m.cfg) // Add the reservation to the active reservations map. m.Lock() @@ -130,7 +130,7 @@ func (m *Manager) newReservation(ctx context.Context, currentHeight uint32, // Send the init event to the state machine. go func() { - err = reservationFSM.SendEvent(OnServerRequest, initContext) + err = reservationFSM.SendEvent(ctx, OnServerRequest, initContext) if err != nil { log.Errorf("Error sending init event: %v", err) } @@ -171,16 +171,14 @@ func (m *Manager) RecoverReservations(ctx context.Context) error { fsmCtx := context.WithValue(ctx, reservation.ID, nil) - reservationFSM := NewFSMFromReservation( - fsmCtx, m.cfg, reservation, - ) + reservationFSM := NewFSMFromReservation(m.cfg, reservation) m.activeReservations[reservation.ID] = reservationFSM // As SendEvent can block, we'll start a goroutine to process // the event. go func() { - err := reservationFSM.SendEvent(OnRecover, nil) + err := reservationFSM.SendEvent(fsmCtx, OnRecover, nil) if err != nil { log.Errorf("FSM %v Error sending recover "+ "event %v, state: %v", @@ -217,7 +215,7 @@ func (m *Manager) LockReservation(ctx context.Context, id ID) error { } // Try to send the lock event to the reservation. - err := reservation.SendEvent(OnLocked, nil) + err := reservation.SendEvent(ctx, OnLocked, nil) if err != nil { return err } @@ -237,7 +235,7 @@ func (m *Manager) UnlockReservation(ctx context.Context, id ID) error { } // Try to send the unlock event to the reservation. - err := reservation.SendEvent(OnUnlocked, nil) + err := reservation.SendEvent(ctx, OnUnlocked, nil) if err != nil && strings.Contains(err.Error(), "config error") { // If the error is a config error, we can ignore it, as the // reservation is already unlocked. diff --git a/instantout/reservation/manager_test.go b/instantout/reservation/manager_test.go index 1dbb5a349..226ffb172 100644 --- a/instantout/reservation/manager_test.go +++ b/instantout/reservation/manager_test.go @@ -25,12 +25,16 @@ func TestManager(t *testing.T) { testContext := newManagerTestContext(t) + initChan := make(chan struct{}) // Start the manager. go func() { - err := testContext.manager.Run(ctxb, testContext.mockLnd.Height) + err := testContext.manager.Run(ctxb, testContext.mockLnd.Height, initChan) require.NoError(t, err) }() + // We'll now wait for the manager to be initialized. + <-initChan + // Create a new reservation. reservationFSM, err := testContext.manager.newReservation( ctxb, uint32(testContext.mockLnd.Height), diff --git a/loopd/daemon.go b/loopd/daemon.go index 291c6e541..a7db601b9 100644 --- a/loopd/daemon.go +++ b/loopd/daemon.go @@ -648,6 +648,7 @@ func (d *Daemon) initialize(withMacaroonService bool) error { // Start the reservation manager. if d.reservationManager != nil { d.wg.Add(1) + initChan := make(chan struct{}) go func() { defer d.wg.Done() @@ -663,12 +664,25 @@ func (d *Daemon) initialize(withMacaroonService bool) error { defer log.Info("Reservation manager stopped") err = d.reservationManager.Run( - d.mainCtx, int32(getInfo.BlockHeight), + d.mainCtx, int32(getInfo.BlockHeight), initChan, ) if err != nil && !errors.Is(err, context.Canceled) { d.internalErrChan <- err } }() + + // Wait for the reservation server to be ready before starting the + // grpc server. + timeOutCtx, cancel := context.WithTimeout(d.mainCtx, 10*time.Second) + select { + case <-timeOutCtx.Done(): + cancel() + return fmt.Errorf("reservation server not ready: %v", + timeOutCtx.Err()) + + case <-initChan: + cancel() + } } // Start the instant out manager. @@ -701,8 +715,9 @@ func (d *Daemon) initialize(withMacaroonService bool) error { select { case <-timeOutCtx.Done(): cancel() - return fmt.Errorf("reservation server not ready: %v", + return fmt.Errorf("instantout server not ready: %v", timeOutCtx.Err()) + case <-initChan: cancel() }