diff --git a/client_test.go b/client_test.go index 3ca2995..8539a88 100644 --- a/client_test.go +++ b/client_test.go @@ -13,6 +13,7 @@ import ( "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/test" + "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" @@ -285,6 +286,7 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, case swap.HtlcV3: htlc, err = swap.NewHtlcV3( + input.MuSig2Version040, pendingSwap.Contract.CltvExpiry, senderKey, receiverKey, senderKey, receiverKey, hash, &chaincfg.TestNet3Params, diff --git a/loopin_test.go b/loopin_test.go index a75e5ac..7bf3b8f 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -11,6 +11,7 @@ import ( "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/test" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/input" invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" @@ -459,6 +460,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, case swap.HtlcV3: htlc, err = swap.NewHtlcV3( + input.MuSig2Version040, contract.CltvExpiry, contract.SenderKey, contract.ReceiverKey, contract.SenderKey, contract.ReceiverKey, testPreimage.Hash(), diff --git a/swap.go b/swap.go index 3253277..6e2d947 100644 --- a/swap.go +++ b/swap.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" + "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" ) @@ -82,6 +83,7 @@ func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract, case swap.HtlcV3: return swap.NewHtlcV3( + input.MuSig2Version040, contract.CltvExpiry, contract.SenderKey, contract.ReceiverKey, contract.SenderKey, contract.ReceiverKey, hash, diff --git a/swap/htlc.go b/swap/htlc.go index 3c0f411..daa7ef3 100644 --- a/swap/htlc.go +++ b/swap/htlc.go @@ -122,8 +122,8 @@ var ( // QuoteHtlcP2TR is a template script just used for sweep fee // estimation. QuoteHtlcP2TR, _ = NewHtlcV3( - ^int32(0), dummyPubKey, dummyPubKey, dummyPubKey, dummyPubKey, - quoteHash, &chaincfg.MainNetParams, + input.MuSig2Version100RC2, ^int32(0), dummyPubKey, dummyPubKey, + dummyPubKey, dummyPubKey, quoteHash, &chaincfg.MainNetParams, ) // ErrInvalidScriptVersion is returned when an unknown htlc version @@ -188,13 +188,13 @@ func NewHtlcV2(cltvExpiry int32, senderKey, receiverKey [33]byte, // NewHtlcV3 returns a new V3 HTLC (P2TR) instance. Internal pubkey generated // by both participants must be provided. -func NewHtlcV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, - senderKey, receiverKey [33]byte, hash lntypes.Hash, - chainParams *chaincfg.Params) (*Htlc, error) { +func NewHtlcV3(muSig2Version input.MuSig2Version, cltvExpiry int32, + senderInternalKey, receiverInternalKey, senderKey, receiverKey [33]byte, + hash lntypes.Hash, chainParams *chaincfg.Params) (*Htlc, error) { htlc, err := newHTLCScriptV3( - cltvExpiry, senderInternalKey, receiverInternalKey, - senderKey, receiverKey, hash, + muSig2Version, cltvExpiry, senderInternalKey, + receiverInternalKey, senderKey, receiverKey, hash, ) if err != nil { @@ -504,17 +504,37 @@ type HtlcScriptV3 struct { RootHash chainhash.Hash } +// parsePubKey will parse a serialized public key into a btcec.PublicKey +// depending on the passed MuSig2 version. +func parsePubKey(muSig2Version input.MuSig2Version, key [33]byte) ( + *btcec.PublicKey, error) { + + // Make sure that we have the correct public keys depending on the + // MuSig2 version. + switch muSig2Version { + case input.MuSig2Version100RC2: + return btcec.ParsePubKey(key[:]) + + case input.MuSig2Version040: + return schnorr.ParsePubKey(key[1:]) + + default: + return nil, fmt.Errorf("unsupported MuSig2 version: %v", + muSig2Version) + } +} + // newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script. -func newHTLCScriptV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, - senderHtlcKey, receiverHtlcKey [33]byte, - swapHash lntypes.Hash) (*HtlcScriptV3, error) { +func newHTLCScriptV3(muSig2Version input.MuSig2Version, cltvExpiry int32, + senderInternalKey, receiverInternalKey, senderHtlcKey, + receiverHtlcKey [33]byte, swapHash lntypes.Hash) (*HtlcScriptV3, error) { - senderPubKey, err := schnorr.ParsePubKey(senderHtlcKey[1:]) + senderPubKey, err := parsePubKey(muSig2Version, senderHtlcKey) if err != nil { return nil, err } - receiverPubKey, err := schnorr.ParsePubKey(receiverHtlcKey[1:]) + receiverPubKey, err := parsePubKey(muSig2Version, receiverHtlcKey) if err != nil { return nil, err } @@ -548,38 +568,41 @@ func newHTLCScriptV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, // Parse the pub keys used in the internal aggregate key. They are // optional and may just be the same keys that are used for the script // paths. - senderInternalPubKey, err := schnorr.ParsePubKey(senderInternalKey[1:]) + senderInternalPubKey, err := parsePubKey( + muSig2Version, senderInternalKey, + ) if err != nil { return nil, err } - receiverInternalPubKey, err := schnorr.ParsePubKey( - receiverInternalKey[1:], + receiverInternalPubKey, err := parsePubKey( + muSig2Version, receiverInternalKey, ) if err != nil { return nil, err } + var aggregateKey *musig2.AggregateKey // Calculate the internal aggregate key. - aggregateKey, _, _, err := musig2.AggregateKeys( + aggregateKey, err = input.MuSig2CombineKeys( + muSig2Version, []*btcec.PublicKey{ senderInternalPubKey, receiverInternalPubKey, - }, true, + }, + true, + &input.MuSig2Tweaks{ + TaprootTweak: rootHash[:], + }, ) if err != nil { return nil, err } - // Calculate top level taproot key. - taprootKey := txscript.ComputeTaprootOutputKey( - aggregateKey.PreTweakedKey, rootHash[:], - ) - return &HtlcScriptV3{ timeoutScript: timeoutPathScript, successScript: successPathScript, InternalPubKey: aggregateKey.PreTweakedKey, - TaprootKey: taprootKey, + TaprootKey: aggregateKey.FinalKey, RootHash: rootHash, }, nil } diff --git a/swap/htlc_test.go b/swap/htlc_test.go index 09f9264..8c39fc1 100644 --- a/swap/htlc_test.go +++ b/swap/htlc_test.go @@ -335,6 +335,19 @@ func TestHtlcV2(t *testing.T) { // TestHtlcV3 tests the HTLC V3 script success and timeout spend cases. func TestHtlcV3(t *testing.T) { + versions := map[string]input.MuSig2Version{ + "MuSig2 0.4": input.MuSig2Version040, + "MuSig2 1.0RC2": input.MuSig2Version100RC2, + } + + for name, version := range versions { + t.Run(name, func(t *testing.T) { + testHtlcV3(t, version) + }) + } +} + +func testHtlcV3(t *testing.T, muSig2Version input.MuSig2Version) { var ( receiverKey [33]byte senderKey [33]byte @@ -355,8 +368,8 @@ func TestHtlcV3(t *testing.T) { copy(senderKey[:], senderPubKey.SerializeCompressed()) htlc, err := NewHtlcV3( - cltvExpiry, senderKey, receiverKey, senderKey, receiverKey, - hashedPreimage, &chaincfg.MainNetParams, + muSig2Version, cltvExpiry, senderKey, receiverKey, senderKey, + receiverKey, hashedPreimage, &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -538,7 +551,7 @@ func TestHtlcV3(t *testing.T) { ) htlc, err := NewHtlcV3( - cltvExpiry, senderKey, + muSig2Version, cltvExpiry, senderKey, receiverKey, bogusKeyBytes, receiverKey, hashedPreimage, &chaincfg.MainNetParams, )