From bdb4b773ed82599858798f40f9132d9303266a01 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Wed, 16 Nov 2022 19:01:28 +0100 Subject: [PATCH 1/2] swap: refactor htlc construction to allow passing of internal keys This commit is a refactor of how we construct htlcs to make it possible to pass in internal keys for the sender and receiver when creating P2TR htlcs. Furthermore the commit also cleans up constructors to not pass in script versions and output types to make the code more readable. --- client.go | 64 ++++++++++---------------- client_test.go | 24 +++++++--- loopd/view.go | 35 ++++---------- loopin.go | 30 ++++++------ loopin_test.go | 35 +++++++++----- loopout.go | 25 ++++------ swap.go | 31 +++++++++---- swap/htlc.go | 113 +++++++++++++++++++++++++++++++--------------- swap/htlc_test.go | 27 ++++++----- 9 files changed, 207 insertions(+), 177 deletions(-) diff --git a/client.go b/client.go index 4581f0f..0bf6e46 100644 --- a/client.go +++ b/client.go @@ -213,29 +213,24 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { SwapHash: swp.Hash, LastUpdate: swp.LastUpdateTime(), } - scriptVersion := GetHtlcScriptVersion( - swp.Contract.ProtocolVersion, - ) - - outputType := swap.HtlcP2WSH - if scriptVersion == swap.HtlcV3 { - outputType = swap.HtlcP2TR - } - htlc, err := swap.NewHtlc( - scriptVersion, - swp.Contract.CltvExpiry, swp.Contract.SenderKey, - swp.Contract.ReceiverKey, swp.Hash, - outputType, s.lndServices.ChainParams, + htlc, err := GetHtlc( + swp.Hash, &swp.Contract.SwapContract, + s.lndServices.ChainParams, ) if err != nil { return nil, err } - if outputType == swap.HtlcP2TR { - swapInfo.HtlcAddressP2TR = htlc.Address - } else { + switch htlc.OutputType { + case swap.HtlcP2WSH: swapInfo.HtlcAddressP2WSH = htlc.Address + + case swap.HtlcP2TR: + swapInfo.HtlcAddressP2TR = htlc.Address + + default: + return nil, swap.ErrInvalidOutputType } swaps = append(swaps, swapInfo) @@ -250,34 +245,23 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { LastUpdate: swp.LastUpdateTime(), } - scriptVersion := GetHtlcScriptVersion( - swp.Contract.SwapContract.ProtocolVersion, + htlc, err := GetHtlc( + swp.Hash, &swp.Contract.SwapContract, + s.lndServices.ChainParams, ) + if err != nil { + return nil, err + } - if scriptVersion == swap.HtlcV3 { - htlcP2TR, err := swap.NewHtlc( - swap.HtlcV3, swp.Contract.CltvExpiry, - swp.Contract.SenderKey, swp.Contract.ReceiverKey, - swp.Hash, swap.HtlcP2TR, - s.lndServices.ChainParams, - ) - if err != nil { - return nil, err - } + switch htlc.OutputType { + case swap.HtlcP2WSH: + swapInfo.HtlcAddressP2WSH = htlc.Address - swapInfo.HtlcAddressP2TR = htlcP2TR.Address - } else { - htlcP2WSH, err := swap.NewHtlc( - swap.HtlcV2, swp.Contract.CltvExpiry, - swp.Contract.SenderKey, swp.Contract.ReceiverKey, - swp.Hash, swap.HtlcP2WSH, - s.lndServices.ChainParams, - ) - if err != nil { - return nil, err - } + case swap.HtlcP2TR: + swapInfo.HtlcAddressP2TR = htlc.Address - swapInfo.HtlcAddressP2WSH = htlcP2WSH.Address + default: + return nil, swap.ErrInvalidOutputType } swaps = append(swaps, swapInfo) diff --git a/client_test.go b/client_test.go index 7c8f88b..3bca073 100644 --- a/client_test.go +++ b/client_test.go @@ -284,16 +284,26 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, // Assert that the loopout htlc equals to the expected one. scriptVersion := GetHtlcScriptVersion(protocolVersion) + var htlc *swap.Htlc - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - outputType = swap.HtlcP2WSH + switch scriptVersion { + case swap.HtlcV2: + htlc, err = swap.NewHtlcV2( + pendingSwap.Contract.CltvExpiry, senderKey, + receiverKey, hash, &chaincfg.TestNet3Params, + ) + + case swap.HtlcV3: + htlc, err = swap.NewHtlcV3( + pendingSwap.Contract.CltvExpiry, senderKey, + receiverKey, senderKey, receiverKey, hash, + &chaincfg.TestNet3Params, + ) + + default: + t.Fatalf(swap.ErrInvalidScriptVersion.Error()) } - htlc, err := swap.NewHtlc( - scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey, - receiverKey, hash, outputType, &chaincfg.TestNet3Params, - ) require.NoError(t, err) require.Equal(t, htlc.PkScript, confIntent.PkScript) diff --git a/loopd/view.go b/loopd/view.go index f99fdef..3a7b600 100644 --- a/loopd/view.go +++ b/loopd/view.go @@ -7,7 +7,6 @@ import ( "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop" "github.com/lightninglabs/loop/loopdb" - "github.com/lightninglabs/loop/swap" ) // view prints all swaps currently in the database. @@ -49,24 +48,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { } for _, s := range swaps { - scriptVersion := loop.GetHtlcScriptVersion( - s.Contract.ProtocolVersion, - ) - - var outputType swap.HtlcOutputType - switch scriptVersion { - case swap.HtlcV2: - outputType = swap.HtlcP2WSH - - case swap.HtlcV3: - outputType = swap.HtlcP2TR - } - htlc, err := swap.NewHtlc( - loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion), - s.Contract.CltvExpiry, - s.Contract.SenderKey, - s.Contract.ReceiverKey, - s.Hash, outputType, chainParams, + htlc, err := loop.GetHtlc( + s.Hash, &s.Contract.SwapContract, chainParams, ) if err != nil { return err @@ -77,7 +60,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { s.Contract.InitiationTime, s.Contract.InitiationHeight, ) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) - fmt.Printf(" Htlc address: %v\n", htlc.Address) + fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType, + htlc.Address) fmt.Printf(" Uncharge channels: %v\n", s.Contract.OutgoingChanSet) @@ -113,12 +97,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { } for _, s := range swaps { - htlc, err := swap.NewHtlc( - loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion), - s.Contract.CltvExpiry, - s.Contract.SenderKey, - s.Contract.ReceiverKey, - s.Hash, swap.HtlcP2WSH, chainParams, + htlc, err := loop.GetHtlc( + s.Hash, &s.Contract.SwapContract, chainParams, ) if err != nil { return err @@ -129,7 +109,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { s.Contract.InitiationTime, s.Contract.InitiationHeight, ) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) - fmt.Printf(" Htlc address: %v\n", htlc.Address) + fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType, + htlc.Address) fmt.Printf(" Amt: %v, Expiry: %v\n", s.Contract.AmountRequested, s.Contract.CltvExpiry, ) diff --git a/loopin.go b/loopin.go index 8b3e49b..62d2a42 100644 --- a/loopin.go +++ b/loopin.go @@ -404,26 +404,26 @@ func validateLoopInContract(lnd *lndclient.LndServices, // initHtlcs creates and updates the native and nested segwit htlcs // of the loopInSwap. func (s *loopInSwap) initHtlcs() error { - if IsTaprootSwap(&s.SwapContract) { - htlcP2TR, err := s.swapKit.getHtlc(swap.HtlcP2TR) - if err != nil { - return err - } + htlc, err := GetHtlc( + s.hash, &s.SwapContract, s.swapKit.lnd.ChainParams, + ) + if err != nil { + return err + } - s.swapKit.log.Infof("Htlc address (P2TR): %v", htlcP2TR.Address) - s.htlcP2TR = htlcP2TR + switch htlc.OutputType { + case swap.HtlcP2WSH: + s.htlcP2WSH = htlc - return nil - } + case swap.HtlcP2TR: + s.htlcP2TR = htlc - htlcP2WSH, err := s.swapKit.getHtlc(swap.HtlcP2WSH) - if err != nil { - return err + default: + return fmt.Errorf("invalid output type") } - // Log htlc addresses for debugging. - s.swapKit.log.Infof("Htlc address (P2WSH): %v", htlcP2WSH.Address) - s.htlcP2WSH = htlcP2WSH + s.swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType, + htlc.Address) return nil } diff --git a/loopin_test.go b/loopin_test.go index e4a317b..53f2b35 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -455,21 +455,32 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, pendSwap.Loop.Events[0].Cost = cost } - scriptVersion := GetHtlcScriptVersion(storedVersion) + var ( + htlc *swap.Htlc + err error + ) - outputType := swap.HtlcP2WSH - if scriptVersion == swap.HtlcV3 { - outputType = swap.HtlcP2TR + switch GetHtlcScriptVersion(storedVersion) { + case swap.HtlcV2: + htlc, err = swap.NewHtlcV2( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, testPreimage.Hash(), + cfg.lnd.ChainParams, + ) + + case swap.HtlcV3: + htlc, err = swap.NewHtlcV3( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, contract.SenderKey, + contract.ReceiverKey, testPreimage.Hash(), + cfg.lnd.ChainParams, + ) + + default: + t.Fatalf("unknown HTLC script version") } - htlc, err := swap.NewHtlc( - scriptVersion, contract.CltvExpiry, contract.SenderKey, - contract.ReceiverKey, testPreimage.Hash(), outputType, - cfg.lnd.ChainParams, - ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract) if err != nil { diff --git a/loopout.go b/loopout.go index 30185e6..6b03bfd 100644 --- a/loopout.go +++ b/loopout.go @@ -201,21 +201,17 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, swapKit.lastUpdateTime = initiationTime - scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion()) - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - // Default to using P2WSH for legacy htlcs. - outputType = swap.HtlcP2WSH - } - // Create the htlc. - htlc, err := swapKit.getHtlc(outputType) + htlc, err := GetHtlc( + swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams, + ) if err != nil { return nil, err } // Log htlc address for debugging. - swapKit.log.Infof("Htlc address: %v", htlc.Address) + swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType, + htlc.Address) // Obtain the payment addr since we'll need it later for routing plugin // recommendation and possibly for cancel. @@ -261,15 +257,10 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig, hash, swap.TypeOut, cfg, &pend.Contract.SwapContract, ) - scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion) - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - // Default to using P2WSH for legacy htlcs. - outputType = swap.HtlcP2WSH - } - // Create the htlc. - htlc, err := swapKit.getHtlc(outputType) + htlc, err := GetHtlc( + swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams, + ) if err != nil { return nil, err } diff --git a/swap.go b/swap.go index a084e95..764af79 100644 --- a/swap.go +++ b/swap.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/btcsuite/btcd/chaincfg" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" @@ -67,14 +68,28 @@ func IsTaprootSwap(swapContract *loopdb.SwapContract) bool { return GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3 } -// getHtlc composes and returns the on-chain swap script. -func (s *swapKit) getHtlc(outputType swap.HtlcOutputType) (*swap.Htlc, error) { - return swap.NewHtlc( - GetHtlcScriptVersion(s.contract.ProtocolVersion), - s.contract.CltvExpiry, s.contract.SenderKey, - s.contract.ReceiverKey, s.hash, outputType, - s.swapConfig.lnd.ChainParams, - ) +// GetHtlc composes and returns the on-chain swap script. +func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract, + chainParams *chaincfg.Params) (*swap.Htlc, error) { + + switch GetHtlcScriptVersion(contract.ProtocolVersion) { + case swap.HtlcV2: + return swap.NewHtlcV2( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, hash, + chainParams, + ) + + case swap.HtlcV3: + return swap.NewHtlcV3( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, contract.SenderKey, + contract.ReceiverKey, hash, + chainParams, + ) + } + + return nil, swap.ErrInvalidScriptVersion } // swapInfo constructs and returns a filled SwapInfo from diff --git a/swap/htlc.go b/swap/htlc.go index 1e174cc..8a1101a 100644 --- a/swap/htlc.go +++ b/swap/htlc.go @@ -114,16 +114,16 @@ var ( // QuoteHtlcP2WSH is a template script just used for sweep fee // estimation. - QuoteHtlcP2WSH, _ = NewHtlc( - HtlcV2, ^int32(0), dummyPubKey, dummyPubKey, quoteHash, - HtlcP2WSH, &chaincfg.MainNetParams, + QuoteHtlcP2WSH, _ = NewHtlcV2( + ^int32(0), dummyPubKey, dummyPubKey, quoteHash, + &chaincfg.MainNetParams, ) // QuoteHtlcP2TR is a template script just used for sweep fee // estimation. - QuoteHtlcP2TR, _ = NewHtlc( - HtlcV3, ^int32(0), dummyPubKey, dummyPubKey, quoteHash, - HtlcP2TR, &chaincfg.MainNetParams, + QuoteHtlcP2TR, _ = NewHtlcV3( + ^int32(0), dummyPubKey, dummyPubKey, dummyPubKey, dummyPubKey, + quoteHash, &chaincfg.MainNetParams, ) // ErrInvalidScriptVersion is returned when an unknown htlc version @@ -135,6 +135,10 @@ var ( // selected for a v2 script. ErrInvalidOutputSelected = fmt.Errorf("taproot output selected for " + "non taproot htlc") + + // ErrInvalidOutputType is returned when an unknown output type is + // associated with a certain swap htlc. + ErrInvalidOutputType = fmt.Errorf("invalid htlc output type") ) // String returns the string value of HtlcOutputType. @@ -151,38 +155,54 @@ func (h HtlcOutputType) String() string { } } -// NewHtlc returns a new instance. For v3 scripts, an internal pubkey generated -// by both participants must be provided. -func NewHtlc(version ScriptVersion, cltvExpiry int32, - senderKey, receiverKey [33]byte, hash lntypes.Hash, - outputType HtlcOutputType, chainParams *chaincfg.Params) (*Htlc, error) { +// NewHtlcV2 returns a new V2 (P2WSH) HTLC instance. +func NewHtlcV2(cltvExpiry int32, senderKey, receiverKey [33]byte, + hash lntypes.Hash, chainParams *chaincfg.Params) (*Htlc, error) { - var ( - err error - htlc HtlcScript + htlc, err := newHTLCScriptV2( + cltvExpiry, senderKey, receiverKey, hash, ) - switch version { - case HtlcV2: - htlc, err = newHTLCScriptV2( - cltvExpiry, senderKey, receiverKey, hash, - ) - - case HtlcV3: - htlc, err = newHTLCScriptV3( - cltvExpiry, senderKey, receiverKey, hash, - ) + if err != nil { + return nil, err + } - default: - return nil, ErrInvalidScriptVersion + address, pkScript, sigScript, err := htlc.lockingConditions( + HtlcP2WSH, chainParams, + ) + if err != nil { + return nil, fmt.Errorf("could not get address: %w", err) } + return &Htlc{ + HtlcScript: htlc, + Hash: hash, + Version: HtlcV2, + PkScript: pkScript, + OutputType: HtlcP2WSH, + ChainParams: chainParams, + Address: address, + SigScript: sigScript, + }, nil +} + +// NewHtlcV3 returns a new V3 HTLC (P2TR) instance. Internal pubkey generated +// by both participants must be provided. +func NewHtlcV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, + senderKey, receiverKey [33]byte, hash lntypes.Hash, + chainParams *chaincfg.Params) (*Htlc, error) { + + htlc, err := newHTLCScriptV3( + cltvExpiry, senderInternalKey, receiverInternalKey, + senderKey, receiverKey, hash, + ) + if err != nil { return nil, err } address, pkScript, sigScript, err := htlc.lockingConditions( - outputType, chainParams, + HtlcP2TR, chainParams, ) if err != nil { return nil, fmt.Errorf("could not get address: %w", err) @@ -191,9 +211,9 @@ func NewHtlc(version ScriptVersion, cltvExpiry int32, return &Htlc{ HtlcScript: htlc, Hash: hash, - Version: version, + Version: HtlcV3, PkScript: pkScript, - OutputType: outputType, + OutputType: HtlcP2TR, ChainParams: chainParams, Address: address, SigScript: sigScript, @@ -481,7 +501,8 @@ type HtlcScriptV3 struct { } // newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script. -func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, +func newHTLCScriptV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, + senderHtlcKey, receiverHtlcKey [33]byte, swapHash lntypes.Hash) (*HtlcScriptV3, error) { senderPubKey, err := schnorr.ParsePubKey(senderHtlcKey[1:]) @@ -494,13 +515,6 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, return nil, err } - aggregateKey, _, _, err := musig2.AggregateKeys( - []*btcec.PublicKey{senderPubKey, receiverPubKey}, true, - ) - if err != nil { - return nil, err - } - // Create our success path script, we'll use this separately // to generate the success path leaf. successPathScript, err := GenSuccessPathScript( @@ -527,6 +541,31 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, rootHash := tree.RootNode.TapHash() + // Parse the pub keys used in the internal aggregate key. They are + // optional and may just be the same keys that are used for the script + // paths. + senderInternalPubKey, err := schnorr.ParsePubKey(senderInternalKey[1:]) + if err != nil { + return nil, err + } + + receiverInternalPubKey, err := schnorr.ParsePubKey( + receiverInternalKey[1:], + ) + if err != nil { + return nil, err + } + + // Calculate the internal aggregate key. + aggregateKey, _, _, err := musig2.AggregateKeys( + []*btcec.PublicKey{ + senderInternalPubKey, receiverInternalPubKey, + }, true, + ) + if err != nil { + return nil, err + } + // Calculate top level taproot key. taprootKey := txscript.ComputeTaprootOutputKey( aggregateKey.PreTweakedKey, rootHash[:], diff --git a/swap/htlc_test.go b/swap/htlc_test.go index 96b2f16..dd309e5 100644 --- a/swap/htlc_test.go +++ b/swap/htlc_test.go @@ -134,9 +134,9 @@ func TestHtlcV2(t *testing.T) { hash := sha256.Sum256(testPreimage[:]) // Create the htlc. - htlc, err := NewHtlc( - HtlcV2, testCltvExpiry, senderKey, receiverKey, hash, - HtlcP2WSH, &chaincfg.MainNetParams, + htlc, err := NewHtlcV2( + testCltvExpiry, senderKey, receiverKey, hash, + &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -287,10 +287,9 @@ func TestHtlcV2(t *testing.T) { bogusKey := [33]byte{0xb, 0xa, 0xd} // Create the htlc with the bogus key. - htlc, err = NewHtlc( - HtlcV2, testCltvExpiry, - bogusKey, receiverKey, hash, - HtlcP2WSH, &chaincfg.MainNetParams, + htlc, err = NewHtlcV2( + testCltvExpiry, bogusKey, receiverKey, + hash, &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -357,9 +356,9 @@ func TestHtlcV3(t *testing.T) { copy(receiverKey[:], receiverPubKey.SerializeCompressed()) copy(senderKey[:], senderPubKey.SerializeCompressed()) - htlc, err := NewHtlc( - HtlcV3, cltvExpiry, senderKey, receiverKey, - hashedPreimage, HtlcP2TR, &chaincfg.MainNetParams, + htlc, err := NewHtlcV3( + cltvExpiry, senderKey, receiverKey, senderKey, receiverKey, + hashedPreimage, &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -540,10 +539,10 @@ func TestHtlcV3(t *testing.T) { bogusKey.SerializeCompressed(), ) - htlc, err := NewHtlc( - HtlcV3, cltvExpiry, bogusKeyBytes, - receiverKey, hashedPreimage, HtlcP2TR, - &chaincfg.MainNetParams, + htlc, err := NewHtlcV3( + cltvExpiry, senderKey, + receiverKey, bogusKeyBytes, receiverKey, + hashedPreimage, &chaincfg.MainNetParams, ) require.NoError(t, err) From 049b17ff969d4dfec2e01f2687dd4540b5a0f4b9 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 29 Nov 2022 21:36:14 +0100 Subject: [PATCH 2/2] misc: refactor loop tests to use require where possible --- client_test.go | 30 ++++------ loopd/swapclient_server_test.go | 26 ++++----- loopin_test.go | 51 +++++------------ loopin_testcontext_test.go | 5 +- loopout_test.go | 98 ++++++++++----------------------- store_mock_test.go | 16 ++---- swap/htlc_test.go | 4 +- test/context.go | 33 +++++------ test/testutils.go | 10 ++-- testcontext_test.go | 18 ++---- 10 files changed, 99 insertions(+), 192 deletions(-) diff --git a/client_test.go b/client_test.go index 3bca073..b49e0ba 100644 --- a/client_test.go +++ b/client_test.go @@ -1,7 +1,6 @@ package loop import ( - "bytes" "context" "crypto/sha256" "errors" @@ -57,9 +56,7 @@ func TestLoopOutSuccess(t *testing.T) { // Initiate loop out. info, err := ctx.swapClient.LoopOut(context.Background(), &req) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) ctx.assertStored() ctx.assertStatus(loopdb.StateInitiated) @@ -84,9 +81,7 @@ func TestLoopOutFailOffchain(t *testing.T) { ctx := createClientTestContext(t, nil) _, err := ctx.swapClient.LoopOut(context.Background(), testRequest) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) ctx.assertStored() ctx.assertStatus(loopdb.StateInitiated) @@ -208,14 +203,10 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, amt := btcutil.Amount(50000) swapPayReq, err := getInvoice(hash, amt, swapInvoiceDesc) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) prePayReq, err := getInvoice(hash, 100, prepayInvoiceDesc) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) _, senderPubKey := test.CreateKey(1) var senderKey [33]byte @@ -373,10 +364,11 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash, // Expect client on-chain sweep of HTLC. sweepTx := ctx.ReceiveTx() - if !bytes.Equal(sweepTx.TxIn[0].PreviousOutPoint.Hash[:], - htlcOutpoint.Hash[:]) { - ctx.T.Fatalf("client not sweeping from htlc tx") - } + require.Equal( + ctx.T, htlcOutpoint.Hash[:], + sweepTx.TxIn[0].PreviousOutPoint.Hash[:], + "client not sweeping from htlc tx", + ) var preImageIndex int switch scriptVersion { @@ -390,9 +382,7 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash, // Check preimage. clientPreImage := sweepTx.TxIn[0].Witness[preImageIndex] clientPreImageHash := sha256.Sum256(clientPreImage) - if clientPreImageHash != hash { - ctx.T.Fatalf("incorrect preimage") - } + require.Equal(ctx.T, hash, lntypes.Hash(clientPreImageHash)) // Since we successfully published our sweep, we expect the preimage to // have been pushed to our mock server. diff --git a/loopd/swapclient_server_test.go b/loopd/swapclient_server_test.go index 53846ba..585d87f 100644 --- a/loopd/swapclient_server_test.go +++ b/loopd/swapclient_server_test.go @@ -130,16 +130,13 @@ func TestValidateConfTarget(t *testing.T) { test.confTarget, defaultConf, ) - haveErr := err != nil - if haveErr != test.expectErr { - t.Fatalf("expected err: %v, got: %v", - test.expectErr, err) + if test.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) } - if target != test.expectedTarget { - t.Fatalf("expected: %v, got: %v", - test.expectedTarget, target) - } + require.Equal(t, test.expectedTarget, target) }) } } @@ -199,16 +196,13 @@ func TestValidateLoopInRequest(t *testing.T) { test.confTarget, external, ) - haveErr := err != nil - if haveErr != test.expectErr { - t.Fatalf("expected err: %v, got: %v", - test.expectErr, err) + if test.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) } - if conf != test.expectedTarget { - t.Fatalf("expected: %v, got: %v", - test.expectedTarget, conf) - } + require.Equal(t, test.expectedTarget, conf) }) } } diff --git a/loopin_test.go b/loopin_test.go index 53f2b35..6940b68 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -58,9 +58,8 @@ func testLoopInSuccess(t *testing.T) { context.Background(), cfg, height, req, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + inSwap := initResult.swap ctx.store.assertLoopInStored() @@ -142,10 +141,7 @@ func testLoopInSuccess(t *testing.T) { ctx.assertState(loopdb.StateSuccess) ctx.store.assertLoopInState(loopdb.StateSuccess) - err = <-errChan - if err != nil { - t.Fatal(err) - } + require.NoError(t, <-errChan) } // TestLoopInTimeout tests scenarios where the server doesn't sweep the htlc @@ -215,9 +211,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { context.Background(), cfg, height, &req, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) inSwap := initResult.swap ctx.store.assertLoopInStored() @@ -289,11 +283,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { ctx.assertState(loopdb.StateFailIncorrectHtlcAmt) ctx.store.assertLoopInState(loopdb.StateFailIncorrectHtlcAmt) - err = <-errChan - if err != nil { - t.Fatal(err) - } - + require.NoError(t, <-errChan) return } @@ -308,9 +298,11 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { // Expect a signing request for the htlc tx output value. signReq := <-ctx.lnd.SignOutputRawChannel - if signReq.SignDescriptors[0].Output.Value != htlcTx.TxOut[0].Value { - t.Fatal("invalid signing amount") - } + require.Equal( + t, htlcTx.TxOut[0].Value, + signReq.SignDescriptors[0].Output.Value, + "invalid signing amount", + ) // Expect timeout tx to be published. timeoutTx := <-ctx.lnd.TxPublishChannel @@ -341,10 +333,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { state := ctx.store.assertLoopInState(loopdb.StateFailTimeout) require.Equal(t, cost, state.Cost) - err = <-errChan - if err != nil { - t.Fatal(err) - } + require.NoError(t, <-errChan) } // TestLoopInResume tests resuming swaps in various states. @@ -483,17 +472,10 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, require.NoError(t, err) err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - inSwap, err := resumeLoopInSwap( - context.Background(), cfg, - pendSwap, - ) - if err != nil { - t.Fatal(err) - } + inSwap, err := resumeLoopInSwap(context.Background(), cfg, pendSwap) + require.NoError(t, err) var height int32 if expired { @@ -512,10 +494,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, }() defer func() { - err = <-errChan - if err != nil { - t.Fatal(err) - } + require.NoError(t, <-errChan) select { case <-ctx.lnd.SendPaymentChannel: diff --git a/loopin_testcontext_test.go b/loopin_testcontext_test.go index 077622e..ca6b024 100644 --- a/loopin_testcontext_test.go +++ b/loopin_testcontext_test.go @@ -63,10 +63,7 @@ func newLoopInTestContext(t *testing.T) *loopInTestContext { func (c *loopInTestContext) assertState(expectedState loopdb.SwapState) { state := <-c.statusChan - if state.State != expectedState { - c.t.Fatalf("expected state %v but got %v", expectedState, - state.State) - } + require.Equal(c.t, expectedState, state.State) } // assertSubscribeInvoice asserts that the client subscribes to invoice updates diff --git a/loopout_test.go b/loopout_test.go index ea41d10..67566e8 100644 --- a/loopout_test.go +++ b/loopout_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "math" - "reflect" "testing" "time" @@ -66,7 +65,7 @@ func testLoopOutPaymentParameters(t *testing.T) { blockEpochChan := make(chan interface{}) statusChan := make(chan SwapInfo) - const maxParts = 5 + const maxParts = uint32(5) chanSet := loopdb.ChannelSet{2, 3} @@ -77,9 +76,7 @@ func testLoopOutPaymentParameters(t *testing.T) { initResult, err := newLoopOutSwap( context.Background(), cfg, height, &req, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) swap := initResult.swap // Execute the swap in its own goroutine. @@ -105,9 +102,7 @@ func testLoopOutPaymentParameters(t *testing.T) { store.assertLoopOutStored() state := <-statusChan - if state.State != loopdb.StateInitiated { - t.Fatal("unexpected state") - } + require.Equal(t, loopdb.StateInitiated, state.State) // Check that the SwapInfo contains the outgoing chan set require.Equal(t, chanSet, state.OutgoingChanSet) @@ -130,18 +125,12 @@ func testLoopOutPaymentParameters(t *testing.T) { } // Assert that it is sent as a multi-part payment. - if swapPayment.MaxParts != maxParts { - t.Fatalf("Expected %v parts, but got %v", - maxParts, swapPayment.MaxParts) - } + require.Equal(t, maxParts, swapPayment.MaxParts) // Verify the outgoing channel set restriction. - if !reflect.DeepEqual( - []uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds, - ) { - - t.Fatalf("Unexpected outgoing channel set") - } + require.Equal( + t, []uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds, + ) // Swap is expected to register for confirmation of the htlc. Assert // this to prevent a blocked channel in the mock. @@ -152,10 +141,7 @@ func testLoopOutPaymentParameters(t *testing.T) { cancel() // Expect the swap to signal that it was cancelled. - err = <-errChan - if err != context.Canceled { - t.Fatal(err) - } + require.Equal(t, context.Canceled, <-errChan) } // TestLateHtlcPublish tests that the client is not revealing the preimage if @@ -198,9 +184,7 @@ func testLateHtlcPublish(t *testing.T) { initResult, err := newLoopOutSwap( context.Background(), cfg, height, testRequest, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) swap := initResult.swap sweeper := &sweep.Sweeper{Lnd: &lnd.LndServices} @@ -225,11 +209,8 @@ func testLateHtlcPublish(t *testing.T) { }() store.assertLoopOutStored() - - state := <-statusChan - if state.State != loopdb.StateInitiated { - t.Fatal("unexpected state") - } + status := <-statusChan + require.Equal(t, loopdb.StateInitiated, status.State) signalSwapPaymentResult := ctx.AssertPaid(swapInvoiceDesc) signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) @@ -249,15 +230,9 @@ func testLateHtlcPublish(t *testing.T) { store.assertStoreFinished(loopdb.StateFailTimeout) - status := <-statusChan - if status.State != loopdb.StateFailTimeout { - t.Fatal("unexpected state") - } - - err = <-errChan - if err != nil { - t.Fatal(err) - } + status = <-statusChan + require.Equal(t, loopdb.StateFailTimeout, status.State) + require.NoError(t, <-errChan) } // TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a @@ -304,9 +279,7 @@ func testCustomSweepConfTarget(t *testing.T) { initResult, err := newLoopOutSwap( context.Background(), cfg, ctx.Lnd.Height, &testReq, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) swap := initResult.swap // Set up the required dependencies to execute the swap. @@ -339,9 +312,7 @@ func testCustomSweepConfTarget(t *testing.T) { // The swap should be found in its initial state. cfg.store.(*storeMock).assertLoopOutStored() state := <-statusChan - if state.State != loopdb.StateInitiated { - t.Fatal("unexpected state") - } + require.Equal(t, loopdb.StateInitiated, state.State) // We'll then pay both the swap and prepay invoice, which should trigger // the server to publish the on-chain HTLC. @@ -381,10 +352,7 @@ func testCustomSweepConfTarget(t *testing.T) { cfg.store.(*storeMock).assertLoopOutState(loopdb.StatePreimageRevealed) status := <-statusChan - if status.State != loopdb.StatePreimageRevealed { - t.Fatalf("expected state %v, got %v", - loopdb.StatePreimageRevealed, status.State) - } + require.Equal(t, loopdb.StatePreimageRevealed, status.State) // When using taproot htlcs the flow is different as we do reveal the // preimage before sweeping in order for the server to trust us with @@ -410,10 +378,10 @@ func testCustomSweepConfTarget(t *testing.T) { t.Helper() sweepTx := ctx.ReceiveTx() - if sweepTx.TxIn[0].PreviousOutPoint.Hash != htlcTx.TxHash() { - t.Fatalf("expected sweep tx to spend %v, got %v", - htlcTx.TxHash(), sweepTx.TxIn[0].PreviousOutPoint) - } + require.Equal( + t, htlcTx.TxHash(), + sweepTx.TxIn[0].PreviousOutPoint.Hash, + ) // The fee used for the sweep transaction is an estimate based // on the maximum witness size, so we should expect to see a @@ -427,16 +395,14 @@ func testCustomSweepConfTarget(t *testing.T) { feeRate, err := ctx.Lnd.WalletKit.EstimateFeeRate( context.Background(), expConfTarget, ) - if err != nil { - t.Fatalf("unable to retrieve fee estimate: %v", err) - } + require.NoError(t, err, "unable to retrieve fee estimate") + minFee := feeRate.FeeForWeight(weight) - maxFee := btcutil.Amount(float64(minFee) * 1.1) + // Just an estimate that works to sanity check fee upper bound. + maxFee := btcutil.Amount(float64(minFee) * 1.5) - if fee < minFee && fee > maxFee { - t.Fatalf("expected sweep tx to have fee between %v-%v, "+ - "got %v", minFee, maxFee, fee) - } + require.GreaterOrEqual(t, fee, minFee) + require.LessOrEqual(t, fee, maxFee) return sweepTx } @@ -479,14 +445,8 @@ func testCustomSweepConfTarget(t *testing.T) { cfg.store.(*storeMock).assertLoopOutState(loopdb.StateSuccess) status = <-statusChan - if status.State != loopdb.StateSuccess { - t.Fatalf("expected state %v, got %v", loopdb.StateSuccess, - status.State) - } - - if err := <-errChan; err != nil { - t.Fatal(err) - } + require.Equal(t, loopdb.StateSuccess, status.State) + require.NoError(t, <-errChan) } // TestPreimagePush tests or logic that decides whether to push our preimage to diff --git a/store_mock_test.go b/store_mock_test.go index a366fcd..4ec9daa 100644 --- a/store_mock_test.go +++ b/store_mock_test.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/test" "github.com/lightningnetwork/lnd/lntypes" + "github.com/stretchr/testify/require" ) // storeMock implements a mock client swap store. @@ -239,9 +240,7 @@ func (s *storeMock) assertLoopInState( s.t.Helper() state := <-s.loopInUpdateChan - if state.State != expectedState { - s.t.Fatalf("expected state %v, got %v", expectedState, state) - } + require.Equal(s.t, expectedState, state.State) return state } @@ -252,9 +251,8 @@ func (s *storeMock) assertStorePreimageReveal() { select { case state := <-s.loopOutUpdateChan: - if state.State != loopdb.StatePreimageRevealed { - s.t.Fatalf("unexpected state") - } + require.Equal(s.t, loopdb.StatePreimageRevealed, state.State) + case <-time.After(test.Timeout): s.t.Fatalf("expected swap to be marked as preimage revealed") } @@ -265,10 +263,8 @@ func (s *storeMock) assertStoreFinished(expectedResult loopdb.SwapState) { select { case state := <-s.loopOutUpdateChan: - if state.State != expectedResult { - s.t.Fatalf("expected result %v, but got %v", - expectedResult, state) - } + require.Equal(s.t, expectedResult, state.State) + case <-time.After(test.Timeout): s.t.Fatalf("expected swap to be finished") } diff --git a/swap/htlc_test.go b/swap/htlc_test.go index dd309e5..09f9264 100644 --- a/swap/htlc_test.go +++ b/swap/htlc_test.go @@ -54,9 +54,7 @@ func assertEngineExecution(t *testing.T, valid bool, done := false for !done { dis, err := vm.DisasmPC() - if err != nil { - t.Fatalf("stepping (%v)\n", err) - } + require.NoError(t, err, "stepping") debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis)) done, err = vm.Step() diff --git a/test/context.go b/test/context.go index ca83322..e8c517c 100644 --- a/test/context.go +++ b/test/context.go @@ -86,9 +86,11 @@ func (ctx *Context) AssertRegisterSpendNtfn(script []byte) { select { case spendIntent := <-ctx.Lnd.RegisterSpendChannel: - if !bytes.Equal(spendIntent.PkScript, script) { - ctx.T.Fatalf("server not listening for published htlc script") - } + require.Equal( + ctx.T, script, spendIntent.PkScript, + "server not listening for published htlc script", + ) + case <-time.After(Timeout): DumpGoroutines() ctx.T.Fatalf("spend not subscribed to") @@ -163,10 +165,11 @@ func (ctx *Context) AssertPaid( payReq := ctx.DecodeInvoice(swapPayment.Invoice) - if _, ok := ctx.PaidInvoices[*payReq.Description]; ok { - ctx.T.Fatalf("duplicate invoice paid: %v", - *payReq.Description) - } + _, ok := ctx.PaidInvoices[*payReq.Description] + require.False( + ctx.T, ok, + "duplicate invoice paid: %v", *payReq.Description, + ) done := func(result error) { if result != nil { @@ -195,9 +198,10 @@ func (ctx *Context) AssertSettled( select { case preimage := <-ctx.Lnd.SettleInvoiceChannel: hash := sha256.Sum256(preimage[:]) - if expectedHash != hash { - ctx.T.Fatalf("server claims with wrong preimage") - } + require.Equal( + ctx.T, expectedHash, lntypes.Hash(hash), + "server claims with wrong preimage", + ) return preimage case <-time.After(Timeout): @@ -232,9 +236,8 @@ func (ctx *Context) DecodeInvoice(request string) *zpay32.Invoice { ctx.T.Helper() payReq, err := ctx.Lnd.DecodeInvoice(request) - if err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, err) + return payReq } @@ -256,7 +259,5 @@ func (ctx *Context) GetOutputIndex(tx *wire.MsgTx, // waits for the notification to be processed by selecting on a // dedicated test channel. func (ctx *Context) NotifyServerHeight(height int32) { - if err := ctx.Lnd.NotifyHeight(height); err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height)) } diff --git a/test/testutils.go b/test/testutils.go index c00ea96..6aa312e 100644 --- a/test/testutils.go +++ b/test/testutils.go @@ -14,6 +14,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/zpay32" + "github.com/stretchr/testify/require" ) var ( @@ -29,11 +30,10 @@ var ( // GetDestAddr deterministically generates a sweep address for testing. func GetDestAddr(t *testing.T, nr byte) btcutil.Address { - destAddr, err := btcutil.NewAddressScriptHash([]byte{nr}, - &chaincfg.MainNetParams) - if err != nil { - t.Fatal(err) - } + destAddr, err := btcutil.NewAddressScriptHash( + []byte{nr}, &chaincfg.MainNetParams, + ) + require.NoError(t, err) return destAddr } diff --git a/testcontext_test.go b/testcontext_test.go index afc6a90..a221a28 100644 --- a/testcontext_test.go +++ b/testcontext_test.go @@ -140,9 +140,8 @@ func (ctx *testContext) finish() { ctx.stop() select { case err := <-ctx.runErr: - if err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, err) + case <-time.After(test.Timeout): ctx.T.Fatal("client not stopping") } @@ -156,19 +155,12 @@ func (ctx *testContext) finish() { func (ctx *testContext) notifyHeight(height int32) { ctx.T.Helper() - if err := ctx.Lnd.NotifyHeight(height); err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height)) } func (ctx *testContext) assertIsDone() { - if err := ctx.Lnd.IsDone(); err != nil { - ctx.T.Fatal(err) - } - - if err := ctx.store.isDone(); err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, ctx.Lnd.IsDone()) + require.NoError(ctx.T, ctx.store.isDone()) select { case <-ctx.statusChan: