diff --git a/client.go b/client.go index 4581f0f..0bf6e46 100644 --- a/client.go +++ b/client.go @@ -213,29 +213,24 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { SwapHash: swp.Hash, LastUpdate: swp.LastUpdateTime(), } - scriptVersion := GetHtlcScriptVersion( - swp.Contract.ProtocolVersion, - ) - - outputType := swap.HtlcP2WSH - if scriptVersion == swap.HtlcV3 { - outputType = swap.HtlcP2TR - } - htlc, err := swap.NewHtlc( - scriptVersion, - swp.Contract.CltvExpiry, swp.Contract.SenderKey, - swp.Contract.ReceiverKey, swp.Hash, - outputType, s.lndServices.ChainParams, + htlc, err := GetHtlc( + swp.Hash, &swp.Contract.SwapContract, + s.lndServices.ChainParams, ) if err != nil { return nil, err } - if outputType == swap.HtlcP2TR { - swapInfo.HtlcAddressP2TR = htlc.Address - } else { + switch htlc.OutputType { + case swap.HtlcP2WSH: swapInfo.HtlcAddressP2WSH = htlc.Address + + case swap.HtlcP2TR: + swapInfo.HtlcAddressP2TR = htlc.Address + + default: + return nil, swap.ErrInvalidOutputType } swaps = append(swaps, swapInfo) @@ -250,34 +245,23 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { LastUpdate: swp.LastUpdateTime(), } - scriptVersion := GetHtlcScriptVersion( - swp.Contract.SwapContract.ProtocolVersion, + htlc, err := GetHtlc( + swp.Hash, &swp.Contract.SwapContract, + s.lndServices.ChainParams, ) + if err != nil { + return nil, err + } - if scriptVersion == swap.HtlcV3 { - htlcP2TR, err := swap.NewHtlc( - swap.HtlcV3, swp.Contract.CltvExpiry, - swp.Contract.SenderKey, swp.Contract.ReceiverKey, - swp.Hash, swap.HtlcP2TR, - s.lndServices.ChainParams, - ) - if err != nil { - return nil, err - } + switch htlc.OutputType { + case swap.HtlcP2WSH: + swapInfo.HtlcAddressP2WSH = htlc.Address - swapInfo.HtlcAddressP2TR = htlcP2TR.Address - } else { - htlcP2WSH, err := swap.NewHtlc( - swap.HtlcV2, swp.Contract.CltvExpiry, - swp.Contract.SenderKey, swp.Contract.ReceiverKey, - swp.Hash, swap.HtlcP2WSH, - s.lndServices.ChainParams, - ) - if err != nil { - return nil, err - } + case swap.HtlcP2TR: + swapInfo.HtlcAddressP2TR = htlc.Address - swapInfo.HtlcAddressP2WSH = htlcP2WSH.Address + default: + return nil, swap.ErrInvalidOutputType } swaps = append(swaps, swapInfo) diff --git a/client_test.go b/client_test.go index 7c8f88b..b49e0ba 100644 --- a/client_test.go +++ b/client_test.go @@ -1,7 +1,6 @@ package loop import ( - "bytes" "context" "crypto/sha256" "errors" @@ -57,9 +56,7 @@ func TestLoopOutSuccess(t *testing.T) { // Initiate loop out. info, err := ctx.swapClient.LoopOut(context.Background(), &req) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) ctx.assertStored() ctx.assertStatus(loopdb.StateInitiated) @@ -84,9 +81,7 @@ func TestLoopOutFailOffchain(t *testing.T) { ctx := createClientTestContext(t, nil) _, err := ctx.swapClient.LoopOut(context.Background(), testRequest) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) ctx.assertStored() ctx.assertStatus(loopdb.StateInitiated) @@ -208,14 +203,10 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, amt := btcutil.Amount(50000) swapPayReq, err := getInvoice(hash, amt, swapInvoiceDesc) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) prePayReq, err := getInvoice(hash, 100, prepayInvoiceDesc) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) _, senderPubKey := test.CreateKey(1) var senderKey [33]byte @@ -284,16 +275,26 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, // Assert that the loopout htlc equals to the expected one. scriptVersion := GetHtlcScriptVersion(protocolVersion) + var htlc *swap.Htlc - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - outputType = swap.HtlcP2WSH + switch scriptVersion { + case swap.HtlcV2: + htlc, err = swap.NewHtlcV2( + pendingSwap.Contract.CltvExpiry, senderKey, + receiverKey, hash, &chaincfg.TestNet3Params, + ) + + case swap.HtlcV3: + htlc, err = swap.NewHtlcV3( + pendingSwap.Contract.CltvExpiry, senderKey, + receiverKey, senderKey, receiverKey, hash, + &chaincfg.TestNet3Params, + ) + + default: + t.Fatalf(swap.ErrInvalidScriptVersion.Error()) } - htlc, err := swap.NewHtlc( - scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey, - receiverKey, hash, outputType, &chaincfg.TestNet3Params, - ) require.NoError(t, err) require.Equal(t, htlc.PkScript, confIntent.PkScript) @@ -363,10 +364,11 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash, // Expect client on-chain sweep of HTLC. sweepTx := ctx.ReceiveTx() - if !bytes.Equal(sweepTx.TxIn[0].PreviousOutPoint.Hash[:], - htlcOutpoint.Hash[:]) { - ctx.T.Fatalf("client not sweeping from htlc tx") - } + require.Equal( + ctx.T, htlcOutpoint.Hash[:], + sweepTx.TxIn[0].PreviousOutPoint.Hash[:], + "client not sweeping from htlc tx", + ) var preImageIndex int switch scriptVersion { @@ -380,9 +382,7 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash, // Check preimage. clientPreImage := sweepTx.TxIn[0].Witness[preImageIndex] clientPreImageHash := sha256.Sum256(clientPreImage) - if clientPreImageHash != hash { - ctx.T.Fatalf("incorrect preimage") - } + require.Equal(ctx.T, hash, lntypes.Hash(clientPreImageHash)) // Since we successfully published our sweep, we expect the preimage to // have been pushed to our mock server. diff --git a/loopd/swapclient_server_test.go b/loopd/swapclient_server_test.go index 53846ba..585d87f 100644 --- a/loopd/swapclient_server_test.go +++ b/loopd/swapclient_server_test.go @@ -130,16 +130,13 @@ func TestValidateConfTarget(t *testing.T) { test.confTarget, defaultConf, ) - haveErr := err != nil - if haveErr != test.expectErr { - t.Fatalf("expected err: %v, got: %v", - test.expectErr, err) + if test.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) } - if target != test.expectedTarget { - t.Fatalf("expected: %v, got: %v", - test.expectedTarget, target) - } + require.Equal(t, test.expectedTarget, target) }) } } @@ -199,16 +196,13 @@ func TestValidateLoopInRequest(t *testing.T) { test.confTarget, external, ) - haveErr := err != nil - if haveErr != test.expectErr { - t.Fatalf("expected err: %v, got: %v", - test.expectErr, err) + if test.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) } - if conf != test.expectedTarget { - t.Fatalf("expected: %v, got: %v", - test.expectedTarget, conf) - } + require.Equal(t, test.expectedTarget, conf) }) } } diff --git a/loopd/view.go b/loopd/view.go index f99fdef..3a7b600 100644 --- a/loopd/view.go +++ b/loopd/view.go @@ -7,7 +7,6 @@ import ( "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop" "github.com/lightninglabs/loop/loopdb" - "github.com/lightninglabs/loop/swap" ) // view prints all swaps currently in the database. @@ -49,24 +48,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { } for _, s := range swaps { - scriptVersion := loop.GetHtlcScriptVersion( - s.Contract.ProtocolVersion, - ) - - var outputType swap.HtlcOutputType - switch scriptVersion { - case swap.HtlcV2: - outputType = swap.HtlcP2WSH - - case swap.HtlcV3: - outputType = swap.HtlcP2TR - } - htlc, err := swap.NewHtlc( - loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion), - s.Contract.CltvExpiry, - s.Contract.SenderKey, - s.Contract.ReceiverKey, - s.Hash, outputType, chainParams, + htlc, err := loop.GetHtlc( + s.Hash, &s.Contract.SwapContract, chainParams, ) if err != nil { return err @@ -77,7 +60,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { s.Contract.InitiationTime, s.Contract.InitiationHeight, ) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) - fmt.Printf(" Htlc address: %v\n", htlc.Address) + fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType, + htlc.Address) fmt.Printf(" Uncharge channels: %v\n", s.Contract.OutgoingChanSet) @@ -113,12 +97,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { } for _, s := range swaps { - htlc, err := swap.NewHtlc( - loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion), - s.Contract.CltvExpiry, - s.Contract.SenderKey, - s.Contract.ReceiverKey, - s.Hash, swap.HtlcP2WSH, chainParams, + htlc, err := loop.GetHtlc( + s.Hash, &s.Contract.SwapContract, chainParams, ) if err != nil { return err @@ -129,7 +109,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { s.Contract.InitiationTime, s.Contract.InitiationHeight, ) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) - fmt.Printf(" Htlc address: %v\n", htlc.Address) + fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType, + htlc.Address) fmt.Printf(" Amt: %v, Expiry: %v\n", s.Contract.AmountRequested, s.Contract.CltvExpiry, ) diff --git a/loopin.go b/loopin.go index 8b3e49b..62d2a42 100644 --- a/loopin.go +++ b/loopin.go @@ -404,26 +404,26 @@ func validateLoopInContract(lnd *lndclient.LndServices, // initHtlcs creates and updates the native and nested segwit htlcs // of the loopInSwap. func (s *loopInSwap) initHtlcs() error { - if IsTaprootSwap(&s.SwapContract) { - htlcP2TR, err := s.swapKit.getHtlc(swap.HtlcP2TR) - if err != nil { - return err - } + htlc, err := GetHtlc( + s.hash, &s.SwapContract, s.swapKit.lnd.ChainParams, + ) + if err != nil { + return err + } - s.swapKit.log.Infof("Htlc address (P2TR): %v", htlcP2TR.Address) - s.htlcP2TR = htlcP2TR + switch htlc.OutputType { + case swap.HtlcP2WSH: + s.htlcP2WSH = htlc - return nil - } + case swap.HtlcP2TR: + s.htlcP2TR = htlc - htlcP2WSH, err := s.swapKit.getHtlc(swap.HtlcP2WSH) - if err != nil { - return err + default: + return fmt.Errorf("invalid output type") } - // Log htlc addresses for debugging. - s.swapKit.log.Infof("Htlc address (P2WSH): %v", htlcP2WSH.Address) - s.htlcP2WSH = htlcP2WSH + s.swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType, + htlc.Address) return nil } diff --git a/loopin_test.go b/loopin_test.go index e4a317b..6940b68 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -58,9 +58,8 @@ func testLoopInSuccess(t *testing.T) { context.Background(), cfg, height, req, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + inSwap := initResult.swap ctx.store.assertLoopInStored() @@ -142,10 +141,7 @@ func testLoopInSuccess(t *testing.T) { ctx.assertState(loopdb.StateSuccess) ctx.store.assertLoopInState(loopdb.StateSuccess) - err = <-errChan - if err != nil { - t.Fatal(err) - } + require.NoError(t, <-errChan) } // TestLoopInTimeout tests scenarios where the server doesn't sweep the htlc @@ -215,9 +211,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { context.Background(), cfg, height, &req, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) inSwap := initResult.swap ctx.store.assertLoopInStored() @@ -289,11 +283,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { ctx.assertState(loopdb.StateFailIncorrectHtlcAmt) ctx.store.assertLoopInState(loopdb.StateFailIncorrectHtlcAmt) - err = <-errChan - if err != nil { - t.Fatal(err) - } - + require.NoError(t, <-errChan) return } @@ -308,9 +298,11 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { // Expect a signing request for the htlc tx output value. signReq := <-ctx.lnd.SignOutputRawChannel - if signReq.SignDescriptors[0].Output.Value != htlcTx.TxOut[0].Value { - t.Fatal("invalid signing amount") - } + require.Equal( + t, htlcTx.TxOut[0].Value, + signReq.SignDescriptors[0].Output.Value, + "invalid signing amount", + ) // Expect timeout tx to be published. timeoutTx := <-ctx.lnd.TxPublishChannel @@ -341,10 +333,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { state := ctx.store.assertLoopInState(loopdb.StateFailTimeout) require.Equal(t, cost, state.Cost) - err = <-errChan - if err != nil { - t.Fatal(err) - } + require.NoError(t, <-errChan) } // TestLoopInResume tests resuming swaps in various states. @@ -455,34 +444,38 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, pendSwap.Loop.Events[0].Cost = cost } - scriptVersion := GetHtlcScriptVersion(storedVersion) + var ( + htlc *swap.Htlc + err error + ) - outputType := swap.HtlcP2WSH - if scriptVersion == swap.HtlcV3 { - outputType = swap.HtlcP2TR + switch GetHtlcScriptVersion(storedVersion) { + case swap.HtlcV2: + htlc, err = swap.NewHtlcV2( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, testPreimage.Hash(), + cfg.lnd.ChainParams, + ) + + case swap.HtlcV3: + htlc, err = swap.NewHtlcV3( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, contract.SenderKey, + contract.ReceiverKey, testPreimage.Hash(), + cfg.lnd.ChainParams, + ) + + default: + t.Fatalf("unknown HTLC script version") } - htlc, err := swap.NewHtlc( - scriptVersion, contract.CltvExpiry, contract.SenderKey, - contract.ReceiverKey, testPreimage.Hash(), outputType, - cfg.lnd.ChainParams, - ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - inSwap, err := resumeLoopInSwap( - context.Background(), cfg, - pendSwap, - ) - if err != nil { - t.Fatal(err) - } + inSwap, err := resumeLoopInSwap(context.Background(), cfg, pendSwap) + require.NoError(t, err) var height int32 if expired { @@ -501,10 +494,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, }() defer func() { - err = <-errChan - if err != nil { - t.Fatal(err) - } + require.NoError(t, <-errChan) select { case <-ctx.lnd.SendPaymentChannel: diff --git a/loopin_testcontext_test.go b/loopin_testcontext_test.go index 077622e..ca6b024 100644 --- a/loopin_testcontext_test.go +++ b/loopin_testcontext_test.go @@ -63,10 +63,7 @@ func newLoopInTestContext(t *testing.T) *loopInTestContext { func (c *loopInTestContext) assertState(expectedState loopdb.SwapState) { state := <-c.statusChan - if state.State != expectedState { - c.t.Fatalf("expected state %v but got %v", expectedState, - state.State) - } + require.Equal(c.t, expectedState, state.State) } // assertSubscribeInvoice asserts that the client subscribes to invoice updates diff --git a/loopout.go b/loopout.go index 30185e6..6b03bfd 100644 --- a/loopout.go +++ b/loopout.go @@ -201,21 +201,17 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, swapKit.lastUpdateTime = initiationTime - scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion()) - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - // Default to using P2WSH for legacy htlcs. - outputType = swap.HtlcP2WSH - } - // Create the htlc. - htlc, err := swapKit.getHtlc(outputType) + htlc, err := GetHtlc( + swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams, + ) if err != nil { return nil, err } // Log htlc address for debugging. - swapKit.log.Infof("Htlc address: %v", htlc.Address) + swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType, + htlc.Address) // Obtain the payment addr since we'll need it later for routing plugin // recommendation and possibly for cancel. @@ -261,15 +257,10 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig, hash, swap.TypeOut, cfg, &pend.Contract.SwapContract, ) - scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion) - outputType := swap.HtlcP2TR - if scriptVersion != swap.HtlcV3 { - // Default to using P2WSH for legacy htlcs. - outputType = swap.HtlcP2WSH - } - // Create the htlc. - htlc, err := swapKit.getHtlc(outputType) + htlc, err := GetHtlc( + swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams, + ) if err != nil { return nil, err } diff --git a/loopout_test.go b/loopout_test.go index ea41d10..67566e8 100644 --- a/loopout_test.go +++ b/loopout_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "math" - "reflect" "testing" "time" @@ -66,7 +65,7 @@ func testLoopOutPaymentParameters(t *testing.T) { blockEpochChan := make(chan interface{}) statusChan := make(chan SwapInfo) - const maxParts = 5 + const maxParts = uint32(5) chanSet := loopdb.ChannelSet{2, 3} @@ -77,9 +76,7 @@ func testLoopOutPaymentParameters(t *testing.T) { initResult, err := newLoopOutSwap( context.Background(), cfg, height, &req, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) swap := initResult.swap // Execute the swap in its own goroutine. @@ -105,9 +102,7 @@ func testLoopOutPaymentParameters(t *testing.T) { store.assertLoopOutStored() state := <-statusChan - if state.State != loopdb.StateInitiated { - t.Fatal("unexpected state") - } + require.Equal(t, loopdb.StateInitiated, state.State) // Check that the SwapInfo contains the outgoing chan set require.Equal(t, chanSet, state.OutgoingChanSet) @@ -130,18 +125,12 @@ func testLoopOutPaymentParameters(t *testing.T) { } // Assert that it is sent as a multi-part payment. - if swapPayment.MaxParts != maxParts { - t.Fatalf("Expected %v parts, but got %v", - maxParts, swapPayment.MaxParts) - } + require.Equal(t, maxParts, swapPayment.MaxParts) // Verify the outgoing channel set restriction. - if !reflect.DeepEqual( - []uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds, - ) { - - t.Fatalf("Unexpected outgoing channel set") - } + require.Equal( + t, []uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds, + ) // Swap is expected to register for confirmation of the htlc. Assert // this to prevent a blocked channel in the mock. @@ -152,10 +141,7 @@ func testLoopOutPaymentParameters(t *testing.T) { cancel() // Expect the swap to signal that it was cancelled. - err = <-errChan - if err != context.Canceled { - t.Fatal(err) - } + require.Equal(t, context.Canceled, <-errChan) } // TestLateHtlcPublish tests that the client is not revealing the preimage if @@ -198,9 +184,7 @@ func testLateHtlcPublish(t *testing.T) { initResult, err := newLoopOutSwap( context.Background(), cfg, height, testRequest, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) swap := initResult.swap sweeper := &sweep.Sweeper{Lnd: &lnd.LndServices} @@ -225,11 +209,8 @@ func testLateHtlcPublish(t *testing.T) { }() store.assertLoopOutStored() - - state := <-statusChan - if state.State != loopdb.StateInitiated { - t.Fatal("unexpected state") - } + status := <-statusChan + require.Equal(t, loopdb.StateInitiated, status.State) signalSwapPaymentResult := ctx.AssertPaid(swapInvoiceDesc) signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) @@ -249,15 +230,9 @@ func testLateHtlcPublish(t *testing.T) { store.assertStoreFinished(loopdb.StateFailTimeout) - status := <-statusChan - if status.State != loopdb.StateFailTimeout { - t.Fatal("unexpected state") - } - - err = <-errChan - if err != nil { - t.Fatal(err) - } + status = <-statusChan + require.Equal(t, loopdb.StateFailTimeout, status.State) + require.NoError(t, <-errChan) } // TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a @@ -304,9 +279,7 @@ func testCustomSweepConfTarget(t *testing.T) { initResult, err := newLoopOutSwap( context.Background(), cfg, ctx.Lnd.Height, &testReq, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) swap := initResult.swap // Set up the required dependencies to execute the swap. @@ -339,9 +312,7 @@ func testCustomSweepConfTarget(t *testing.T) { // The swap should be found in its initial state. cfg.store.(*storeMock).assertLoopOutStored() state := <-statusChan - if state.State != loopdb.StateInitiated { - t.Fatal("unexpected state") - } + require.Equal(t, loopdb.StateInitiated, state.State) // We'll then pay both the swap and prepay invoice, which should trigger // the server to publish the on-chain HTLC. @@ -381,10 +352,7 @@ func testCustomSweepConfTarget(t *testing.T) { cfg.store.(*storeMock).assertLoopOutState(loopdb.StatePreimageRevealed) status := <-statusChan - if status.State != loopdb.StatePreimageRevealed { - t.Fatalf("expected state %v, got %v", - loopdb.StatePreimageRevealed, status.State) - } + require.Equal(t, loopdb.StatePreimageRevealed, status.State) // When using taproot htlcs the flow is different as we do reveal the // preimage before sweeping in order for the server to trust us with @@ -410,10 +378,10 @@ func testCustomSweepConfTarget(t *testing.T) { t.Helper() sweepTx := ctx.ReceiveTx() - if sweepTx.TxIn[0].PreviousOutPoint.Hash != htlcTx.TxHash() { - t.Fatalf("expected sweep tx to spend %v, got %v", - htlcTx.TxHash(), sweepTx.TxIn[0].PreviousOutPoint) - } + require.Equal( + t, htlcTx.TxHash(), + sweepTx.TxIn[0].PreviousOutPoint.Hash, + ) // The fee used for the sweep transaction is an estimate based // on the maximum witness size, so we should expect to see a @@ -427,16 +395,14 @@ func testCustomSweepConfTarget(t *testing.T) { feeRate, err := ctx.Lnd.WalletKit.EstimateFeeRate( context.Background(), expConfTarget, ) - if err != nil { - t.Fatalf("unable to retrieve fee estimate: %v", err) - } + require.NoError(t, err, "unable to retrieve fee estimate") + minFee := feeRate.FeeForWeight(weight) - maxFee := btcutil.Amount(float64(minFee) * 1.1) + // Just an estimate that works to sanity check fee upper bound. + maxFee := btcutil.Amount(float64(minFee) * 1.5) - if fee < minFee && fee > maxFee { - t.Fatalf("expected sweep tx to have fee between %v-%v, "+ - "got %v", minFee, maxFee, fee) - } + require.GreaterOrEqual(t, fee, minFee) + require.LessOrEqual(t, fee, maxFee) return sweepTx } @@ -479,14 +445,8 @@ func testCustomSweepConfTarget(t *testing.T) { cfg.store.(*storeMock).assertLoopOutState(loopdb.StateSuccess) status = <-statusChan - if status.State != loopdb.StateSuccess { - t.Fatalf("expected state %v, got %v", loopdb.StateSuccess, - status.State) - } - - if err := <-errChan; err != nil { - t.Fatal(err) - } + require.Equal(t, loopdb.StateSuccess, status.State) + require.NoError(t, <-errChan) } // TestPreimagePush tests or logic that decides whether to push our preimage to diff --git a/store_mock_test.go b/store_mock_test.go index a366fcd..4ec9daa 100644 --- a/store_mock_test.go +++ b/store_mock_test.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/test" "github.com/lightningnetwork/lnd/lntypes" + "github.com/stretchr/testify/require" ) // storeMock implements a mock client swap store. @@ -239,9 +240,7 @@ func (s *storeMock) assertLoopInState( s.t.Helper() state := <-s.loopInUpdateChan - if state.State != expectedState { - s.t.Fatalf("expected state %v, got %v", expectedState, state) - } + require.Equal(s.t, expectedState, state.State) return state } @@ -252,9 +251,8 @@ func (s *storeMock) assertStorePreimageReveal() { select { case state := <-s.loopOutUpdateChan: - if state.State != loopdb.StatePreimageRevealed { - s.t.Fatalf("unexpected state") - } + require.Equal(s.t, loopdb.StatePreimageRevealed, state.State) + case <-time.After(test.Timeout): s.t.Fatalf("expected swap to be marked as preimage revealed") } @@ -265,10 +263,8 @@ func (s *storeMock) assertStoreFinished(expectedResult loopdb.SwapState) { select { case state := <-s.loopOutUpdateChan: - if state.State != expectedResult { - s.t.Fatalf("expected result %v, but got %v", - expectedResult, state) - } + require.Equal(s.t, expectedResult, state.State) + case <-time.After(test.Timeout): s.t.Fatalf("expected swap to be finished") } diff --git a/swap.go b/swap.go index a084e95..764af79 100644 --- a/swap.go +++ b/swap.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/btcsuite/btcd/chaincfg" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" @@ -67,14 +68,28 @@ func IsTaprootSwap(swapContract *loopdb.SwapContract) bool { return GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3 } -// getHtlc composes and returns the on-chain swap script. -func (s *swapKit) getHtlc(outputType swap.HtlcOutputType) (*swap.Htlc, error) { - return swap.NewHtlc( - GetHtlcScriptVersion(s.contract.ProtocolVersion), - s.contract.CltvExpiry, s.contract.SenderKey, - s.contract.ReceiverKey, s.hash, outputType, - s.swapConfig.lnd.ChainParams, - ) +// GetHtlc composes and returns the on-chain swap script. +func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract, + chainParams *chaincfg.Params) (*swap.Htlc, error) { + + switch GetHtlcScriptVersion(contract.ProtocolVersion) { + case swap.HtlcV2: + return swap.NewHtlcV2( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, hash, + chainParams, + ) + + case swap.HtlcV3: + return swap.NewHtlcV3( + contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, contract.SenderKey, + contract.ReceiverKey, hash, + chainParams, + ) + } + + return nil, swap.ErrInvalidScriptVersion } // swapInfo constructs and returns a filled SwapInfo from diff --git a/swap/htlc.go b/swap/htlc.go index 1e174cc..8a1101a 100644 --- a/swap/htlc.go +++ b/swap/htlc.go @@ -114,16 +114,16 @@ var ( // QuoteHtlcP2WSH is a template script just used for sweep fee // estimation. - QuoteHtlcP2WSH, _ = NewHtlc( - HtlcV2, ^int32(0), dummyPubKey, dummyPubKey, quoteHash, - HtlcP2WSH, &chaincfg.MainNetParams, + QuoteHtlcP2WSH, _ = NewHtlcV2( + ^int32(0), dummyPubKey, dummyPubKey, quoteHash, + &chaincfg.MainNetParams, ) // QuoteHtlcP2TR is a template script just used for sweep fee // estimation. - QuoteHtlcP2TR, _ = NewHtlc( - HtlcV3, ^int32(0), dummyPubKey, dummyPubKey, quoteHash, - HtlcP2TR, &chaincfg.MainNetParams, + QuoteHtlcP2TR, _ = NewHtlcV3( + ^int32(0), dummyPubKey, dummyPubKey, dummyPubKey, dummyPubKey, + quoteHash, &chaincfg.MainNetParams, ) // ErrInvalidScriptVersion is returned when an unknown htlc version @@ -135,6 +135,10 @@ var ( // selected for a v2 script. ErrInvalidOutputSelected = fmt.Errorf("taproot output selected for " + "non taproot htlc") + + // ErrInvalidOutputType is returned when an unknown output type is + // associated with a certain swap htlc. + ErrInvalidOutputType = fmt.Errorf("invalid htlc output type") ) // String returns the string value of HtlcOutputType. @@ -151,38 +155,54 @@ func (h HtlcOutputType) String() string { } } -// NewHtlc returns a new instance. For v3 scripts, an internal pubkey generated -// by both participants must be provided. -func NewHtlc(version ScriptVersion, cltvExpiry int32, - senderKey, receiverKey [33]byte, hash lntypes.Hash, - outputType HtlcOutputType, chainParams *chaincfg.Params) (*Htlc, error) { +// NewHtlcV2 returns a new V2 (P2WSH) HTLC instance. +func NewHtlcV2(cltvExpiry int32, senderKey, receiverKey [33]byte, + hash lntypes.Hash, chainParams *chaincfg.Params) (*Htlc, error) { - var ( - err error - htlc HtlcScript + htlc, err := newHTLCScriptV2( + cltvExpiry, senderKey, receiverKey, hash, ) - switch version { - case HtlcV2: - htlc, err = newHTLCScriptV2( - cltvExpiry, senderKey, receiverKey, hash, - ) - - case HtlcV3: - htlc, err = newHTLCScriptV3( - cltvExpiry, senderKey, receiverKey, hash, - ) + if err != nil { + return nil, err + } - default: - return nil, ErrInvalidScriptVersion + address, pkScript, sigScript, err := htlc.lockingConditions( + HtlcP2WSH, chainParams, + ) + if err != nil { + return nil, fmt.Errorf("could not get address: %w", err) } + return &Htlc{ + HtlcScript: htlc, + Hash: hash, + Version: HtlcV2, + PkScript: pkScript, + OutputType: HtlcP2WSH, + ChainParams: chainParams, + Address: address, + SigScript: sigScript, + }, nil +} + +// NewHtlcV3 returns a new V3 HTLC (P2TR) instance. Internal pubkey generated +// by both participants must be provided. +func NewHtlcV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, + senderKey, receiverKey [33]byte, hash lntypes.Hash, + chainParams *chaincfg.Params) (*Htlc, error) { + + htlc, err := newHTLCScriptV3( + cltvExpiry, senderInternalKey, receiverInternalKey, + senderKey, receiverKey, hash, + ) + if err != nil { return nil, err } address, pkScript, sigScript, err := htlc.lockingConditions( - outputType, chainParams, + HtlcP2TR, chainParams, ) if err != nil { return nil, fmt.Errorf("could not get address: %w", err) @@ -191,9 +211,9 @@ func NewHtlc(version ScriptVersion, cltvExpiry int32, return &Htlc{ HtlcScript: htlc, Hash: hash, - Version: version, + Version: HtlcV3, PkScript: pkScript, - OutputType: outputType, + OutputType: HtlcP2TR, ChainParams: chainParams, Address: address, SigScript: sigScript, @@ -481,7 +501,8 @@ type HtlcScriptV3 struct { } // newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script. -func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, +func newHTLCScriptV3(cltvExpiry int32, senderInternalKey, receiverInternalKey, + senderHtlcKey, receiverHtlcKey [33]byte, swapHash lntypes.Hash) (*HtlcScriptV3, error) { senderPubKey, err := schnorr.ParsePubKey(senderHtlcKey[1:]) @@ -494,13 +515,6 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, return nil, err } - aggregateKey, _, _, err := musig2.AggregateKeys( - []*btcec.PublicKey{senderPubKey, receiverPubKey}, true, - ) - if err != nil { - return nil, err - } - // Create our success path script, we'll use this separately // to generate the success path leaf. successPathScript, err := GenSuccessPathScript( @@ -527,6 +541,31 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, rootHash := tree.RootNode.TapHash() + // Parse the pub keys used in the internal aggregate key. They are + // optional and may just be the same keys that are used for the script + // paths. + senderInternalPubKey, err := schnorr.ParsePubKey(senderInternalKey[1:]) + if err != nil { + return nil, err + } + + receiverInternalPubKey, err := schnorr.ParsePubKey( + receiverInternalKey[1:], + ) + if err != nil { + return nil, err + } + + // Calculate the internal aggregate key. + aggregateKey, _, _, err := musig2.AggregateKeys( + []*btcec.PublicKey{ + senderInternalPubKey, receiverInternalPubKey, + }, true, + ) + if err != nil { + return nil, err + } + // Calculate top level taproot key. taprootKey := txscript.ComputeTaprootOutputKey( aggregateKey.PreTweakedKey, rootHash[:], diff --git a/swap/htlc_test.go b/swap/htlc_test.go index 96b2f16..09f9264 100644 --- a/swap/htlc_test.go +++ b/swap/htlc_test.go @@ -54,9 +54,7 @@ func assertEngineExecution(t *testing.T, valid bool, done := false for !done { dis, err := vm.DisasmPC() - if err != nil { - t.Fatalf("stepping (%v)\n", err) - } + require.NoError(t, err, "stepping") debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis)) done, err = vm.Step() @@ -134,9 +132,9 @@ func TestHtlcV2(t *testing.T) { hash := sha256.Sum256(testPreimage[:]) // Create the htlc. - htlc, err := NewHtlc( - HtlcV2, testCltvExpiry, senderKey, receiverKey, hash, - HtlcP2WSH, &chaincfg.MainNetParams, + htlc, err := NewHtlcV2( + testCltvExpiry, senderKey, receiverKey, hash, + &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -287,10 +285,9 @@ func TestHtlcV2(t *testing.T) { bogusKey := [33]byte{0xb, 0xa, 0xd} // Create the htlc with the bogus key. - htlc, err = NewHtlc( - HtlcV2, testCltvExpiry, - bogusKey, receiverKey, hash, - HtlcP2WSH, &chaincfg.MainNetParams, + htlc, err = NewHtlcV2( + testCltvExpiry, bogusKey, receiverKey, + hash, &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -357,9 +354,9 @@ func TestHtlcV3(t *testing.T) { copy(receiverKey[:], receiverPubKey.SerializeCompressed()) copy(senderKey[:], senderPubKey.SerializeCompressed()) - htlc, err := NewHtlc( - HtlcV3, cltvExpiry, senderKey, receiverKey, - hashedPreimage, HtlcP2TR, &chaincfg.MainNetParams, + htlc, err := NewHtlcV3( + cltvExpiry, senderKey, receiverKey, senderKey, receiverKey, + hashedPreimage, &chaincfg.MainNetParams, ) require.NoError(t, err) @@ -540,10 +537,10 @@ func TestHtlcV3(t *testing.T) { bogusKey.SerializeCompressed(), ) - htlc, err := NewHtlc( - HtlcV3, cltvExpiry, bogusKeyBytes, - receiverKey, hashedPreimage, HtlcP2TR, - &chaincfg.MainNetParams, + htlc, err := NewHtlcV3( + cltvExpiry, senderKey, + receiverKey, bogusKeyBytes, receiverKey, + hashedPreimage, &chaincfg.MainNetParams, ) require.NoError(t, err) diff --git a/test/context.go b/test/context.go index ca83322..e8c517c 100644 --- a/test/context.go +++ b/test/context.go @@ -86,9 +86,11 @@ func (ctx *Context) AssertRegisterSpendNtfn(script []byte) { select { case spendIntent := <-ctx.Lnd.RegisterSpendChannel: - if !bytes.Equal(spendIntent.PkScript, script) { - ctx.T.Fatalf("server not listening for published htlc script") - } + require.Equal( + ctx.T, script, spendIntent.PkScript, + "server not listening for published htlc script", + ) + case <-time.After(Timeout): DumpGoroutines() ctx.T.Fatalf("spend not subscribed to") @@ -163,10 +165,11 @@ func (ctx *Context) AssertPaid( payReq := ctx.DecodeInvoice(swapPayment.Invoice) - if _, ok := ctx.PaidInvoices[*payReq.Description]; ok { - ctx.T.Fatalf("duplicate invoice paid: %v", - *payReq.Description) - } + _, ok := ctx.PaidInvoices[*payReq.Description] + require.False( + ctx.T, ok, + "duplicate invoice paid: %v", *payReq.Description, + ) done := func(result error) { if result != nil { @@ -195,9 +198,10 @@ func (ctx *Context) AssertSettled( select { case preimage := <-ctx.Lnd.SettleInvoiceChannel: hash := sha256.Sum256(preimage[:]) - if expectedHash != hash { - ctx.T.Fatalf("server claims with wrong preimage") - } + require.Equal( + ctx.T, expectedHash, lntypes.Hash(hash), + "server claims with wrong preimage", + ) return preimage case <-time.After(Timeout): @@ -232,9 +236,8 @@ func (ctx *Context) DecodeInvoice(request string) *zpay32.Invoice { ctx.T.Helper() payReq, err := ctx.Lnd.DecodeInvoice(request) - if err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, err) + return payReq } @@ -256,7 +259,5 @@ func (ctx *Context) GetOutputIndex(tx *wire.MsgTx, // waits for the notification to be processed by selecting on a // dedicated test channel. func (ctx *Context) NotifyServerHeight(height int32) { - if err := ctx.Lnd.NotifyHeight(height); err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height)) } diff --git a/test/testutils.go b/test/testutils.go index c00ea96..6aa312e 100644 --- a/test/testutils.go +++ b/test/testutils.go @@ -14,6 +14,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/zpay32" + "github.com/stretchr/testify/require" ) var ( @@ -29,11 +30,10 @@ var ( // GetDestAddr deterministically generates a sweep address for testing. func GetDestAddr(t *testing.T, nr byte) btcutil.Address { - destAddr, err := btcutil.NewAddressScriptHash([]byte{nr}, - &chaincfg.MainNetParams) - if err != nil { - t.Fatal(err) - } + destAddr, err := btcutil.NewAddressScriptHash( + []byte{nr}, &chaincfg.MainNetParams, + ) + require.NoError(t, err) return destAddr } diff --git a/testcontext_test.go b/testcontext_test.go index afc6a90..a221a28 100644 --- a/testcontext_test.go +++ b/testcontext_test.go @@ -140,9 +140,8 @@ func (ctx *testContext) finish() { ctx.stop() select { case err := <-ctx.runErr: - if err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, err) + case <-time.After(test.Timeout): ctx.T.Fatal("client not stopping") } @@ -156,19 +155,12 @@ func (ctx *testContext) finish() { func (ctx *testContext) notifyHeight(height int32) { ctx.T.Helper() - if err := ctx.Lnd.NotifyHeight(height); err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height)) } func (ctx *testContext) assertIsDone() { - if err := ctx.Lnd.IsDone(); err != nil { - ctx.T.Fatal(err) - } - - if err := ctx.store.isDone(); err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, ctx.Lnd.IsDone()) + require.NoError(ctx.T, ctx.store.isDone()) select { case <-ctx.statusChan: