From bdb4b773ed82599858798f40f9132d9303266a01 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Wed, 16 Nov 2022 19:01:28 +0100 Subject: [PATCH] swap: refactor htlc construction to allow passing of internal keys This commit is a refactor of how we construct htlcs to make it possible to pass in internal keys for the sender and receiver when creating P2TR htlcs. Furthermore the commit also cleans up constructors to not pass in script versions and output types to make the code more readable. --- client.go | 64 ++++++++++---------------- client_test.go | 24 +++++++--- loopd/view.go | 35 ++++---------- loopin.go | 30 ++++++------ loopin_test.go | 35 +++++++++----- loopout.go | 25 ++++------ swap.go | 31 +++++++++---- swap/htlc.go | 113 +++++++++++++++++++++++++++++++--------------- swap/htlc_test.go | 27 ++++++----- 9 files changed, 207 insertions(+), 177 deletions(-) diff --git a/client.go b/client.go index 4581f0f..0bf6e46 100644 --- a/client.go +++ b/client.go @@ -213,29 +213,24 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { SwapHash: swp.Hash, LastUpdate: swp.LastUpdateTime(), } - scriptVersion := GetHtlcScriptVersion( - swp.Contract.ProtocolVersion, - ) - - outputType := swap.HtlcP2WSH - if scriptVersion == swap.HtlcV3 { - outputType = swap.HtlcP2TR - } - htlc, err := swap.NewHtlc( - scriptVersion, - swp.Contract.CltvExpiry, swp.Contract.SenderKey, - swp.Contract.ReceiverKey, swp.Hash, - outputType, s.lndServices.ChainParams, + htlc, err := GetHtlc( + swp.Hash, &swp.Contract.SwapContract, + s.lndServices.ChainParams, ) if err != nil { return nil, err } - if outputType == swap.HtlcP2TR { - swapInfo.HtlcAddressP2TR = htlc.Address - } else { + switch htlc.OutputType { + case swap.HtlcP2WSH: swapInfo.HtlcAddressP2WSH = htlc.Address + + case swap.HtlcP2TR: + swapInfo.HtlcAddressP2TR = htlc.Address + + default: + return nil, swap.ErrInvalidOutputType } swaps = append(swaps, swapInfo) @@ -250,34 +245,23 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { LastUpdate: swp.LastUpdateTime(), } - scriptVersion := GetHtlcScriptVersion( - swp.Contract.SwapContract.ProtocolVersion, + htlc, err := GetHtlc( + swp.Hash, &swp.Contract.SwapContract, + s.lndServices.ChainParams, ) + if err != nil { + return nil, err + } - if scriptVersion == swap.HtlcV3 { - htlcP2TR, err := swap.NewHtlc( - swap.HtlcV3, swp.Contract.CltvExpiry, - swp.Contract.SenderKey, swp.Contract.ReceiverKey, - swp.Hash, swap.HtlcP2TR, - s.lndServices.ChainParams, - ) - if err != nil { - return nil, err - } + switch htlc.OutputType { + case swap.HtlcP2WSH: + swapInfo.HtlcAddressP2WSH = htlc.Address - swapInfo.HtlcAddressP2TR = htlcP2TR.Address - } else { - htlcP2WSH, err := swap.NewHtlc( - swap.HtlcV2, swp.Contract.CltvExpiry, - swp.Contract.SenderKey, swp.Contract.ReceiverKey, - swp.Hash, swap.HtlcP2WSH, - s.lndServices.ChainParams, - ) - if err != nil { - return nil, err - } + case swap.HtlcP2TR: + swapInfo.HtlcAddressP2TR = htlc.Address - swapInfo.HtlcAddressP2WSH = htlcP2WSH.Address + default: + return nil, swap.ErrInvalidOutputType } swaps = append(swaps, swapInfo) diff --git a/client_test.go b/client_test.go index 7c8f88b..3bca073 100644 --- a/client_test.go +++ b/client_test.go @@ -284,16 +284,26 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, // Assert that the loopout htlc equals to the expected one. scriptVersion := GetHtlcScriptVersion(protocolVersion) + var htlc *swap.Htlc - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - outputType = swap.HtlcP2WSH + switch scriptVersion { + case swap.HtlcV2: + htlc, err = swap.NewHtlcV2( + pendingSwap.Contract.CltvExpiry, senderKey, + receiverKey, hash, &chaincfg.TestNet3Params, + ) + + case swap.HtlcV3: + htlc, err = swap.NewHtlcV3( + pendingSwap.Contract.CltvExpiry, senderKey, + receiverKey, senderKey, receiverKey, hash, + &chaincfg.TestNet3Params, + ) + + default: + t.Fatalf(swap.ErrInvalidScriptVersion.Error()) } - htlc, err := swap.NewHtlc( - scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey, - receiverKey, hash, outputType, &chaincfg.TestNet3Params, - ) require.NoError(t, err) require.Equal(t, htlc.PkScript, confIntent.PkScript) diff --git a/loopd/view.go b/loopd/view.go index f99fdef..3a7b600 100644 --- a/loopd/view.go +++ b/loopd/view.go @@ -7,7 +7,6 @@ import ( "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop" "github.com/lightninglabs/loop/loopdb" - "github.com/lightninglabs/loop/swap" ) // view prints all swaps currently in the database. @@ -49,24 +48,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { } for _, s := range swaps { - scriptVersion := loop.GetHtlcScriptVersion( - s.Contract.ProtocolVersion, - ) - - var outputType swap.HtlcOutputType - switch scriptVersion { - case swap.HtlcV2: - outputType = swap.HtlcP2WSH - - case swap.HtlcV3: - outputType = swap.HtlcP2TR - } - htlc, err := swap.NewHtlc( - loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion), - s.Contract.CltvExpiry, - s.Contract.SenderKey, - s.Contract.ReceiverKey, - s.Hash, outputType, chainParams, + htlc, err := loop.GetHtlc( + s.Hash, &s.Contract.SwapContract, chainParams, ) if err != nil { return err @@ -77,7 +60,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { s.Contract.InitiationTime, s.Contract.InitiationHeight, ) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) - fmt.Printf(" Htlc address: %v\n", htlc.Address) + fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType, + htlc.Address) fmt.Printf(" Uncharge channels: %v\n", s.Contract.OutgoingChanSet) @@ -113,12 +97,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { } for _, s := range swaps { - htlc, err := swap.NewHtlc( - loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion), - s.Contract.CltvExpiry, - s.Contract.SenderKey, - s.Contract.ReceiverKey, - s.Hash, swap.HtlcP2WSH, chainParams, + htlc, err := loop.GetHtlc( + s.Hash, &s.Contract.SwapContract, chainParams, ) if err != nil { return err @@ -129,7 +109,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { s.Contract.InitiationTime, s.Contract.InitiationHeight, ) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) - fmt.Printf(" Htlc address: %v\n", htlc.Address) + fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType, + htlc.Address) fmt.Printf(" Amt: %v, Expiry: %v\n", s.Contract.AmountRequested, s.Contract.CltvExpiry, ) diff --git a/loopin.go b/loopin.go index 8b3e49b..62d2a42 100644 --- a/loopin.go +++ b/loopin.go @@ -404,26 +404,26 @@ func validateLoopInContract(lnd *lndclient.LndServices, // initHtlcs creates and updates the native and nested segwit htlcs // of the loopInSwap. func (s *loopInSwap) initHtlcs() error { - if IsTaprootSwap(&s.SwapContract) { - htlcP2TR, err := s.swapKit.getHtlc(swap.HtlcP2TR) - if err != nil { - return err - } + htlc, err := GetHtlc( + s.hash, &s.SwapContract, s.swapKit.lnd.ChainParams, + ) + if err != nil { + return err + } - s.swapKit.log.Infof("Htlc address (P2TR): %v", htlcP2TR.Address) - s.htlcP2TR = htlcP2TR + switch htlc.OutputType { + case swap.HtlcP2WSH: + s.htlcP2WSH = htlc - return nil - } + case swap.HtlcP2TR: + s.htlcP2TR = htlc - htlcP2WSH, err := s.swapKit.getHtlc(swap.HtlcP2WSH) - if err != nil { - return err + default: + return fmt.Errorf("invalid output type") } - // Log htlc addresses for debugging. - s.swapKit.log.Infof("Htlc address (P2WSH): %v", htlcP2WSH.Address) - s.htlcP2WSH = htlcP2WSH + s.swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType, + htlc.Address) return nil } diff --git a/loopin_test.go b/loopin_test.go index e4a317b..53f2b35 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -455,21 +455,32 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, pendSwap.Loop.Events[0].Cost = cost } - scriptVersion := GetHtlcScriptVersion(storedVersion) + var ( + htlc *swap.Htlc + err error + ) - outputType := swap.HtlcP2WSH - if scriptVersion == swap.HtlcV3 { - outputType = swap.HtlcP2TR + switch GetHtlcScriptVersion(storedVersion) { + case swap.HtlcV2: + htlc, err = swap.NewHtlcV2( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, testPreimage.Hash(), + cfg.lnd.ChainParams, + ) + + case swap.HtlcV3: + htlc, err = swap.NewHtlcV3( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, contract.SenderKey, + contract.ReceiverKey, testPreimage.Hash(), + cfg.lnd.ChainParams, + ) + + default: + t.Fatalf("unknown HTLC script version") } - htlc, err := swap.NewHtlc( - scriptVersion, contract.CltvExpiry, contract.SenderKey, - contract.ReceiverKey, testPreimage.Hash(), outputType, - cfg.lnd.ChainParams, - ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract) if err != nil { diff --git a/loopout.go b/loopout.go index 30185e6..6b03bfd 100644 --- a/loopout.go +++ b/loopout.go @@ -201,21 +201,17 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, swapKit.lastUpdateTime = initiationTime - scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion()) - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - // Default to using P2WSH for legacy htlcs. - outputType = swap.HtlcP2WSH - } - // Create the htlc. - htlc, err := swapKit.getHtlc(outputType) + htlc, err := GetHtlc( + swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams, + ) if err != nil { return nil, err } // Log htlc address for debugging. - swapKit.log.Infof("Htlc address: %v", htlc.Address) + swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType, + htlc.Address) // Obtain the payment addr since we'll need it later for routing plugin // recommendation and possibly for cancel. @@ -261,15 +257,10 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig, hash, swap.TypeOut, cfg, &pend.Contract.SwapContract, ) - scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion) - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - // Default to using P2WSH for legacy htlcs. - outputType = swap.HtlcP2WSH - } - // Create the htlc. - htlc, err := swapKit.getHtlc(outputType) + htlc, err := GetHtlc( + swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams, + ) if err != nil { return nil, err } diff --git a/swap.go b/swap.go index a084e95..764af79 100644 --- a/swap.go +++ b/swap.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/btcsuite/btcd/chaincfg" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" @@ -67,14 +68,28 @@ func IsTaprootSwap(swapContract *loopdb.SwapContract) bool { return GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3 } -// getHtlc composes and returns the on-chain swap script. -func (s *swapKit) getHtlc(outputType swap.HtlcOutputType) (*swap.Htlc, error) { - return swap.NewHtlc( - GetHtlcScriptVersion(s.contract.ProtocolVersion), - s.contract.CltvExpiry, s.contract.SenderKey, - s.contract.ReceiverKey, s.hash, outputType, - s.swapConfig.lnd.ChainParams, - ) +// GetHtlc composes and returns the on-chain swap script. +func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract, + chainParams *chaincfg.Params) (*swap.Htlc, error) { + + switch GetHtlcScriptVersion(contract.ProtocolVersion) { + case swap.HtlcV2: + return swap.NewHtlcV2( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, hash, + chainParams, + ) + + case swap.HtlcV3: + return swap.NewHtlcV3( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, contract.SenderKey, + contract.ReceiverKey, hash, + chainParams, + ) + } + + return nil, swap.ErrInvalidScriptVersion } // swapInfo constructs and returns a filled SwapInfo from diff --git a/swap/htlc.go b/swap/htlc.go index 1e174cc..8a1101a 100644 --- a/swap/htlc.go +++ b/swap/htlc.go @@ -114,16 +114,16 @@ var ( // QuoteHtlcP2WSH is a template script just used for sweep fee // estimation. - QuoteHtlcP2WSH, _ = NewHtlc( - HtlcV2, ^int32(0), dummyPubKey, dummyPubKey, quoteHash, - HtlcP2WSH, &chaincfg.MainNetParams, + QuoteHtlcP2WSH, _ = NewHtlcV2( + ^int32(0), dummyPubKey, dummyPubKey, quoteHash, + &chaincfg.MainNetParams, ) // QuoteHtlcP2TR is a template script just used for sweep fee // estimation. - QuoteHtlcP2TR, _ = NewHtlc( - HtlcV3, ^int32(0), dummyPubKey, dummyPubKey, quoteHash, - HtlcP2TR, &chaincfg.MainNetParams, + QuoteHtlcP2TR, _ = NewHtlcV3( + ^int32(0), dummyPubKey, dummyPubKey, dummyPubKey, dummyPubKey, + quoteHash, &chaincfg.MainNetParams, ) // ErrInvalidScriptVersion is returned when an unknown htlc version @@ -135,6 +135,10 @@ var ( // selected for a v2 script. ErrInvalidOutputSelected = fmt.Errorf("taproot output selected for " + "non taproot htlc") + + // ErrInvalidOutputType is returned when an unknown output type is + // associated with a certain swap htlc. + ErrInvalidOutputType = fmt.Errorf("invalid htlc output type") ) // String returns the string value of HtlcOutputType. @@ -151,38 +155,54 @@ func (h HtlcOutputType) String() string { } } -// NewHtlc returns a new instance. For v3 scripts, an internal pubkey generated -// by both participants must be provided. -func NewHtlc(version ScriptVersion, cltvExpiry int32, - senderKey, receiverKey [33]byte, hash lntypes.Hash, - outputType HtlcOutputType, chainParams *chaincfg.Params) (*Htlc, error) { +// NewHtlcV2 returns a new V2 (P2WSH) HTLC instance. +func NewHtlcV2(cltvExpiry int32, senderKey, receiverKey [33]byte, + hash lntypes.Hash, chainParams *chaincfg.Params) (*Htlc, error) { - var ( - err error - htlc HtlcScript + htlc, err := newHTLCScriptV2( + cltvExpiry, senderKey, receiverKey, hash, ) - switch version { - case HtlcV2: - htlc, err = newHTLCScriptV2( - cltvExpiry, senderKey, receiverKey, hash, - ) - - case HtlcV3: - htlc, err = newHTLCScriptV3( - cltvExpiry, senderKey, receiverKey, hash, - ) + if err != nil { + return nil, err + } - default: - return nil, ErrInvalidScriptVersion + address, pkScript, sigScript, err := htlc.lockingConditions( + HtlcP2WSH, chainParams, + ) + if err != nil { + return nil, fmt.Errorf("could not get address: %w", err) } + return &Htlc{ + HtlcScript: htlc, + Hash: hash, + Version: HtlcV2, + PkScript: pkScript, + OutputType: HtlcP2WSH, + ChainParams: chainParams, + Address: address, + SigScript: sigScript, + }, nil +} + +// NewHtlcV3 returns a new V3 HTLC (P2TR) instance. Internal pubkey generated +// by both participants must be provided. +func NewHtlcV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, + senderKey, receiverKey [33]byte, hash lntypes.Hash, + chainParams *chaincfg.Params) (*Htlc, error) { + + htlc, err := newHTLCScriptV3( + cltvExpiry, senderInternalKey, receiverInternalKey, + senderKey, receiverKey, hash, + ) + if err != nil { return nil, err } address, pkScript, sigScript, err := htlc.lockingConditions( - outputType, chainParams, + HtlcP2TR, chainParams, ) if err != nil { return nil, fmt.Errorf("could not get address: %w", err) @@ -191,9 +211,9 @@ func NewHtlc(version ScriptVersion, cltvExpiry int32, return &Htlc{ HtlcScript: htlc, Hash: hash, - Version: version, + Version: HtlcV3, PkScript: pkScript, - OutputType: outputType, + OutputType: HtlcP2TR, ChainParams: chainParams, Address: address, SigScript: sigScript, @@ -481,7 +501,8 @@ type HtlcScriptV3 struct { } // newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script. -func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, +func newHTLCScriptV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, + senderHtlcKey, receiverHtlcKey [33]byte, swapHash lntypes.Hash) (*HtlcScriptV3, error) { senderPubKey, err := schnorr.ParsePubKey(senderHtlcKey[1:]) @@ -494,13 +515,6 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, return nil, err } - aggregateKey, _, _, err := musig2.AggregateKeys( - []*btcec.PublicKey{senderPubKey, receiverPubKey}, true, - ) - if err != nil { - return nil, err - } - // Create our success path script, we'll use this separately // to generate the success path leaf. successPathScript, err := GenSuccessPathScript( @@ -527,6 +541,31 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, rootHash := tree.RootNode.TapHash() + // Parse the pub keys used in the internal aggregate key. They are + // optional and may just be the same keys that are used for the script + // paths. + senderInternalPubKey, err := schnorr.ParsePubKey(senderInternalKey[1:]) + if err != nil { + return nil, err + } + + receiverInternalPubKey, err := schnorr.ParsePubKey( + receiverInternalKey[1:], + ) + if err != nil { + return nil, err + } + + // Calculate the internal aggregate key. + aggregateKey, _, _, err := musig2.AggregateKeys( + []*btcec.PublicKey{ + senderInternalPubKey, receiverInternalPubKey, + }, true, + ) + if err != nil { + return nil, err + } + // Calculate top level taproot key. taprootKey := txscript.ComputeTaprootOutputKey( aggregateKey.PreTweakedKey, rootHash[:], diff --git a/swap/htlc_test.go b/swap/htlc_test.go index 96b2f16..dd309e5 100644 --- a/swap/htlc_test.go +++ b/swap/htlc_test.go @@ -134,9 +134,9 @@ func TestHtlcV2(t *testing.T) { hash := sha256.Sum256(testPreimage[:]) // Create the htlc. - htlc, err := NewHtlc( - HtlcV2, testCltvExpiry, senderKey, receiverKey, hash, - HtlcP2WSH, &chaincfg.MainNetParams, + htlc, err := NewHtlcV2( + testCltvExpiry, senderKey, receiverKey, hash, + &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -287,10 +287,9 @@ func TestHtlcV2(t *testing.T) { bogusKey := [33]byte{0xb, 0xa, 0xd} // Create the htlc with the bogus key. - htlc, err = NewHtlc( - HtlcV2, testCltvExpiry, - bogusKey, receiverKey, hash, - HtlcP2WSH, &chaincfg.MainNetParams, + htlc, err = NewHtlcV2( + testCltvExpiry, bogusKey, receiverKey, + hash, &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -357,9 +356,9 @@ func TestHtlcV3(t *testing.T) { copy(receiverKey[:], receiverPubKey.SerializeCompressed()) copy(senderKey[:], senderPubKey.SerializeCompressed()) - htlc, err := NewHtlc( - HtlcV3, cltvExpiry, senderKey, receiverKey, - hashedPreimage, HtlcP2TR, &chaincfg.MainNetParams, + htlc, err := NewHtlcV3( + cltvExpiry, senderKey, receiverKey, senderKey, receiverKey, + hashedPreimage, &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -540,10 +539,10 @@ func TestHtlcV3(t *testing.T) { bogusKey.SerializeCompressed(), ) - htlc, err := NewHtlc( - HtlcV3, cltvExpiry, bogusKeyBytes, - receiverKey, hashedPreimage, HtlcP2TR, - &chaincfg.MainNetParams, + htlc, err := NewHtlcV3( + cltvExpiry, senderKey, + receiverKey, bogusKeyBytes, receiverKey, + hashedPreimage, &chaincfg.MainNetParams, ) require.NoError(t, err)