Skip to content

Commit 11ab596

Browse files
authored
Merge pull request #497 from bhandras/taproot-musig2
multi: upgrade to using P2TR htlcs and added support for MuSig2 loopout sweep
2 parents 490fb35 + 5f4b34d commit 11ab596

38 files changed

+2571
-925
lines changed

client.go

Lines changed: 159 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"sync/atomic"
1010
"time"
1111

12+
"github.com/btcsuite/btcd/btcec/v2"
13+
"github.com/btcsuite/btcd/btcec/v2/schnorr"
1214
"github.com/btcsuite/btcd/btcutil"
1315
"github.com/lightninglabs/aperture/lsat"
1416
"github.com/lightninglabs/lndclient"
@@ -158,6 +160,18 @@ func NewClient(dbDir string, cfg *ClientConfig) (*Client, func(), error) {
158160
totalPaymentTimeout: cfg.TotalPaymentTimeout,
159161
maxPaymentRetries: cfg.MaxPaymentRetries,
160162
cancelSwap: swapServerClient.CancelLoopOutSwap,
163+
verifySchnorrSig: func(pubKey *btcec.PublicKey, hash, sig []byte) error {
164+
schnorrSig, err := schnorr.ParseSignature(sig)
165+
if err != nil {
166+
return err
167+
}
168+
169+
if !schnorrSig.Verify(hash, pubKey) {
170+
return fmt.Errorf("invalid signature")
171+
}
172+
173+
return nil
174+
},
161175
})
162176

