diff --git a/client_test.go b/client_test.go index 3bca073..b49e0ba 100644 --- a/client_test.go +++ b/client_test.go @@ -1,7 +1,6 @@ package loop import ( - "bytes" "context" "crypto/sha256" "errors" @@ -57,9 +56,7 @@ func TestLoopOutSuccess(t *testing.T) { // Initiate loop out. info, err := ctx.swapClient.LoopOut(context.Background(), &req) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) ctx.assertStored() ctx.assertStatus(loopdb.StateInitiated) @@ -84,9 +81,7 @@ func TestLoopOutFailOffchain(t *testing.T) { ctx := createClientTestContext(t, nil) _, err := ctx.swapClient.LoopOut(context.Background(), testRequest) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) ctx.assertStored() ctx.assertStatus(loopdb.StateInitiated) @@ -208,14 +203,10 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, amt := btcutil.Amount(50000) swapPayReq, err := getInvoice(hash, amt, swapInvoiceDesc) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) prePayReq, err := getInvoice(hash, 100, prepayInvoiceDesc) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) _, senderPubKey := test.CreateKey(1) var senderKey [33]byte @@ -373,10 +364,11 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash, // Expect client on-chain sweep of HTLC. sweepTx := ctx.ReceiveTx() - if !bytes.Equal(sweepTx.TxIn[0].PreviousOutPoint.Hash[:], - htlcOutpoint.Hash[:]) { - ctx.T.Fatalf("client not sweeping from htlc tx") - } + require.Equal( + ctx.T, htlcOutpoint.Hash[:], + sweepTx.TxIn[0].PreviousOutPoint.Hash[:], + "client not sweeping from htlc tx", + ) var preImageIndex int switch scriptVersion { @@ -390,9 +382,7 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash, // Check preimage. clientPreImage := sweepTx.TxIn[0].Witness[preImageIndex] clientPreImageHash := sha256.Sum256(clientPreImage) - if clientPreImageHash != hash { - ctx.T.Fatalf("incorrect preimage") - } + require.Equal(ctx.T, hash, lntypes.Hash(clientPreImageHash)) // Since we successfully published our sweep, we expect the preimage to // have been pushed to our mock server. diff --git a/loopd/swapclient_server_test.go b/loopd/swapclient_server_test.go index 53846ba..585d87f 100644 --- a/loopd/swapclient_server_test.go +++ b/loopd/swapclient_server_test.go @@ -130,16 +130,13 @@ func TestValidateConfTarget(t *testing.T) { test.confTarget, defaultConf, ) - haveErr := err != nil - if haveErr != test.expectErr { - t.Fatalf("expected err: %v, got: %v", - test.expectErr, err) + if test.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) } - if target != test.expectedTarget { - t.Fatalf("expected: %v, got: %v", - test.expectedTarget, target) - } + require.Equal(t, test.expectedTarget, target) }) } } @@ -199,16 +196,13 @@ func TestValidateLoopInRequest(t *testing.T) { test.confTarget, external, ) - haveErr := err != nil - if haveErr != test.expectErr { - t.Fatalf("expected err: %v, got: %v", - test.expectErr, err) + if test.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) } - if conf != test.expectedTarget { - t.Fatalf("expected: %v, got: %v", - test.expectedTarget, conf) - } + require.Equal(t, test.expectedTarget, conf) }) } } diff --git a/loopin_test.go b/loopin_test.go index 53f2b35..6940b68 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -58,9 +58,8 @@ func testLoopInSuccess(t *testing.T) { context.Background(), cfg, height, req, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + inSwap := initResult.swap ctx.store.assertLoopInStored() @@ -142,10 +141,7 @@ func testLoopInSuccess(t *testing.T) { ctx.assertState(loopdb.StateSuccess) ctx.store.assertLoopInState(loopdb.StateSuccess) - err = <-errChan - if err != nil { - t.Fatal(err) - } + require.NoError(t, <-errChan) } // TestLoopInTimeout tests scenarios where the server doesn't sweep the htlc @@ -215,9 +211,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { context.Background(), cfg, height, &req, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) inSwap := initResult.swap ctx.store.assertLoopInStored() @@ -289,11 +283,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { ctx.assertState(loopdb.StateFailIncorrectHtlcAmt) ctx.store.assertLoopInState(loopdb.StateFailIncorrectHtlcAmt) - err = <-errChan - if err != nil { - t.Fatal(err) - } - + require.NoError(t, <-errChan) return } @@ -308,9 +298,11 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { // Expect a signing request for the htlc tx output value. signReq := <-ctx.lnd.SignOutputRawChannel - if signReq.SignDescriptors[0].Output.Value != htlcTx.TxOut[0].Value { - t.Fatal("invalid signing amount") - } + require.Equal( + t, htlcTx.TxOut[0].Value, + signReq.SignDescriptors[0].Output.Value, + "invalid signing amount", + ) // Expect timeout tx to be published. timeoutTx := <-ctx.lnd.TxPublishChannel @@ -341,10 +333,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) { state := ctx.store.assertLoopInState(loopdb.StateFailTimeout) require.Equal(t, cost, state.Cost) - err = <-errChan - if err != nil { - t.Fatal(err) - } + require.NoError(t, <-errChan) } // TestLoopInResume tests resuming swaps in various states. @@ -483,17 +472,10 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, require.NoError(t, err) err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - inSwap, err := resumeLoopInSwap( - context.Background(), cfg, - pendSwap, - ) - if err != nil { - t.Fatal(err) - } + inSwap, err := resumeLoopInSwap(context.Background(), cfg, pendSwap) + require.NoError(t, err) var height int32 if expired { @@ -512,10 +494,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, }() defer func() { - err = <-errChan - if err != nil { - t.Fatal(err) - } + require.NoError(t, <-errChan) select { case <-ctx.lnd.SendPaymentChannel: diff --git a/loopin_testcontext_test.go b/loopin_testcontext_test.go index 077622e..ca6b024 100644 --- a/loopin_testcontext_test.go +++ b/loopin_testcontext_test.go @@ -63,10 +63,7 @@ func newLoopInTestContext(t *testing.T) *loopInTestContext { func (c *loopInTestContext) assertState(expectedState loopdb.SwapState) { state := <-c.statusChan - if state.State != expectedState { - c.t.Fatalf("expected state %v but got %v", expectedState, - state.State) - } + require.Equal(c.t, expectedState, state.State) } // assertSubscribeInvoice asserts that the client subscribes to invoice updates diff --git a/loopout_test.go b/loopout_test.go index ea41d10..67566e8 100644 --- a/loopout_test.go +++ b/loopout_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "math" - "reflect" "testing" "time" @@ -66,7 +65,7 @@ func testLoopOutPaymentParameters(t *testing.T) { blockEpochChan := make(chan interface{}) statusChan := make(chan SwapInfo) - const maxParts = 5 + const maxParts = uint32(5) chanSet := loopdb.ChannelSet{2, 3} @@ -77,9 +76,7 @@ func testLoopOutPaymentParameters(t *testing.T) { initResult, err := newLoopOutSwap( context.Background(), cfg, height, &req, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) swap := initResult.swap // Execute the swap in its own goroutine. @@ -105,9 +102,7 @@ func testLoopOutPaymentParameters(t *testing.T) { store.assertLoopOutStored() state := <-statusChan - if state.State != loopdb.StateInitiated { - t.Fatal("unexpected state") - } + require.Equal(t, loopdb.StateInitiated, state.State) // Check that the SwapInfo contains the outgoing chan set require.Equal(t, chanSet, state.OutgoingChanSet) @@ -130,18 +125,12 @@ func testLoopOutPaymentParameters(t *testing.T) { } // Assert that it is sent as a multi-part payment. - if swapPayment.MaxParts != maxParts { - t.Fatalf("Expected %v parts, but got %v", - maxParts, swapPayment.MaxParts) - } + require.Equal(t, maxParts, swapPayment.MaxParts) // Verify the outgoing channel set restriction. - if !reflect.DeepEqual( - []uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds, - ) { - - t.Fatalf("Unexpected outgoing channel set") - } + require.Equal( + t, []uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds, + ) // Swap is expected to register for confirmation of the htlc. Assert // this to prevent a blocked channel in the mock. @@ -152,10 +141,7 @@ func testLoopOutPaymentParameters(t *testing.T) { cancel() // Expect the swap to signal that it was cancelled. - err = <-errChan - if err != context.Canceled { - t.Fatal(err) - } + require.Equal(t, context.Canceled, <-errChan) } // TestLateHtlcPublish tests that the client is not revealing the preimage if @@ -198,9 +184,7 @@ func testLateHtlcPublish(t *testing.T) { initResult, err := newLoopOutSwap( context.Background(), cfg, height, testRequest, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) swap := initResult.swap sweeper := &sweep.Sweeper{Lnd: &lnd.LndServices} @@ -225,11 +209,8 @@ func testLateHtlcPublish(t *testing.T) { }() store.assertLoopOutStored() - - state := <-statusChan - if state.State != loopdb.StateInitiated { - t.Fatal("unexpected state") - } + status := <-statusChan + require.Equal(t, loopdb.StateInitiated, status.State) signalSwapPaymentResult := ctx.AssertPaid(swapInvoiceDesc) signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) @@ -249,15 +230,9 @@ func testLateHtlcPublish(t *testing.T) { store.assertStoreFinished(loopdb.StateFailTimeout) - status := <-statusChan - if status.State != loopdb.StateFailTimeout { - t.Fatal("unexpected state") - } - - err = <-errChan - if err != nil { - t.Fatal(err) - } + status = <-statusChan + require.Equal(t, loopdb.StateFailTimeout, status.State) + require.NoError(t, <-errChan) } // TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a @@ -304,9 +279,7 @@ func testCustomSweepConfTarget(t *testing.T) { initResult, err := newLoopOutSwap( context.Background(), cfg, ctx.Lnd.Height, &testReq, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) swap := initResult.swap // Set up the required dependencies to execute the swap. @@ -339,9 +312,7 @@ func testCustomSweepConfTarget(t *testing.T) { // The swap should be found in its initial state. cfg.store.(*storeMock).assertLoopOutStored() state := <-statusChan - if state.State != loopdb.StateInitiated { - t.Fatal("unexpected state") - } + require.Equal(t, loopdb.StateInitiated, state.State) // We'll then pay both the swap and prepay invoice, which should trigger // the server to publish the on-chain HTLC. @@ -381,10 +352,7 @@ func testCustomSweepConfTarget(t *testing.T) { cfg.store.(*storeMock).assertLoopOutState(loopdb.StatePreimageRevealed) status := <-statusChan - if status.State != loopdb.StatePreimageRevealed { - t.Fatalf("expected state %v, got %v", - loopdb.StatePreimageRevealed, status.State) - } + require.Equal(t, loopdb.StatePreimageRevealed, status.State) // When using taproot htlcs the flow is different as we do reveal the // preimage before sweeping in order for the server to trust us with @@ -410,10 +378,10 @@ func testCustomSweepConfTarget(t *testing.T) { t.Helper() sweepTx := ctx.ReceiveTx() - if sweepTx.TxIn[0].PreviousOutPoint.Hash != htlcTx.TxHash() { - t.Fatalf("expected sweep tx to spend %v, got %v", - htlcTx.TxHash(), sweepTx.TxIn[0].PreviousOutPoint) - } + require.Equal( + t, htlcTx.TxHash(), + sweepTx.TxIn[0].PreviousOutPoint.Hash, + ) // The fee used for the sweep transaction is an estimate based // on the maximum witness size, so we should expect to see a @@ -427,16 +395,14 @@ func testCustomSweepConfTarget(t *testing.T) { feeRate, err := ctx.Lnd.WalletKit.EstimateFeeRate( context.Background(), expConfTarget, ) - if err != nil { - t.Fatalf("unable to retrieve fee estimate: %v", err) - } + require.NoError(t, err, "unable to retrieve fee estimate") + minFee := feeRate.FeeForWeight(weight) - maxFee := btcutil.Amount(float64(minFee) * 1.1) + // Just an estimate that works to sanity check fee upper bound. + maxFee := btcutil.Amount(float64(minFee) * 1.5) - if fee < minFee && fee > maxFee { - t.Fatalf("expected sweep tx to have fee between %v-%v, "+ - "got %v", minFee, maxFee, fee) - } + require.GreaterOrEqual(t, fee, minFee) + require.LessOrEqual(t, fee, maxFee) return sweepTx } @@ -479,14 +445,8 @@ func testCustomSweepConfTarget(t *testing.T) { cfg.store.(*storeMock).assertLoopOutState(loopdb.StateSuccess) status = <-statusChan - if status.State != loopdb.StateSuccess { - t.Fatalf("expected state %v, got %v", loopdb.StateSuccess, - status.State) - } - - if err := <-errChan; err != nil { - t.Fatal(err) - } + require.Equal(t, loopdb.StateSuccess, status.State) + require.NoError(t, <-errChan) } // TestPreimagePush tests or logic that decides whether to push our preimage to diff --git a/store_mock_test.go b/store_mock_test.go index a366fcd..4ec9daa 100644 --- a/store_mock_test.go +++ b/store_mock_test.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/test" "github.com/lightningnetwork/lnd/lntypes" + "github.com/stretchr/testify/require" ) // storeMock implements a mock client swap store. @@ -239,9 +240,7 @@ func (s *storeMock) assertLoopInState( s.t.Helper() state := <-s.loopInUpdateChan - if state.State != expectedState { - s.t.Fatalf("expected state %v, got %v", expectedState, state) - } + require.Equal(s.t, expectedState, state.State) return state } @@ -252,9 +251,8 @@ func (s *storeMock) assertStorePreimageReveal() { select { case state := <-s.loopOutUpdateChan: - if state.State != loopdb.StatePreimageRevealed { - s.t.Fatalf("unexpected state") - } + require.Equal(s.t, loopdb.StatePreimageRevealed, state.State) + case <-time.After(test.Timeout): s.t.Fatalf("expected swap to be marked as preimage revealed") } @@ -265,10 +263,8 @@ func (s *storeMock) assertStoreFinished(expectedResult loopdb.SwapState) { select { case state := <-s.loopOutUpdateChan: - if state.State != expectedResult { - s.t.Fatalf("expected result %v, but got %v", - expectedResult, state) - } + require.Equal(s.t, expectedResult, state.State) + case <-time.After(test.Timeout): s.t.Fatalf("expected swap to be finished") } diff --git a/swap/htlc_test.go b/swap/htlc_test.go index dd309e5..09f9264 100644 --- a/swap/htlc_test.go +++ b/swap/htlc_test.go @@ -54,9 +54,7 @@ func assertEngineExecution(t *testing.T, valid bool, done := false for !done { dis, err := vm.DisasmPC() - if err != nil { - t.Fatalf("stepping (%v)\n", err) - } + require.NoError(t, err, "stepping") debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis)) done, err = vm.Step() diff --git a/test/context.go b/test/context.go index ca83322..e8c517c 100644 --- a/test/context.go +++ b/test/context.go @@ -86,9 +86,11 @@ func (ctx *Context) AssertRegisterSpendNtfn(script []byte) { select { case spendIntent := <-ctx.Lnd.RegisterSpendChannel: - if !bytes.Equal(spendIntent.PkScript, script) { - ctx.T.Fatalf("server not listening for published htlc script") - } + require.Equal( + ctx.T, script, spendIntent.PkScript, + "server not listening for published htlc script", + ) + case <-time.After(Timeout): DumpGoroutines() ctx.T.Fatalf("spend not subscribed to") @@ -163,10 +165,11 @@ func (ctx *Context) AssertPaid( payReq := ctx.DecodeInvoice(swapPayment.Invoice) - if _, ok := ctx.PaidInvoices[*payReq.Description]; ok { - ctx.T.Fatalf("duplicate invoice paid: %v", - *payReq.Description) - } + _, ok := ctx.PaidInvoices[*payReq.Description] + require.False( + ctx.T, ok, + "duplicate invoice paid: %v", *payReq.Description, + ) done := func(result error) { if result != nil { @@ -195,9 +198,10 @@ func (ctx *Context) AssertSettled( select { case preimage := <-ctx.Lnd.SettleInvoiceChannel: hash := sha256.Sum256(preimage[:]) - if expectedHash != hash { - ctx.T.Fatalf("server claims with wrong preimage") - } + require.Equal( + ctx.T, expectedHash, lntypes.Hash(hash), + "server claims with wrong preimage", + ) return preimage case <-time.After(Timeout): @@ -232,9 +236,8 @@ func (ctx *Context) DecodeInvoice(request string) *zpay32.Invoice { ctx.T.Helper() payReq, err := ctx.Lnd.DecodeInvoice(request) - if err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, err) + return payReq } @@ -256,7 +259,5 @@ func (ctx *Context) GetOutputIndex(tx *wire.MsgTx, // waits for the notification to be processed by selecting on a // dedicated test channel. func (ctx *Context) NotifyServerHeight(height int32) { - if err := ctx.Lnd.NotifyHeight(height); err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height)) } diff --git a/test/testutils.go b/test/testutils.go index c00ea96..6aa312e 100644 --- a/test/testutils.go +++ b/test/testutils.go @@ -14,6 +14,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/zpay32" + "github.com/stretchr/testify/require" ) var ( @@ -29,11 +30,10 @@ var ( // GetDestAddr deterministically generates a sweep address for testing. func GetDestAddr(t *testing.T, nr byte) btcutil.Address { - destAddr, err := btcutil.NewAddressScriptHash([]byte{nr}, - &chaincfg.MainNetParams) - if err != nil { - t.Fatal(err) - } + destAddr, err := btcutil.NewAddressScriptHash( + []byte{nr}, &chaincfg.MainNetParams, + ) + require.NoError(t, err) return destAddr } diff --git a/testcontext_test.go b/testcontext_test.go index afc6a90..a221a28 100644 --- a/testcontext_test.go +++ b/testcontext_test.go @@ -140,9 +140,8 @@ func (ctx *testContext) finish() { ctx.stop() select { case err := <-ctx.runErr: - if err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, err) + case <-time.After(test.Timeout): ctx.T.Fatal("client not stopping") } @@ -156,19 +155,12 @@ func (ctx *testContext) finish() { func (ctx *testContext) notifyHeight(height int32) { ctx.T.Helper() - if err := ctx.Lnd.NotifyHeight(height); err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height)) } func (ctx *testContext) assertIsDone() { - if err := ctx.Lnd.IsDone(); err != nil { - ctx.T.Fatal(err) - } - - if err := ctx.store.isDone(); err != nil { - ctx.T.Fatal(err) - } + require.NoError(ctx.T, ctx.Lnd.IsDone()) + require.NoError(ctx.T, ctx.store.isDone()) select { case <-ctx.statusChan: