diff --git a/client_test.go b/client_test.go index ba3ff6e..a6297dd 100644 --- a/client_test.go +++ b/client_test.go @@ -13,7 +13,6 @@ import ( "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/test" - "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" @@ -277,30 +276,13 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, // Expect client to register for our expected number of confirmations. confIntent := ctx.AssertRegisterConf(preimageRevealed, int32(confs)) - // Assert that the loopout htlc equals to the expected one. - scriptVersion := GetHtlcScriptVersion(protocolVersion) - var htlc *swap.Htlc - - switch scriptVersion { - case swap.HtlcV2: - htlc, err = swap.NewHtlcV2( - pendingSwap.Contract.CltvExpiry, senderKey, - receiverKey, hash, &chaincfg.TestNet3Params, - ) - - case swap.HtlcV3: - htlc, err = swap.NewHtlcV3( - input.MuSig2Version040, - pendingSwap.Contract.CltvExpiry, senderKey, - receiverKey, senderKey, receiverKey, hash, - &chaincfg.TestNet3Params, - ) - - default: - t.Fatalf(swap.ErrInvalidScriptVersion.Error()) - } - + htlc, err := GetHtlc( + hash, &pendingSwap.Contract.SwapContract, + &chaincfg.TestNet3Params, + ) require.NoError(t, err) + + // Assert that the loopout htlc equals to the expected one. require.Equal(t, htlc.PkScript, confIntent.PkScript) signalSwapPaymentResult(nil) @@ -319,7 +301,7 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, func(r error) {}, func(r error) {}, preimageRevealed, - confIntent, scriptVersion, + confIntent, GetHtlcScriptVersion(protocolVersion), ) } diff --git a/loopin_test.go b/loopin_test.go index ca3b707..de4452b 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -8,10 +8,8 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/loop/loopdb" - "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/test" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/input" invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" @@ -450,37 +448,10 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, pendSwap.Loop.Events[0].Cost = cost } - var ( - htlc *swap.Htlc - err error + htlc, err := GetHtlc( + testPreimage.Hash(), &contract.SwapContract, + cfg.lnd.ChainParams, ) - - switch GetHtlcScriptVersion(storedVersion) { - case swap.HtlcV2: - htlc, err = swap.NewHtlcV2( - contract.CltvExpiry, - contract.HtlcKeys.SenderScriptKey, - contract.HtlcKeys.ReceiverScriptKey, - testPreimage.Hash(), - cfg.lnd.ChainParams, - ) - - case swap.HtlcV3: - htlc, err = swap.NewHtlcV3( - input.MuSig2Version040, - contract.CltvExpiry, - contract.HtlcKeys.SenderInternalPubKey, - contract.HtlcKeys.ReceiverInternalPubKey, - contract.HtlcKeys.SenderScriptKey, - contract.HtlcKeys.ReceiverScriptKey, - testPreimage.Hash(), - cfg.lnd.ChainParams, - ) - - default: - t.Fatalf("unknown HTLC script version") - } - require.NoError(t, err) err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract) diff --git a/loopout.go b/loopout.go index ff9bd92..eca55bc 100644 --- a/loopout.go +++ b/loopout.go @@ -1355,12 +1355,28 @@ func (s *loopOutSwap) createMuSig2SweepTxn( return nil, err } - signers := [][]byte{ - s.HtlcKeys.SenderInternalPubKey[1:], - s.HtlcKeys.ReceiverInternalPubKey[1:], + var ( + signers [][]byte + muSig2Verion input.MuSig2Version + ) + + // Depending on the MuSig2 version we either pass 32 byte Schnorr + // public keys or normal 33 byte public keys. + if s.ProtocolVersion >= loopdb.ProtocolVersionMuSig2 { + muSig2Verion = input.MuSig2Version100RC2 + signers = [][]byte{ + s.HtlcKeys.SenderInternalPubKey[:], + s.HtlcKeys.ReceiverInternalPubKey[:], + } + } else { + muSig2Verion = input.MuSig2Version040 + signers = [][]byte{ + s.HtlcKeys.SenderInternalPubKey[1:], + s.HtlcKeys.ReceiverInternalPubKey[1:], + } } - htlc, ok := s.htlc.HtlcScript.(*swap.HtlcScriptV3) + htlcScript, ok := s.htlc.HtlcScript.(*swap.HtlcScriptV3) if !ok { return nil, fmt.Errorf("non taproot htlc") } @@ -1368,9 +1384,8 @@ func (s *loopOutSwap) createMuSig2SweepTxn( // Now we're creating a local MuSig2 session using the receiver key's // key locator and the htlc's root hash. musig2SessionInfo, err := s.lnd.Signer.MuSig2CreateSession( - ctx, input.MuSig2Version040, - &s.HtlcKeys.ClientScriptKeyLocator, signers, - lndclient.MuSig2TaprootTweakOpt(htlc.RootHash[:], false), + ctx, muSig2Verion, &s.HtlcKeys.ClientScriptKeyLocator, signers, + lndclient.MuSig2TaprootTweakOpt(htlcScript.RootHash[:], false), ) if err != nil { return nil, err @@ -1434,7 +1449,7 @@ func (s *loopOutSwap) createMuSig2SweepTxn( // To be sure that we're good, parse and validate that the combined // signature is indeed valid for the sig hash and the internal pubkey. err = s.executeConfig.verifySchnorrSig( - htlc.TaprootKey, sigHash, finalSig, + htlcScript.TaprootKey, sigHash, finalSig, ) if err != nil { return nil, err diff --git a/swap.go b/swap.go index caa5d21..d816990 100644 --- a/swap.go +++ b/swap.go @@ -82,8 +82,15 @@ func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract, ) case swap.HtlcV3: + // Swaps that implement the new MuSig2 protocol will be expected + // to use the 1.0RC2 MuSig2 key derivation scheme. + muSig2Version := input.MuSig2Version040 + if contract.ProtocolVersion >= loopdb.ProtocolVersionMuSig2 { + muSig2Version = input.MuSig2Version100RC2 + } + return swap.NewHtlcV3( - input.MuSig2Version040, + muSig2Version, contract.CltvExpiry, contract.HtlcKeys.SenderInternalPubKey, contract.HtlcKeys.ReceiverInternalPubKey,