163177
client := &Client{
@@ -192,56 +206,92 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
192206
swaps := make([]*SwapInfo, 0, len(loopInSwaps)+len(loopOutSwaps))
193207

194208
for _, swp := range loopOutSwaps {
209+
swapInfo := &SwapInfo{
210+
SwapType: swap.TypeOut,
211+
SwapContract: swp.Contract.SwapContract,
212+
SwapStateData: swp.State(),
213+
SwapHash: swp.Hash,
214+
LastUpdate: swp.LastUpdateTime(),
215+
}
216+
scriptVersion := GetHtlcScriptVersion(
217+
swp.Contract.ProtocolVersion,
218+
)
219+
220+
outputType := swap.HtlcP2WSH
221+
if scriptVersion == swap.HtlcV3 {
222+
outputType = swap.HtlcP2TR
223+
}
224+
195225
htlc, err := swap.NewHtlc(
196-
GetHtlcScriptVersion(swp.Contract.ProtocolVersion),
226+
scriptVersion,
197227
swp.Contract.CltvExpiry, swp.Contract.SenderKey,
198-
swp.Contract.ReceiverKey, swp.Hash, swap.HtlcP2WSH,
199-
s.lndServices.ChainParams,
228+
swp.Contract.ReceiverKey, swp.Hash,
229+
outputType, s.lndServices.ChainParams,
200230
)
201231
if err != nil {
202232
return nil, err
203233
}
204234

205-
swaps = append(swaps, &SwapInfo{
206-
SwapType: swap.TypeOut,
207-
SwapContract: swp.Contract.SwapContract,
208-
SwapStateData: swp.State(),
209-
SwapHash: swp.Hash,
210-
LastUpdate: swp.LastUpdateTime(),
211-
HtlcAddressP2WSH: htlc.Address,
212-
})
235+
if outputType == swap.HtlcP2TR {
236+
swapInfo.HtlcAddressP2TR = htlc.Address
237+
} else {
238+
swapInfo.HtlcAddressP2WSH = htlc.Address
239+
}
240+
241+
swaps = append(swaps, swapInfo)
213242
}
214243

215244
for _, swp := range loopInSwaps {
216-
htlcNP2WSH, err := swap.NewHtlc(
217-
GetHtlcScriptVersion(swp.Contract.ProtocolVersion),
218-
swp.Contract.CltvExpiry, swp.Contract.SenderKey,
219-
swp.Contract.ReceiverKey, swp.Hash, swap.HtlcNP2WSH,
220-
s.lndServices.ChainParams,
221-
)
222-
if err != nil {
223-
return nil, err
245+
swapInfo := &SwapInfo{
246+
SwapType: swap.TypeIn,
247+
SwapContract: swp.Contract.SwapContract,
248+
SwapStateData: swp.State(),
249+
SwapHash: swp.Hash,
250+
LastUpdate: swp.LastUpdateTime(),
224251
}
225252

226-
htlcP2WSH, err := swap.NewHtlc(
227-
GetHtlcScriptVersion(swp.Contract.ProtocolVersion),
228-
swp.Contract.CltvExpiry, swp.Contract.SenderKey,
229-
swp.Contract.ReceiverKey, swp.Hash, swap.HtlcP2WSH,
230-
s.lndServices.ChainParams,
253+
scriptVersion := GetHtlcScriptVersion(
254+
swp.Contract.SwapContract.ProtocolVersion,
231255
)
232-
if err != nil {
233-
return nil, err
256+
257+
if scriptVersion == swap.HtlcV3 {
258+
htlcP2TR, err := swap.NewHtlc(
259+
swap.HtlcV3, swp.Contract.CltvExpiry,
260+
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
261+
swp.Hash, swap.HtlcP2TR,
262+
s.lndServices.ChainParams,
263+
)
264+
if err != nil {
265+
return nil, err
266+
}
267+
268+
swapInfo.HtlcAddressP2TR = htlcP2TR.Address
269+
} else {
270+
htlcNP2WSH, err := swap.NewHtlc(
271+
swap.HtlcV1, swp.Contract.CltvExpiry,
272+
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
273+
swp.Hash, swap.HtlcNP2WSH,
274+
s.lndServices.ChainParams,
275+
)
276+
if err != nil {
277+
return nil, err
278+
}
279+
280+
htlcP2WSH, err := swap.NewHtlc(
281+
swap.HtlcV2, swp.Contract.CltvExpiry,
282+
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
283+
swp.Hash, swap.HtlcP2WSH,
284+
s.lndServices.ChainParams,
285+
)
286+
if err != nil {
287+
return nil, err
288+
}
289+
290+
swapInfo.HtlcAddressP2WSH = htlcP2WSH.Address
291+
swapInfo.HtlcAddressNP2WSH = htlcNP2WSH.Address
234292
}
235293

236-
swaps = append(swaps, &SwapInfo{
237-
SwapType: swap.TypeIn,
238-
SwapContract: swp.Contract.SwapContract,
239-
SwapStateData: swp.State(),
240-
SwapHash: swp.Hash,
241-
LastUpdate: swp.LastUpdateTime(),
242-
HtlcAddressP2WSH: htlcP2WSH.Address,
243-
HtlcAddressNP2WSH: htlcNP2WSH.Address,
244-
})
294+
swaps = append(swaps, swapInfo)
245295
}
246296

247297
return swaps, nil
@@ -405,9 +455,9 @@ func (s *Client) LoopOut(globalCtx context.Context,
405455
// Return hash so that the caller can identify this swap in the updates
406456
// stream.
407457
return &LoopOutSwapInfo{
408-
SwapHash: swap.hash,
409-
HtlcAddressP2WSH: swap.htlc.Address,
410-
ServerMessage: initResult.serverMessage,
458+
SwapHash: swap.hash,
459+
HtlcAddress: swap.htlc.Address,
460+
ServerMessage: initResult.serverMessage,
411461
}, nil
412462
}
413463

@@ -463,7 +513,23 @@ func (s *Client) LoopOutQuote(ctx context.Context,
463513

464514
log.Infof("Offchain swap destination: %x", quote.SwapPaymentDest)
465515

466-
swapFee := quote.SwapFee
516+
minerFee, err := s.getLoopOutSweepFee(ctx, request.SweepConfTarget)
517+
if err != nil {
518+
return nil, err
519+
}
520+
521+
return &LoopOutQuote{
522+
SwapFee: quote.SwapFee,
523+
MinerFee: minerFee,
524+
PrepayAmount: quote.PrepayAmount,
525+
SwapPaymentDest: quote.SwapPaymentDest,
526+
}, nil
527+
}
528+
529+
// getLoopOutSweepFee is a helper method to estimate the loop out htlc sweep
530+
// fee to a p2wsh address.
531+
func (s *Client) getLoopOutSweepFee(ctx context.Context, confTarget int32) (
532+
btcutil.Amount, error) {
467533

468534
// Generate dummy p2wsh address for fee estimation. The p2wsh address
469535
// type is chosen because it adds the most weight of all output types
@@ -473,23 +539,21 @@ func (s *Client) LoopOutQuote(ctx context.Context,
473539
wsh[:], s.lndServices.ChainParams,
474540
)
475541
if err != nil {
476-
return nil, err
542+
return 0, err
477543
}
478544

479-
minerFee, err := s.sweeper.GetSweepFee(
480-
ctx, swap.QuoteHtlc.AddSuccessToEstimator,
481-
p2wshAddress, request.SweepConfTarget,
545+
scriptVersion := GetHtlcScriptVersion(
546+
loopdb.CurrentProtocolVersion(),
482547
)
483-
if err != nil {
484-
return nil, err
548+
549+
htlc := swap.QuoteHtlcP2TR
550+
if scriptVersion != swap.HtlcV3 {
551+
htlc = swap.QuoteHtlcP2WSH
485552
}
486553

487-
return &LoopOutQuote{
488-
SwapFee: swapFee,
489-
MinerFee: minerFee,
490-
PrepayAmount: quote.PrepayAmount,
491-
SwapPaymentDest: quote.SwapPaymentDest,
492-
}, nil
554+
return s.sweeper.GetSweepFee(
555+
ctx, htlc.AddSuccessToEstimator, p2wshAddress, confTarget,
556+
)
493557
}
494558

495559
// LoopOutTerms returns the terms on which the server executes swaps.
@@ -546,11 +610,17 @@ func (s *Client) LoopIn(globalCtx context.Context,
546610
// Return hash so that the caller can identify this swap in the updates
547611
// stream.
548612
swapInfo := &LoopInSwapInfo{
549-
SwapHash: swap.hash,
550-
HtlcAddressP2WSH: swap.htlcP2WSH.Address,
551-
HtlcAddressNP2WSH: swap.htlcNP2WSH.Address,
552-
ServerMessage: initResult.serverMessage,
613+
SwapHash: swap.hash,
614+
ServerMessage: initResult.serverMessage,
553615
}
616+
617+
if loopdb.CurrentProtocolVersion() < loopdb.ProtocolVersionHtlcV3 {
618+
swapInfo.HtlcAddressNP2WSH = swap.htlcNP2WSH.Address
619+
swapInfo.HtlcAddressP2WSH = swap.htlcP2WSH.Address
620+
} else {
621+
swapInfo.HtlcAddressP2TR = swap.htlcP2TR.Address
622+
}
623+
554624
return swapInfo, nil
555625
}
556626

@@ -626,7 +696,7 @@ func (s *Client) LoopInQuote(ctx context.Context,
626696
//
627697
// TODO(guggero): Thread through error code from lnd to avoid string
628698
// matching.
629-
minerFee, err := s.lndServices.Client.EstimateFeeToP2WSH(
699+
minerFee, err := s.estimateFee(
630700
ctx, request.Amount, request.HtlcConfTarget,
631701
)
632702
if err != nil && strings.Contains(err.Error(), "insufficient funds") {
@@ -647,6 +717,39 @@ func (s *Client) LoopInQuote(ctx context.Context,
647717
}, nil
648718
}
649719

720+
// estimateFee is a helper method to estimate the total fee for paying the
721+
// passed amount with the given conf target. It'll assume taproot destination
722+
// if the protocol version indicates that we're using taproot htlcs.
723+
func (s *Client) estimateFee(ctx context.Context, amt btcutil.Amount,
724+
confTarget int32) (btcutil.Amount, error) {
725+
726+
var (
727+
address btcutil.Address
728+
err error
729+
)
730+
// Generate a dummy address for fee estimation.
731+
witnessProg := [32]byte{}
732+
733+
scriptVersion := GetHtlcScriptVersion(
734+
loopdb.CurrentProtocolVersion(),
735+
)
736+
737+
if scriptVersion != swap.HtlcV3 {
738+
address, err = btcutil.NewAddressWitnessScriptHash(
739+
witnessProg[:], s.lndServices.ChainParams,
740+
)
741+
} else {
742+
address, err = btcutil.NewAddressTaproot(
743+
witnessProg[:], s.lndServices.ChainParams,
744+
)
745+
}
746+
if err != nil {
747+
return 0, err
748+
}
749+
750+
return s.lndServices.Client.EstimateFee(ctx, address, amt, confTarget)
751+
}
752+
650753
// LoopInTerms returns the terms on which the server executes swaps.
651754
func (s *Client) LoopInTerms(ctx context.Context) (
652755
*LoopInTerms, error) {

0 commit comments

Comments
 (0)