misc: refactor loop tests to use require where possible

pull/541/head
Andras Banki-Horvath 1 year ago
parent bdb4b773ed
commit 049b17ff96
No known key found for this signature in database
GPG Key ID: 80E5375C094198D8

@ -1,7 +1,6 @@
package loop
import (
"bytes"
"context"
"crypto/sha256"
"errors"
@ -57,9 +56,7 @@ func TestLoopOutSuccess(t *testing.T) {
// Initiate loop out.
info, err := ctx.swapClient.LoopOut(context.Background(), &req)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
ctx.assertStored()
ctx.assertStatus(loopdb.StateInitiated)
@ -84,9 +81,7 @@ func TestLoopOutFailOffchain(t *testing.T) {
ctx := createClientTestContext(t, nil)
_, err := ctx.swapClient.LoopOut(context.Background(), testRequest)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
ctx.assertStored()
ctx.assertStatus(loopdb.StateInitiated)
@ -208,14 +203,10 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
amt := btcutil.Amount(50000)
swapPayReq, err := getInvoice(hash, amt, swapInvoiceDesc)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
prePayReq, err := getInvoice(hash, 100, prepayInvoiceDesc)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
_, senderPubKey := test.CreateKey(1)
var senderKey [33]byte
@ -373,10 +364,11 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash,
// Expect client on-chain sweep of HTLC.
sweepTx := ctx.ReceiveTx()
if !bytes.Equal(sweepTx.TxIn[0].PreviousOutPoint.Hash[:],
htlcOutpoint.Hash[:]) {
ctx.T.Fatalf("client not sweeping from htlc tx")
}
require.Equal(
ctx.T, htlcOutpoint.Hash[:],
sweepTx.TxIn[0].PreviousOutPoint.Hash[:],
"client not sweeping from htlc tx",
)
var preImageIndex int
switch scriptVersion {
@ -390,9 +382,7 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash,
// Check preimage.
clientPreImage := sweepTx.TxIn[0].Witness[preImageIndex]
clientPreImageHash := sha256.Sum256(clientPreImage)
if clientPreImageHash != hash {
ctx.T.Fatalf("incorrect preimage")
}
require.Equal(ctx.T, hash, lntypes.Hash(clientPreImageHash))
// Since we successfully published our sweep, we expect the preimage to
// have been pushed to our mock server.

@ -130,16 +130,13 @@ func TestValidateConfTarget(t *testing.T) {
test.confTarget, defaultConf,
)
haveErr := err != nil
if haveErr != test.expectErr {
t.Fatalf("expected err: %v, got: %v",
test.expectErr, err)
if test.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
if target != test.expectedTarget {
t.Fatalf("expected: %v, got: %v",
test.expectedTarget, target)
}
require.Equal(t, test.expectedTarget, target)
})
}
}
@ -199,16 +196,13 @@ func TestValidateLoopInRequest(t *testing.T) {
test.confTarget, external,
)
haveErr := err != nil
if haveErr != test.expectErr {
t.Fatalf("expected err: %v, got: %v",
test.expectErr, err)
if test.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
if conf != test.expectedTarget {
t.Fatalf("expected: %v, got: %v",
test.expectedTarget, conf)
}
require.Equal(t, test.expectedTarget, conf)
})
}
}

@ -58,9 +58,8 @@ func testLoopInSuccess(t *testing.T) {
context.Background(), cfg,
height, req,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
inSwap := initResult.swap
ctx.store.assertLoopInStored()
@ -142,10 +141,7 @@ func testLoopInSuccess(t *testing.T) {
ctx.assertState(loopdb.StateSuccess)
ctx.store.assertLoopInState(loopdb.StateSuccess)
err = <-errChan
if err != nil {
t.Fatal(err)
}
require.NoError(t, <-errChan)
}
// TestLoopInTimeout tests scenarios where the server doesn't sweep the htlc
@ -215,9 +211,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
context.Background(), cfg,
height, &req,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
inSwap := initResult.swap
ctx.store.assertLoopInStored()
@ -289,11 +283,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
ctx.assertState(loopdb.StateFailIncorrectHtlcAmt)
ctx.store.assertLoopInState(loopdb.StateFailIncorrectHtlcAmt)
err = <-errChan
if err != nil {
t.Fatal(err)
}
require.NoError(t, <-errChan)
return
}
@ -308,9 +298,11 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
// Expect a signing request for the htlc tx output value.
signReq := <-ctx.lnd.SignOutputRawChannel
if signReq.SignDescriptors[0].Output.Value != htlcTx.TxOut[0].Value {
t.Fatal("invalid signing amount")
}
require.Equal(
t, htlcTx.TxOut[0].Value,
signReq.SignDescriptors[0].Output.Value,
"invalid signing amount",
)
// Expect timeout tx to be published.
timeoutTx := <-ctx.lnd.TxPublishChannel
@ -341,10 +333,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
state := ctx.store.assertLoopInState(loopdb.StateFailTimeout)
require.Equal(t, cost, state.Cost)
err = <-errChan
if err != nil {
t.Fatal(err)
}
require.NoError(t, <-errChan)
}
// TestLoopInResume tests resuming swaps in various states.
@ -483,17 +472,10 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool,
require.NoError(t, err)
err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
inSwap, err := resumeLoopInSwap(
context.Background(), cfg,
pendSwap,
)
if err != nil {
t.Fatal(err)
}
inSwap, err := resumeLoopInSwap(context.Background(), cfg, pendSwap)
require.NoError(t, err)
var height int32
if expired {
@ -512,10 +494,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool,
}()
defer func() {
err = <-errChan
if err != nil {
t.Fatal(err)
}
require.NoError(t, <-errChan)
select {
case <-ctx.lnd.SendPaymentChannel:

@ -63,10 +63,7 @@ func newLoopInTestContext(t *testing.T) *loopInTestContext {
func (c *loopInTestContext) assertState(expectedState loopdb.SwapState) {
state := <-c.statusChan
if state.State != expectedState {
c.t.Fatalf("expected state %v but got %v", expectedState,
state.State)
}
require.Equal(c.t, expectedState, state.State)
}
// assertSubscribeInvoice asserts that the client subscribes to invoice updates

@ -4,7 +4,6 @@ import (
"context"
"errors"
"math"
"reflect"
"testing"
"time"
@ -66,7 +65,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
blockEpochChan := make(chan interface{})
statusChan := make(chan SwapInfo)
const maxParts = 5
const maxParts = uint32(5)
chanSet := loopdb.ChannelSet{2, 3}
@ -77,9 +76,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
initResult, err := newLoopOutSwap(
context.Background(), cfg, height, &req,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
swap := initResult.swap
// Execute the swap in its own goroutine.
@ -105,9 +102,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
store.assertLoopOutStored()
state := <-statusChan
if state.State != loopdb.StateInitiated {
t.Fatal("unexpected state")
}
require.Equal(t, loopdb.StateInitiated, state.State)
// Check that the SwapInfo contains the outgoing chan set
require.Equal(t, chanSet, state.OutgoingChanSet)
@ -130,18 +125,12 @@ func testLoopOutPaymentParameters(t *testing.T) {
}
// Assert that it is sent as a multi-part payment.
if swapPayment.MaxParts != maxParts {
t.Fatalf("Expected %v parts, but got %v",
maxParts, swapPayment.MaxParts)
}
require.Equal(t, maxParts, swapPayment.MaxParts)
// Verify the outgoing channel set restriction.
if !reflect.DeepEqual(
[]uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds,
) {
t.Fatalf("Unexpected outgoing channel set")
}
require.Equal(
t, []uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds,
)
// Swap is expected to register for confirmation of the htlc. Assert
// this to prevent a blocked channel in the mock.
@ -152,10 +141,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
cancel()
// Expect the swap to signal that it was cancelled.
err = <-errChan
if err != context.Canceled {
t.Fatal(err)
}
require.Equal(t, context.Canceled, <-errChan)
}
// TestLateHtlcPublish tests that the client is not revealing the preimage if
@ -198,9 +184,7 @@ func testLateHtlcPublish(t *testing.T) {
initResult, err := newLoopOutSwap(
context.Background(), cfg, height, testRequest,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
swap := initResult.swap
sweeper := &sweep.Sweeper{Lnd: &lnd.LndServices}
@ -225,11 +209,8 @@ func testLateHtlcPublish(t *testing.T) {
}()
store.assertLoopOutStored()
state := <-statusChan
if state.State != loopdb.StateInitiated {
t.Fatal("unexpected state")
}
status := <-statusChan
require.Equal(t, loopdb.StateInitiated, status.State)
signalSwapPaymentResult := ctx.AssertPaid(swapInvoiceDesc)
signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc)
@ -249,15 +230,9 @@ func testLateHtlcPublish(t *testing.T) {
store.assertStoreFinished(loopdb.StateFailTimeout)
status := <-statusChan
if status.State != loopdb.StateFailTimeout {
t.Fatal("unexpected state")
}
err = <-errChan
if err != nil {
t.Fatal(err)
}
status = <-statusChan
require.Equal(t, loopdb.StateFailTimeout, status.State)
require.NoError(t, <-errChan)
}
// TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a
@ -304,9 +279,7 @@ func testCustomSweepConfTarget(t *testing.T) {
initResult, err := newLoopOutSwap(
context.Background(), cfg, ctx.Lnd.Height, &testReq,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
swap := initResult.swap
// Set up the required dependencies to execute the swap.
@ -339,9 +312,7 @@ func testCustomSweepConfTarget(t *testing.T) {
// The swap should be found in its initial state.
cfg.store.(*storeMock).assertLoopOutStored()
state := <-statusChan
if state.State != loopdb.StateInitiated {
t.Fatal("unexpected state")
}
require.Equal(t, loopdb.StateInitiated, state.State)
// We'll then pay both the swap and prepay invoice, which should trigger
// the server to publish the on-chain HTLC.
@ -381,10 +352,7 @@ func testCustomSweepConfTarget(t *testing.T) {
cfg.store.(*storeMock).assertLoopOutState(loopdb.StatePreimageRevealed)
status := <-statusChan
if status.State != loopdb.StatePreimageRevealed {
t.Fatalf("expected state %v, got %v",
loopdb.StatePreimageRevealed, status.State)
}
require.Equal(t, loopdb.StatePreimageRevealed, status.State)
// When using taproot htlcs the flow is different as we do reveal the
// preimage before sweeping in order for the server to trust us with
@ -410,10 +378,10 @@ func testCustomSweepConfTarget(t *testing.T) {
t.Helper()
sweepTx := ctx.ReceiveTx()
if sweepTx.TxIn[0].PreviousOutPoint.Hash != htlcTx.TxHash() {
t.Fatalf("expected sweep tx to spend %v, got %v",
htlcTx.TxHash(), sweepTx.TxIn[0].PreviousOutPoint)
}
require.Equal(
t, htlcTx.TxHash(),
sweepTx.TxIn[0].PreviousOutPoint.Hash,
)
// The fee used for the sweep transaction is an estimate based
// on the maximum witness size, so we should expect to see a
@ -427,16 +395,14 @@ func testCustomSweepConfTarget(t *testing.T) {
feeRate, err := ctx.Lnd.WalletKit.EstimateFeeRate(
context.Background(), expConfTarget,
)
if err != nil {
t.Fatalf("unable to retrieve fee estimate: %v", err)
}
require.NoError(t, err, "unable to retrieve fee estimate")
minFee := feeRate.FeeForWeight(weight)
maxFee := btcutil.Amount(float64(minFee) * 1.1)
// Just an estimate that works to sanity check fee upper bound.
maxFee := btcutil.Amount(float64(minFee) * 1.5)
if fee < minFee && fee > maxFee {
t.Fatalf("expected sweep tx to have fee between %v-%v, "+
"got %v", minFee, maxFee, fee)
}
require.GreaterOrEqual(t, fee, minFee)
require.LessOrEqual(t, fee, maxFee)
return sweepTx
}
@ -479,14 +445,8 @@ func testCustomSweepConfTarget(t *testing.T) {
cfg.store.(*storeMock).assertLoopOutState(loopdb.StateSuccess)
status = <-statusChan
if status.State != loopdb.StateSuccess {
t.Fatalf("expected state %v, got %v", loopdb.StateSuccess,
status.State)
}
if err := <-errChan; err != nil {
t.Fatal(err)
}
require.Equal(t, loopdb.StateSuccess, status.State)
require.NoError(t, <-errChan)
}
// TestPreimagePush tests or logic that decides whether to push our preimage to

@ -8,6 +8,7 @@ import (
"github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/test"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/stretchr/testify/require"
)
// storeMock implements a mock client swap store.
@ -239,9 +240,7 @@ func (s *storeMock) assertLoopInState(
s.t.Helper()
state := <-s.loopInUpdateChan
if state.State != expectedState {
s.t.Fatalf("expected state %v, got %v", expectedState, state)
}
require.Equal(s.t, expectedState, state.State)
return state
}
@ -252,9 +251,8 @@ func (s *storeMock) assertStorePreimageReveal() {
select {
case state := <-s.loopOutUpdateChan:
if state.State != loopdb.StatePreimageRevealed {
s.t.Fatalf("unexpected state")
}
require.Equal(s.t, loopdb.StatePreimageRevealed, state.State)
case <-time.After(test.Timeout):
s.t.Fatalf("expected swap to be marked as preimage revealed")
}
@ -265,10 +263,8 @@ func (s *storeMock) assertStoreFinished(expectedResult loopdb.SwapState) {
select {
case state := <-s.loopOutUpdateChan:
if state.State != expectedResult {
s.t.Fatalf("expected result %v, but got %v",
expectedResult, state)
}
require.Equal(s.t, expectedResult, state.State)
case <-time.After(test.Timeout):
s.t.Fatalf("expected swap to be finished")
}

@ -54,9 +54,7 @@ func assertEngineExecution(t *testing.T, valid bool,
done := false
for !done {
dis, err := vm.DisasmPC()
if err != nil {
t.Fatalf("stepping (%v)\n", err)
}
require.NoError(t, err, "stepping")
debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis))
done, err = vm.Step()

@ -86,9 +86,11 @@ func (ctx *Context) AssertRegisterSpendNtfn(script []byte) {
select {
case spendIntent := <-ctx.Lnd.RegisterSpendChannel:
if !bytes.Equal(spendIntent.PkScript, script) {
ctx.T.Fatalf("server not listening for published htlc script")
}
require.Equal(
ctx.T, script, spendIntent.PkScript,
"server not listening for published htlc script",
)
case <-time.After(Timeout):
DumpGoroutines()
ctx.T.Fatalf("spend not subscribed to")
@ -163,10 +165,11 @@ func (ctx *Context) AssertPaid(
payReq := ctx.DecodeInvoice(swapPayment.Invoice)
if _, ok := ctx.PaidInvoices[*payReq.Description]; ok {
ctx.T.Fatalf("duplicate invoice paid: %v",
*payReq.Description)
}
_, ok := ctx.PaidInvoices[*payReq.Description]
require.False(
ctx.T, ok,
"duplicate invoice paid: %v", *payReq.Description,
)
done := func(result error) {
if result != nil {
@ -195,9 +198,10 @@ func (ctx *Context) AssertSettled(
select {
case preimage := <-ctx.Lnd.SettleInvoiceChannel:
hash := sha256.Sum256(preimage[:])
if expectedHash != hash {
ctx.T.Fatalf("server claims with wrong preimage")
}
require.Equal(
ctx.T, expectedHash, lntypes.Hash(hash),
"server claims with wrong preimage",
)
return preimage
case <-time.After(Timeout):
@ -232,9 +236,8 @@ func (ctx *Context) DecodeInvoice(request string) *zpay32.Invoice {
ctx.T.Helper()
payReq, err := ctx.Lnd.DecodeInvoice(request)
if err != nil {
ctx.T.Fatal(err)
}
require.NoError(ctx.T, err)
return payReq
}
@ -256,7 +259,5 @@ func (ctx *Context) GetOutputIndex(tx *wire.MsgTx,
// waits for the notification to be processed by selecting on a
// dedicated test channel.
func (ctx *Context) NotifyServerHeight(height int32) {
if err := ctx.Lnd.NotifyHeight(height); err != nil {
ctx.T.Fatal(err)
}
require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height))
}

@ -14,6 +14,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/zpay32"
"github.com/stretchr/testify/require"
)
var (
@ -29,11 +30,10 @@ var (
// GetDestAddr deterministically generates a sweep address for testing.
func GetDestAddr(t *testing.T, nr byte) btcutil.Address {
destAddr, err := btcutil.NewAddressScriptHash([]byte{nr},
&chaincfg.MainNetParams)
if err != nil {
t.Fatal(err)
}
destAddr, err := btcutil.NewAddressScriptHash(
[]byte{nr}, &chaincfg.MainNetParams,
)
require.NoError(t, err)
return destAddr
}

@ -140,9 +140,8 @@ func (ctx *testContext) finish() {
ctx.stop()
select {
case err := <-ctx.runErr:
if err != nil {
ctx.T.Fatal(err)
}
require.NoError(ctx.T, err)
case <-time.After(test.Timeout):
ctx.T.Fatal("client not stopping")
}
@ -156,19 +155,12 @@ func (ctx *testContext) finish() {
func (ctx *testContext) notifyHeight(height int32) {
ctx.T.Helper()
if err := ctx.Lnd.NotifyHeight(height); err != nil {
ctx.T.Fatal(err)
}
require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height))
}
func (ctx *testContext) assertIsDone() {
if err := ctx.Lnd.IsDone(); err != nil {
ctx.T.Fatal(err)
}
if err := ctx.store.isDone(); err != nil {
ctx.T.Fatal(err)
}
require.NoError(ctx.T, ctx.Lnd.IsDone())
require.NoError(ctx.T, ctx.store.isDone())
select {
case <-ctx.statusChan:

Loading…
Cancel
Save