diff --git a/client.go b/client.go index 4581f0f..0bf6e46 100644 --- a/client.go +++ b/client.go @@ -213,29 +213,24 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { SwapHash: swp.Hash, LastUpdate: swp.LastUpdateTime(), } - scriptVersion := GetHtlcScriptVersion( - swp.Contract.ProtocolVersion, - ) - - outputType := swap.HtlcP2WSH - if scriptVersion == swap.HtlcV3 { - outputType = swap.HtlcP2TR - } - htlc, err := swap.NewHtlc( - scriptVersion, - swp.Contract.CltvExpiry, swp.Contract.SenderKey, - swp.Contract.ReceiverKey, swp.Hash, - outputType, s.lndServices.ChainParams, + htlc, err := GetHtlc( + swp.Hash, &swp.Contract.SwapContract, + s.lndServices.ChainParams, ) if err != nil { return nil, err } - if outputType == swap.HtlcP2TR { - swapInfo.HtlcAddressP2TR = htlc.Address - } else { + switch htlc.OutputType { + case swap.HtlcP2WSH: swapInfo.HtlcAddressP2WSH = htlc.Address + + case swap.HtlcP2TR: + swapInfo.HtlcAddressP2TR = htlc.Address + + default: + return nil, swap.ErrInvalidOutputType } swaps = append(swaps, swapInfo) @@ -250,34 +245,23 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { LastUpdate: swp.LastUpdateTime(), } - scriptVersion := GetHtlcScriptVersion( - swp.Contract.SwapContract.ProtocolVersion, + htlc, err := GetHtlc( + swp.Hash, &swp.Contract.SwapContract, + s.lndServices.ChainParams, ) + if err != nil { + return nil, err + } - if scriptVersion == swap.HtlcV3 { - htlcP2TR, err := swap.NewHtlc( - swap.HtlcV3, swp.Contract.CltvExpiry, - swp.Contract.SenderKey, swp.Contract.ReceiverKey, - swp.Hash, swap.HtlcP2TR, - s.lndServices.ChainParams, - ) - if err != nil { - return nil, err - } + switch htlc.OutputType { + case swap.HtlcP2WSH: + swapInfo.HtlcAddressP2WSH = htlc.Address - swapInfo.HtlcAddressP2TR = htlcP2TR.Address - } else { - htlcP2WSH, err := swap.NewHtlc( - swap.HtlcV2, swp.Contract.CltvExpiry, - swp.Contract.SenderKey, swp.Contract.ReceiverKey, - swp.Hash, swap.HtlcP2WSH, - s.lndServices.ChainParams, - ) - if err != nil { - return nil, err - } + case swap.HtlcP2TR: + swapInfo.HtlcAddressP2TR = htlc.Address - swapInfo.HtlcAddressP2WSH = htlcP2WSH.Address + default: + return nil, swap.ErrInvalidOutputType } swaps = append(swaps, swapInfo) diff --git a/client_test.go b/client_test.go index 7c8f88b..3bca073 100644 --- a/client_test.go +++ b/client_test.go @@ -284,16 +284,26 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, // Assert that the loopout htlc equals to the expected one. scriptVersion := GetHtlcScriptVersion(protocolVersion) + var htlc *swap.Htlc - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - outputType = swap.HtlcP2WSH + switch scriptVersion { + case swap.HtlcV2: + htlc, err = swap.NewHtlcV2( + pendingSwap.Contract.CltvExpiry, senderKey, + receiverKey, hash, &chaincfg.TestNet3Params, + ) + + case swap.HtlcV3: + htlc, err = swap.NewHtlcV3( + pendingSwap.Contract.CltvExpiry, senderKey, + receiverKey, senderKey, receiverKey, hash, + &chaincfg.TestNet3Params, + ) + + default: + t.Fatalf(swap.ErrInvalidScriptVersion.Error()) } - htlc, err := swap.NewHtlc( - scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey, - receiverKey, hash, outputType, &chaincfg.TestNet3Params, - ) require.NoError(t, err) require.Equal(t, htlc.PkScript, confIntent.PkScript) diff --git a/loopd/view.go b/loopd/view.go index f99fdef..3a7b600 100644 --- a/loopd/view.go +++ b/loopd/view.go @@ -7,7 +7,6 @@ import ( "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop" "github.com/lightninglabs/loop/loopdb" - "github.com/lightninglabs/loop/swap" ) // view prints all swaps currently in the database. @@ -49,24 +48,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { } for _, s := range swaps { - scriptVersion := loop.GetHtlcScriptVersion( - s.Contract.ProtocolVersion, - ) - - var outputType swap.HtlcOutputType - switch scriptVersion { - case swap.HtlcV2: - outputType = swap.HtlcP2WSH - - case swap.HtlcV3: - outputType = swap.HtlcP2TR - } - htlc, err := swap.NewHtlc( - loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion), - s.Contract.CltvExpiry, - s.Contract.SenderKey, - s.Contract.ReceiverKey, - s.Hash, outputType, chainParams, + htlc, err := loop.GetHtlc( + s.Hash, &s.Contract.SwapContract, chainParams, ) if err != nil { return err @@ -77,7 +60,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { s.Contract.InitiationTime, s.Contract.InitiationHeight, ) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) - fmt.Printf(" Htlc address: %v\n", htlc.Address) + fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType, + htlc.Address) fmt.Printf(" Uncharge channels: %v\n", s.Contract.OutgoingChanSet) @@ -113,12 +97,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { } for _, s := range swaps { - htlc, err := swap.NewHtlc( - loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion), - s.Contract.CltvExpiry, - s.Contract.SenderKey, - s.Contract.ReceiverKey, - s.Hash, swap.HtlcP2WSH, chainParams, + htlc, err := loop.GetHtlc( + s.Hash, &s.Contract.SwapContract, chainParams, ) if err != nil { return err @@ -129,7 +109,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { s.Contract.InitiationTime, s.Contract.InitiationHeight, ) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) - fmt.Printf(" Htlc address: %v\n", htlc.Address) + fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType, + htlc.Address) fmt.Printf(" Amt: %v, Expiry: %v\n", s.Contract.AmountRequested, s.Contract.CltvExpiry, ) diff --git a/loopin.go b/loopin.go index 8b3e49b..62d2a42 100644 --- a/loopin.go +++ b/loopin.go @@ -404,26 +404,26 @@ func validateLoopInContract(lnd *lndclient.LndServices, // initHtlcs creates and updates the native and nested segwit htlcs // of the loopInSwap. func (s *loopInSwap) initHtlcs() error { - if IsTaprootSwap(&s.SwapContract) { - htlcP2TR, err := s.swapKit.getHtlc(swap.HtlcP2TR) - if err != nil { - return err - } + htlc, err := GetHtlc( + s.hash, &s.SwapContract, s.swapKit.lnd.ChainParams, + ) + if err != nil { + return err + } - s.swapKit.log.Infof("Htlc address (P2TR): %v", htlcP2TR.Address) - s.htlcP2TR = htlcP2TR + switch htlc.OutputType { + case swap.HtlcP2WSH: + s.htlcP2WSH = htlc - return nil - } + case swap.HtlcP2TR: + s.htlcP2TR = htlc - htlcP2WSH, err := s.swapKit.getHtlc(swap.HtlcP2WSH) - if err != nil { - return err + default: + return fmt.Errorf("invalid output type") } - // Log htlc addresses for debugging. - s.swapKit.log.Infof("Htlc address (P2WSH): %v", htlcP2WSH.Address) - s.htlcP2WSH = htlcP2WSH + s.swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType, + htlc.Address) return nil } diff --git a/loopin_test.go b/loopin_test.go index e4a317b..53f2b35 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -455,21 +455,32 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, pendSwap.Loop.Events[0].Cost = cost } - scriptVersion := GetHtlcScriptVersion(storedVersion) + var ( + htlc *swap.Htlc + err error + ) - outputType := swap.HtlcP2WSH - if scriptVersion == swap.HtlcV3 { - outputType = swap.HtlcP2TR + switch GetHtlcScriptVersion(storedVersion) { + case swap.HtlcV2: + htlc, err = swap.NewHtlcV2( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, testPreimage.Hash(), + cfg.lnd.ChainParams, + ) + + case swap.HtlcV3: + htlc, err = swap.NewHtlcV3( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, contract.SenderKey, + contract.ReceiverKey, testPreimage.Hash(), + cfg.lnd.ChainParams, + ) + + default: + t.Fatalf("unknown HTLC script version") } - htlc, err := swap.NewHtlc( - scriptVersion, contract.CltvExpiry, contract.SenderKey, - contract.ReceiverKey, testPreimage.Hash(), outputType, - cfg.lnd.ChainParams, - ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract) if err != nil { diff --git a/loopout.go b/loopout.go index 30185e6..6b03bfd 100644 --- a/loopout.go +++ b/loopout.go @@ -201,21 +201,17 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, swapKit.lastUpdateTime = initiationTime - scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion()) - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - // Default to using P2WSH for legacy htlcs. - outputType = swap.HtlcP2WSH - } - // Create the htlc. - htlc, err := swapKit.getHtlc(outputType) + htlc, err := GetHtlc( + swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams, + ) if err != nil { return nil, err } // Log htlc address for debugging. - swapKit.log.Infof("Htlc address: %v", htlc.Address) + swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType, + htlc.Address) // Obtain the payment addr since we'll need it later for routing plugin // recommendation and possibly for cancel. @@ -261,15 +257,10 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig, hash, swap.TypeOut, cfg, &pend.Contract.SwapContract, ) - scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion) - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - // Default to using P2WSH for legacy htlcs. - outputType = swap.HtlcP2WSH - } - // Create the htlc. - htlc, err := swapKit.getHtlc(outputType) + htlc, err := GetHtlc( + swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams, + ) if err != nil { return nil, err } diff --git a/swap.go b/swap.go index a084e95..764af79 100644 --- a/swap.go +++ b/swap.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/btcsuite/btcd/chaincfg" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" @@ -67,14 +68,28 @@ func IsTaprootSwap(swapContract *loopdb.SwapContract) bool { return GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3 } -// getHtlc composes and returns the on-chain swap script. -func (s *swapKit) getHtlc(outputType swap.HtlcOutputType) (*swap.Htlc, error) { - return swap.NewHtlc( - GetHtlcScriptVersion(s.contract.ProtocolVersion), - s.contract.CltvExpiry, s.contract.SenderKey, - s.contract.ReceiverKey, s.hash, outputType, - s.swapConfig.lnd.ChainParams, - ) +// GetHtlc composes and returns the on-chain swap script. +func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract, + chainParams *chaincfg.Params) (*swap.Htlc, error) { + + switch GetHtlcScriptVersion(contract.ProtocolVersion) { + case swap.HtlcV2: + return swap.NewHtlcV2( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, hash, + chainParams, + ) + + case swap.HtlcV3: + return swap.NewHtlcV3( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, contract.SenderKey, + contract.ReceiverKey, hash, + chainParams, + ) + } + + return nil, swap.ErrInvalidScriptVersion } // swapInfo constructs and returns a filled SwapInfo from diff --git a/swap/htlc.go b/swap/htlc.go index 1e174cc..8a1101a 100644 --- a/swap/htlc.go +++ b/swap/htlc.go @@ -114,16 +114,16 @@ var ( // QuoteHtlcP2WSH is a template script just used for sweep fee // estimation. - QuoteHtlcP2WSH, _ = NewHtlc( - HtlcV2, ^int32(0), dummyPubKey, dummyPubKey, quoteHash, - HtlcP2WSH, &chaincfg.MainNetParams, + QuoteHtlcP2WSH, _ = NewHtlcV2( + ^int32(0), dummyPubKey, dummyPubKey, quoteHash, + &chaincfg.MainNetParams, ) // QuoteHtlcP2TR is a template script just used for sweep fee // estimation. - QuoteHtlcP2TR, _ = NewHtlc( - HtlcV3, ^int32(0), dummyPubKey, dummyPubKey, quoteHash, - HtlcP2TR, &chaincfg.MainNetParams, + QuoteHtlcP2TR, _ = NewHtlcV3( + ^int32(0), dummyPubKey, dummyPubKey, dummyPubKey, dummyPubKey, + quoteHash, &chaincfg.MainNetParams, ) // ErrInvalidScriptVersion is returned when an unknown htlc version @@ -135,6 +135,10 @@ var ( // selected for a v2 script. ErrInvalidOutputSelected = fmt.Errorf("taproot output selected for " + "non taproot htlc") + + // ErrInvalidOutputType is returned when an unknown output type is + // associated with a certain swap htlc. + ErrInvalidOutputType = fmt.Errorf("invalid htlc output type") ) // String returns the string value of HtlcOutputType. @@ -151,38 +155,54 @@ func (h HtlcOutputType) String() string { } } -// NewHtlc returns a new instance. For v3 scripts, an internal pubkey generated -// by both participants must be provided. -func NewHtlc(version ScriptVersion, cltvExpiry int32, - senderKey, receiverKey [33]byte, hash lntypes.Hash, - outputType HtlcOutputType, chainParams *chaincfg.Params) (*Htlc, error) { +// NewHtlcV2 returns a new V2 (P2WSH) HTLC instance. +func NewHtlcV2(cltvExpiry int32, senderKey, receiverKey [33]byte, + hash lntypes.Hash, chainParams *chaincfg.Params) (*Htlc, error) { - var ( - err error - htlc HtlcScript + htlc, err := newHTLCScriptV2( + cltvExpiry, senderKey, receiverKey, hash, ) - switch version { - case HtlcV2: - htlc, err = newHTLCScriptV2( - cltvExpiry, senderKey, receiverKey, hash, - ) - - case HtlcV3: - htlc, err = newHTLCScriptV3( - cltvExpiry, senderKey, receiverKey, hash, - ) + if err != nil { + return nil, err + } - default: - return nil, ErrInvalidScriptVersion + address, pkScript, sigScript, err := htlc.lockingConditions( + HtlcP2WSH, chainParams, + ) + if err != nil { + return nil, fmt.Errorf("could not get address: %w", err) } + return &Htlc{ + HtlcScript: htlc, + Hash: hash, + Version: HtlcV2, + PkScript: pkScript, + OutputType: HtlcP2WSH, + ChainParams: chainParams, + Address: address, + SigScript: sigScript, + }, nil +} + +// NewHtlcV3 returns a new V3 HTLC (P2TR) instance. Internal pubkey generated +// by both participants must be provided. +func NewHtlcV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, + senderKey, receiverKey [33]byte, hash lntypes.Hash, + chainParams *chaincfg.Params) (*Htlc, error) { + + htlc, err := newHTLCScriptV3( + cltvExpiry, senderInternalKey, receiverInternalKey, + senderKey, receiverKey, hash, + ) + if err != nil { return nil, err } address, pkScript, sigScript, err := htlc.lockingConditions( - outputType, chainParams, + HtlcP2TR, chainParams, ) if err != nil { return nil, fmt.Errorf("could not get address: %w", err) @@ -191,9 +211,9 @@ func NewHtlc(version ScriptVersion, cltvExpiry int32, return &Htlc{ HtlcScript: htlc, Hash: hash, - Version: version, + Version: HtlcV3, PkScript: pkScript, - OutputType: outputType, + OutputType: HtlcP2TR, ChainParams: chainParams, Address: address, SigScript: sigScript, @@ -481,7 +501,8 @@ type HtlcScriptV3 struct { } // newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script. -func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, +func newHTLCScriptV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, + senderHtlcKey, receiverHtlcKey [33]byte, swapHash lntypes.Hash) (*HtlcScriptV3, error) { senderPubKey, err := schnorr.ParsePubKey(senderHtlcKey[1:]) @@ -494,13 +515,6 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, return nil, err } - aggregateKey, _, _, err := musig2.AggregateKeys( - []*btcec.PublicKey{senderPubKey, receiverPubKey}, true, - ) - if err != nil { - return nil, err - } - // Create our success path script, we'll use this separately // to generate the success path leaf. successPathScript, err := GenSuccessPathScript( @@ -527,6 +541,31 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, rootHash := tree.RootNode.TapHash() + // Parse the pub keys used in the internal aggregate key. They are + // optional and may just be the same keys that are used for the script + // paths. + senderInternalPubKey, err := schnorr.ParsePubKey(senderInternalKey[1:]) + if err != nil { + return nil, err + } + + receiverInternalPubKey, err := schnorr.ParsePubKey( + receiverInternalKey[1:], + ) + if err != nil { + return nil, err + } + + // Calculate the internal aggregate key. + aggregateKey, _, _, err := musig2.AggregateKeys( + []*btcec.PublicKey{ + senderInternalPubKey, receiverInternalPubKey, + }, true, + ) + if err != nil { + return nil, err + } + // Calculate top level taproot key. taprootKey := txscript.ComputeTaprootOutputKey( aggregateKey.PreTweakedKey, rootHash[:], diff --git a/swap/htlc_test.go b/swap/htlc_test.go index 96b2f16..dd309e5 100644 --- a/swap/htlc_test.go +++ b/swap/htlc_test.go @@ -134,9 +134,9 @@ func TestHtlcV2(t *testing.T) { hash := sha256.Sum256(testPreimage[:]) // Create the htlc. - htlc, err := NewHtlc( - HtlcV2, testCltvExpiry, senderKey, receiverKey, hash, - HtlcP2WSH, &chaincfg.MainNetParams, + htlc, err := NewHtlcV2( + testCltvExpiry, senderKey, receiverKey, hash, + &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -287,10 +287,9 @@ func TestHtlcV2(t *testing.T) { bogusKey := [33]byte{0xb, 0xa, 0xd} // Create the htlc with the bogus key. - htlc, err = NewHtlc( - HtlcV2, testCltvExpiry, - bogusKey, receiverKey, hash, - HtlcP2WSH, &chaincfg.MainNetParams, + htlc, err = NewHtlcV2( + testCltvExpiry, bogusKey, receiverKey, + hash, &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -357,9 +356,9 @@ func TestHtlcV3(t *testing.T) { copy(receiverKey[:], receiverPubKey.SerializeCompressed()) copy(senderKey[:], senderPubKey.SerializeCompressed()) - htlc, err := NewHtlc( - HtlcV3, cltvExpiry, senderKey, receiverKey, - hashedPreimage, HtlcP2TR, &chaincfg.MainNetParams, + htlc, err := NewHtlcV3( + cltvExpiry, senderKey, receiverKey, senderKey, receiverKey, + hashedPreimage, &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -540,10 +539,10 @@ func TestHtlcV3(t *testing.T) { bogusKey.SerializeCompressed(), ) - htlc, err := NewHtlc( - HtlcV3, cltvExpiry, bogusKeyBytes, - receiverKey, hashedPreimage, HtlcP2TR, - &chaincfg.MainNetParams, + htlc, err := NewHtlcV3( + cltvExpiry, senderKey, + receiverKey, bogusKeyBytes, receiverKey, + hashedPreimage, &chaincfg.MainNetParams, ) require.NoError(t, err)