Merge pull request #541 from bhandras/htlc-v3-interal-key

swap: refactor htlc construction to allow passing of internal keys
pull/527/head
András Bánki-Horváth 1 year ago committed by GitHub
commit 446f163530
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -213,29 +213,24 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
SwapHash: swp.Hash, SwapHash: swp.Hash,
LastUpdate: swp.LastUpdateTime(), LastUpdate: swp.LastUpdateTime(),
} }
scriptVersion := GetHtlcScriptVersion(
swp.Contract.ProtocolVersion,
)
outputType := swap.HtlcP2WSH
if scriptVersion == swap.HtlcV3 {
outputType = swap.HtlcP2TR
}
htlc, err := swap.NewHtlc( htlc, err := GetHtlc(
scriptVersion, swp.Hash, &swp.Contract.SwapContract,
swp.Contract.CltvExpiry, swp.Contract.SenderKey, s.lndServices.ChainParams,
swp.Contract.ReceiverKey, swp.Hash,
outputType, s.lndServices.ChainParams,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
if outputType == swap.HtlcP2TR { switch htlc.OutputType {
swapInfo.HtlcAddressP2TR = htlc.Address case swap.HtlcP2WSH:
} else {
swapInfo.HtlcAddressP2WSH = htlc.Address swapInfo.HtlcAddressP2WSH = htlc.Address
case swap.HtlcP2TR:
swapInfo.HtlcAddressP2TR = htlc.Address
default:
return nil, swap.ErrInvalidOutputType
} }
swaps = append(swaps, swapInfo) swaps = append(swaps, swapInfo)
@ -250,34 +245,23 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
LastUpdate: swp.LastUpdateTime(), LastUpdate: swp.LastUpdateTime(),
} }
scriptVersion := GetHtlcScriptVersion( htlc, err := GetHtlc(
swp.Contract.SwapContract.ProtocolVersion, swp.Hash, &swp.Contract.SwapContract,
s.lndServices.ChainParams,
) )
if err != nil {
return nil, err
}
if scriptVersion == swap.HtlcV3 { switch htlc.OutputType {
htlcP2TR, err := swap.NewHtlc( case swap.HtlcP2WSH:
swap.HtlcV3, swp.Contract.CltvExpiry, swapInfo.HtlcAddressP2WSH = htlc.Address
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
swp.Hash, swap.HtlcP2TR,
s.lndServices.ChainParams,
)
if err != nil {
return nil, err
}
swapInfo.HtlcAddressP2TR = htlcP2TR.Address case swap.HtlcP2TR:
} else { swapInfo.HtlcAddressP2TR = htlc.Address
htlcP2WSH, err := swap.NewHtlc(
swap.HtlcV2, swp.Contract.CltvExpiry,
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
swp.Hash, swap.HtlcP2WSH,
s.lndServices.ChainParams,
)
if err != nil {
return nil, err
}
swapInfo.HtlcAddressP2WSH = htlcP2WSH.Address default:
return nil, swap.ErrInvalidOutputType
} }
swaps = append(swaps, swapInfo) swaps = append(swaps, swapInfo)

@ -1,7 +1,6 @@
package loop package loop
import ( import (
"bytes"
"context" "context"
"crypto/sha256" "crypto/sha256"
"errors" "errors"
@ -57,9 +56,7 @@ func TestLoopOutSuccess(t *testing.T) {
// Initiate loop out. // Initiate loop out.
info, err := ctx.swapClient.LoopOut(context.Background(), &req) info, err := ctx.swapClient.LoopOut(context.Background(), &req)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
ctx.assertStored() ctx.assertStored()
ctx.assertStatus(loopdb.StateInitiated) ctx.assertStatus(loopdb.StateInitiated)
@ -84,9 +81,7 @@ func TestLoopOutFailOffchain(t *testing.T) {
ctx := createClientTestContext(t, nil) ctx := createClientTestContext(t, nil)
_, err := ctx.swapClient.LoopOut(context.Background(), testRequest) _, err := ctx.swapClient.LoopOut(context.Background(), testRequest)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
ctx.assertStored() ctx.assertStored()
ctx.assertStatus(loopdb.StateInitiated) ctx.assertStatus(loopdb.StateInitiated)
@ -208,14 +203,10 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
amt := btcutil.Amount(50000) amt := btcutil.Amount(50000)
swapPayReq, err := getInvoice(hash, amt, swapInvoiceDesc) swapPayReq, err := getInvoice(hash, amt, swapInvoiceDesc)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
prePayReq, err := getInvoice(hash, 100, prepayInvoiceDesc) prePayReq, err := getInvoice(hash, 100, prepayInvoiceDesc)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
_, senderPubKey := test.CreateKey(1) _, senderPubKey := test.CreateKey(1)
var senderKey [33]byte var senderKey [33]byte
@ -284,16 +275,26 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
// Assert that the loopout htlc equals to the expected one. // Assert that the loopout htlc equals to the expected one.
scriptVersion := GetHtlcScriptVersion(protocolVersion) scriptVersion := GetHtlcScriptVersion(protocolVersion)
var htlc *swap.Htlc
outputType := swap.HtlcP2TR switch scriptVersion {
if scriptVersion != swap.HtlcV3 { case swap.HtlcV2:
outputType = swap.HtlcP2WSH htlc, err = swap.NewHtlcV2(
pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, hash, &chaincfg.TestNet3Params,
)
case swap.HtlcV3:
htlc, err = swap.NewHtlcV3(
pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, senderKey, receiverKey, hash,
&chaincfg.TestNet3Params,
)
default:
t.Fatalf(swap.ErrInvalidScriptVersion.Error())
} }
htlc, err := swap.NewHtlc(
scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, hash, outputType, &chaincfg.TestNet3Params,
)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, htlc.PkScript, confIntent.PkScript) require.Equal(t, htlc.PkScript, confIntent.PkScript)
@ -363,10 +364,11 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash,
// Expect client on-chain sweep of HTLC. // Expect client on-chain sweep of HTLC.
sweepTx := ctx.ReceiveTx() sweepTx := ctx.ReceiveTx()
if !bytes.Equal(sweepTx.TxIn[0].PreviousOutPoint.Hash[:], require.Equal(
htlcOutpoint.Hash[:]) { ctx.T, htlcOutpoint.Hash[:],
ctx.T.Fatalf("client not sweeping from htlc tx") sweepTx.TxIn[0].PreviousOutPoint.Hash[:],
} "client not sweeping from htlc tx",
)
var preImageIndex int var preImageIndex int
switch scriptVersion { switch scriptVersion {
@ -380,9 +382,7 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash,
// Check preimage. // Check preimage.
clientPreImage := sweepTx.TxIn[0].Witness[preImageIndex] clientPreImage := sweepTx.TxIn[0].Witness[preImageIndex]
clientPreImageHash := sha256.Sum256(clientPreImage) clientPreImageHash := sha256.Sum256(clientPreImage)
if clientPreImageHash != hash { require.Equal(ctx.T, hash, lntypes.Hash(clientPreImageHash))
ctx.T.Fatalf("incorrect preimage")
}
// Since we successfully published our sweep, we expect the preimage to // Since we successfully published our sweep, we expect the preimage to
// have been pushed to our mock server. // have been pushed to our mock server.

@ -130,16 +130,13 @@ func TestValidateConfTarget(t *testing.T) {
test.confTarget, defaultConf, test.confTarget, defaultConf,
) )
haveErr := err != nil if test.expectErr {
if haveErr != test.expectErr { require.Error(t, err)
t.Fatalf("expected err: %v, got: %v", } else {
test.expectErr, err) require.NoError(t, err)
} }
if target != test.expectedTarget { require.Equal(t, test.expectedTarget, target)
t.Fatalf("expected: %v, got: %v",
test.expectedTarget, target)
}
}) })
} }
} }
@ -199,16 +196,13 @@ func TestValidateLoopInRequest(t *testing.T) {
test.confTarget, external, test.confTarget, external,
) )
haveErr := err != nil if test.expectErr {
if haveErr != test.expectErr { require.Error(t, err)
t.Fatalf("expected err: %v, got: %v", } else {
test.expectErr, err) require.NoError(t, err)
} }
if conf != test.expectedTarget { require.Equal(t, test.expectedTarget, conf)
t.Fatalf("expected: %v, got: %v",
test.expectedTarget, conf)
}
}) })
} }
} }

@ -7,7 +7,6 @@ import (
"github.com/lightninglabs/lndclient" "github.com/lightninglabs/lndclient"
"github.com/lightninglabs/loop" "github.com/lightninglabs/loop"
"github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/swap"
) )
// view prints all swaps currently in the database. // view prints all swaps currently in the database.
@ -49,24 +48,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
} }
for _, s := range swaps { for _, s := range swaps {
scriptVersion := loop.GetHtlcScriptVersion( htlc, err := loop.GetHtlc(
s.Contract.ProtocolVersion, s.Hash, &s.Contract.SwapContract, chainParams,
)
var outputType swap.HtlcOutputType
switch scriptVersion {
case swap.HtlcV2:
outputType = swap.HtlcP2WSH
case swap.HtlcV3:
outputType = swap.HtlcP2TR
}
htlc, err := swap.NewHtlc(
loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion),
s.Contract.CltvExpiry,
s.Contract.SenderKey,
s.Contract.ReceiverKey,
s.Hash, outputType, chainParams,
) )
if err != nil { if err != nil {
return err return err
@ -77,7 +60,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
s.Contract.InitiationTime, s.Contract.InitiationHeight, s.Contract.InitiationTime, s.Contract.InitiationHeight,
) )
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
fmt.Printf(" Htlc address: %v\n", htlc.Address) fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType,
htlc.Address)
fmt.Printf(" Uncharge channels: %v\n", fmt.Printf(" Uncharge channels: %v\n",
s.Contract.OutgoingChanSet) s.Contract.OutgoingChanSet)
@ -113,12 +97,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
} }
for _, s := range swaps { for _, s := range swaps {
htlc, err := swap.NewHtlc( htlc, err := loop.GetHtlc(
loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion), s.Hash, &s.Contract.SwapContract, chainParams,
s.Contract.CltvExpiry,
s.Contract.SenderKey,
s.Contract.ReceiverKey,
s.Hash, swap.HtlcP2WSH, chainParams,
) )
if err != nil { if err != nil {
return err return err
@ -129,7 +109,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
s.Contract.InitiationTime, s.Contract.InitiationHeight, s.Contract.InitiationTime, s.Contract.InitiationHeight,
) )
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
fmt.Printf(" Htlc address: %v\n", htlc.Address) fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType,
htlc.Address)
fmt.Printf(" Amt: %v, Expiry: %v\n", fmt.Printf(" Amt: %v, Expiry: %v\n",
s.Contract.AmountRequested, s.Contract.CltvExpiry, s.Contract.AmountRequested, s.Contract.CltvExpiry,
) )

@ -404,26 +404,26 @@ func validateLoopInContract(lnd *lndclient.LndServices,
// initHtlcs creates and updates the native and nested segwit htlcs // initHtlcs creates and updates the native and nested segwit htlcs
// of the loopInSwap. // of the loopInSwap.
func (s *loopInSwap) initHtlcs() error { func (s *loopInSwap) initHtlcs() error {
if IsTaprootSwap(&s.SwapContract) { htlc, err := GetHtlc(
htlcP2TR, err := s.swapKit.getHtlc(swap.HtlcP2TR) s.hash, &s.SwapContract, s.swapKit.lnd.ChainParams,
if err != nil { )
return err if err != nil {
} return err
}
s.swapKit.log.Infof("Htlc address (P2TR): %v", htlcP2TR.Address) switch htlc.OutputType {
s.htlcP2TR = htlcP2TR case swap.HtlcP2WSH:
s.htlcP2WSH = htlc
return nil case swap.HtlcP2TR:
} s.htlcP2TR = htlc
htlcP2WSH, err := s.swapKit.getHtlc(swap.HtlcP2WSH) default:
if err != nil { return fmt.Errorf("invalid output type")
return err
} }
// Log htlc addresses for debugging. s.swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType,
s.swapKit.log.Infof("Htlc address (P2WSH): %v", htlcP2WSH.Address) htlc.Address)
s.htlcP2WSH = htlcP2WSH
return nil return nil
} }

@ -58,9 +58,8 @@ func testLoopInSuccess(t *testing.T) {
context.Background(), cfg, context.Background(), cfg,
height, req, height, req,
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
inSwap := initResult.swap inSwap := initResult.swap
ctx.store.assertLoopInStored() ctx.store.assertLoopInStored()
@ -142,10 +141,7 @@ func testLoopInSuccess(t *testing.T) {
ctx.assertState(loopdb.StateSuccess) ctx.assertState(loopdb.StateSuccess)
ctx.store.assertLoopInState(loopdb.StateSuccess) ctx.store.assertLoopInState(loopdb.StateSuccess)
err = <-errChan require.NoError(t, <-errChan)
if err != nil {
t.Fatal(err)
}
} }
// TestLoopInTimeout tests scenarios where the server doesn't sweep the htlc // TestLoopInTimeout tests scenarios where the server doesn't sweep the htlc
@ -215,9 +211,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
context.Background(), cfg, context.Background(), cfg,
height, &req, height, &req,
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
inSwap := initResult.swap inSwap := initResult.swap
ctx.store.assertLoopInStored() ctx.store.assertLoopInStored()
@ -289,11 +283,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
ctx.assertState(loopdb.StateFailIncorrectHtlcAmt) ctx.assertState(loopdb.StateFailIncorrectHtlcAmt)
ctx.store.assertLoopInState(loopdb.StateFailIncorrectHtlcAmt) ctx.store.assertLoopInState(loopdb.StateFailIncorrectHtlcAmt)
err = <-errChan require.NoError(t, <-errChan)
if err != nil {
t.Fatal(err)
}
return return
} }
@ -308,9 +298,11 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
// Expect a signing request for the htlc tx output value. // Expect a signing request for the htlc tx output value.
signReq := <-ctx.lnd.SignOutputRawChannel signReq := <-ctx.lnd.SignOutputRawChannel
if signReq.SignDescriptors[0].Output.Value != htlcTx.TxOut[0].Value { require.Equal(
t.Fatal("invalid signing amount") t, htlcTx.TxOut[0].Value,
} signReq.SignDescriptors[0].Output.Value,
"invalid signing amount",
)
// Expect timeout tx to be published. // Expect timeout tx to be published.
timeoutTx := <-ctx.lnd.TxPublishChannel timeoutTx := <-ctx.lnd.TxPublishChannel
@ -341,10 +333,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
state := ctx.store.assertLoopInState(loopdb.StateFailTimeout) state := ctx.store.assertLoopInState(loopdb.StateFailTimeout)
require.Equal(t, cost, state.Cost) require.Equal(t, cost, state.Cost)
err = <-errChan require.NoError(t, <-errChan)
if err != nil {
t.Fatal(err)
}
} }
// TestLoopInResume tests resuming swaps in various states. // TestLoopInResume tests resuming swaps in various states.
@ -455,34 +444,38 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool,
pendSwap.Loop.Events[0].Cost = cost pendSwap.Loop.Events[0].Cost = cost
} }
scriptVersion := GetHtlcScriptVersion(storedVersion) var (
htlc *swap.Htlc
err error
)
outputType := swap.HtlcP2WSH switch GetHtlcScriptVersion(storedVersion) {
if scriptVersion == swap.HtlcV3 { case swap.HtlcV2:
outputType = swap.HtlcP2TR htlc, err = swap.NewHtlcV2(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, testPreimage.Hash(),
cfg.lnd.ChainParams,
)
case swap.HtlcV3:
htlc, err = swap.NewHtlcV3(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, contract.SenderKey,
contract.ReceiverKey, testPreimage.Hash(),
cfg.lnd.ChainParams,
)
default:
t.Fatalf("unknown HTLC script version")
} }
htlc, err := swap.NewHtlc( require.NoError(t, err)
scriptVersion, contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, testPreimage.Hash(), outputType,
cfg.lnd.ChainParams,
)
if err != nil {
t.Fatal(err)
}
err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract) err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
inSwap, err := resumeLoopInSwap( inSwap, err := resumeLoopInSwap(context.Background(), cfg, pendSwap)
context.Background(), cfg, require.NoError(t, err)
pendSwap,
)
if err != nil {
t.Fatal(err)
}
var height int32 var height int32
if expired { if expired {
@ -501,10 +494,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool,
}() }()
defer func() { defer func() {
err = <-errChan require.NoError(t, <-errChan)
if err != nil {
t.Fatal(err)
}
select { select {
case <-ctx.lnd.SendPaymentChannel: case <-ctx.lnd.SendPaymentChannel:

@ -63,10 +63,7 @@ func newLoopInTestContext(t *testing.T) *loopInTestContext {
func (c *loopInTestContext) assertState(expectedState loopdb.SwapState) { func (c *loopInTestContext) assertState(expectedState loopdb.SwapState) {
state := <-c.statusChan state := <-c.statusChan
if state.State != expectedState { require.Equal(c.t, expectedState, state.State)
c.t.Fatalf("expected state %v but got %v", expectedState,
state.State)
}
} }
// assertSubscribeInvoice asserts that the client subscribes to invoice updates // assertSubscribeInvoice asserts that the client subscribes to invoice updates

@ -201,21 +201,17 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig,
swapKit.lastUpdateTime = initiationTime swapKit.lastUpdateTime = initiationTime
scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion())
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
// Default to using P2WSH for legacy htlcs.
outputType = swap.HtlcP2WSH
}
// Create the htlc. // Create the htlc.
htlc, err := swapKit.getHtlc(outputType) htlc, err := GetHtlc(
swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Log htlc address for debugging. // Log htlc address for debugging.
swapKit.log.Infof("Htlc address: %v", htlc.Address) swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType,
htlc.Address)
// Obtain the payment addr since we'll need it later for routing plugin // Obtain the payment addr since we'll need it later for routing plugin
// recommendation and possibly for cancel. // recommendation and possibly for cancel.
@ -261,15 +257,10 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig,
hash, swap.TypeOut, cfg, &pend.Contract.SwapContract, hash, swap.TypeOut, cfg, &pend.Contract.SwapContract,
) )
scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion)
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
// Default to using P2WSH for legacy htlcs.
outputType = swap.HtlcP2WSH
}
// Create the htlc. // Create the htlc.
htlc, err := swapKit.getHtlc(outputType) htlc, err := GetHtlc(
swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"math" "math"
"reflect"
"testing" "testing"
"time" "time"
@ -66,7 +65,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
blockEpochChan := make(chan interface{}) blockEpochChan := make(chan interface{})
statusChan := make(chan SwapInfo) statusChan := make(chan SwapInfo)
const maxParts = 5 const maxParts = uint32(5)
chanSet := loopdb.ChannelSet{2, 3} chanSet := loopdb.ChannelSet{2, 3}
@ -77,9 +76,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
initResult, err := newLoopOutSwap( initResult, err := newLoopOutSwap(
context.Background(), cfg, height, &req, context.Background(), cfg, height, &req,
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
swap := initResult.swap swap := initResult.swap
// Execute the swap in its own goroutine. // Execute the swap in its own goroutine.
@ -105,9 +102,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
store.assertLoopOutStored() store.assertLoopOutStored()
state := <-statusChan state := <-statusChan
if state.State != loopdb.StateInitiated { require.Equal(t, loopdb.StateInitiated, state.State)
t.Fatal("unexpected state")
}
// Check that the SwapInfo contains the outgoing chan set // Check that the SwapInfo contains the outgoing chan set
require.Equal(t, chanSet, state.OutgoingChanSet) require.Equal(t, chanSet, state.OutgoingChanSet)
@ -130,18 +125,12 @@ func testLoopOutPaymentParameters(t *testing.T) {
} }
// Assert that it is sent as a multi-part payment. // Assert that it is sent as a multi-part payment.
if swapPayment.MaxParts != maxParts { require.Equal(t, maxParts, swapPayment.MaxParts)
t.Fatalf("Expected %v parts, but got %v",
maxParts, swapPayment.MaxParts)
}
// Verify the outgoing channel set restriction. // Verify the outgoing channel set restriction.
if !reflect.DeepEqual( require.Equal(
[]uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds, t, []uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds,
) { )
t.Fatalf("Unexpected outgoing channel set")
}
// Swap is expected to register for confirmation of the htlc. Assert // Swap is expected to register for confirmation of the htlc. Assert
// this to prevent a blocked channel in the mock. // this to prevent a blocked channel in the mock.
@ -152,10 +141,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
cancel() cancel()
// Expect the swap to signal that it was cancelled. // Expect the swap to signal that it was cancelled.
err = <-errChan require.Equal(t, context.Canceled, <-errChan)
if err != context.Canceled {
t.Fatal(err)
}
} }
// TestLateHtlcPublish tests that the client is not revealing the preimage if // TestLateHtlcPublish tests that the client is not revealing the preimage if
@ -198,9 +184,7 @@ func testLateHtlcPublish(t *testing.T) {
initResult, err := newLoopOutSwap( initResult, err := newLoopOutSwap(
context.Background(), cfg, height, testRequest, context.Background(), cfg, height, testRequest,
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
swap := initResult.swap swap := initResult.swap
sweeper := &sweep.Sweeper{Lnd: &lnd.LndServices} sweeper := &sweep.Sweeper{Lnd: &lnd.LndServices}
@ -225,11 +209,8 @@ func testLateHtlcPublish(t *testing.T) {
}() }()
store.assertLoopOutStored() store.assertLoopOutStored()
status := <-statusChan
state := <-statusChan require.Equal(t, loopdb.StateInitiated, status.State)
if state.State != loopdb.StateInitiated {
t.Fatal("unexpected state")
}
signalSwapPaymentResult := ctx.AssertPaid(swapInvoiceDesc) signalSwapPaymentResult := ctx.AssertPaid(swapInvoiceDesc)
signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc)
@ -249,15 +230,9 @@ func testLateHtlcPublish(t *testing.T) {
store.assertStoreFinished(loopdb.StateFailTimeout) store.assertStoreFinished(loopdb.StateFailTimeout)
status := <-statusChan status = <-statusChan
if status.State != loopdb.StateFailTimeout { require.Equal(t, loopdb.StateFailTimeout, status.State)
t.Fatal("unexpected state") require.NoError(t, <-errChan)
}
err = <-errChan
if err != nil {
t.Fatal(err)
}
} }
// TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a // TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a
@ -304,9 +279,7 @@ func testCustomSweepConfTarget(t *testing.T) {
initResult, err := newLoopOutSwap( initResult, err := newLoopOutSwap(
context.Background(), cfg, ctx.Lnd.Height, &testReq, context.Background(), cfg, ctx.Lnd.Height, &testReq,
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
swap := initResult.swap swap := initResult.swap
// Set up the required dependencies to execute the swap. // Set up the required dependencies to execute the swap.
@ -339,9 +312,7 @@ func testCustomSweepConfTarget(t *testing.T) {
// The swap should be found in its initial state. // The swap should be found in its initial state.
cfg.store.(*storeMock).assertLoopOutStored() cfg.store.(*storeMock).assertLoopOutStored()
state := <-statusChan state := <-statusChan
if state.State != loopdb.StateInitiated { require.Equal(t, loopdb.StateInitiated, state.State)
t.Fatal("unexpected state")
}
// We'll then pay both the swap and prepay invoice, which should trigger // We'll then pay both the swap and prepay invoice, which should trigger
// the server to publish the on-chain HTLC. // the server to publish the on-chain HTLC.
@ -381,10 +352,7 @@ func testCustomSweepConfTarget(t *testing.T) {
cfg.store.(*storeMock).assertLoopOutState(loopdb.StatePreimageRevealed) cfg.store.(*storeMock).assertLoopOutState(loopdb.StatePreimageRevealed)
status := <-statusChan status := <-statusChan
if status.State != loopdb.StatePreimageRevealed { require.Equal(t, loopdb.StatePreimageRevealed, status.State)
t.Fatalf("expected state %v, got %v",
loopdb.StatePreimageRevealed, status.State)
}
// When using taproot htlcs the flow is different as we do reveal the // When using taproot htlcs the flow is different as we do reveal the
// preimage before sweeping in order for the server to trust us with // preimage before sweeping in order for the server to trust us with
@ -410,10 +378,10 @@ func testCustomSweepConfTarget(t *testing.T) {
t.Helper() t.Helper()
sweepTx := ctx.ReceiveTx() sweepTx := ctx.ReceiveTx()
if sweepTx.TxIn[0].PreviousOutPoint.Hash != htlcTx.TxHash() { require.Equal(
t.Fatalf("expected sweep tx to spend %v, got %v", t, htlcTx.TxHash(),
htlcTx.TxHash(), sweepTx.TxIn[0].PreviousOutPoint) sweepTx.TxIn[0].PreviousOutPoint.Hash,
} )
// The fee used for the sweep transaction is an estimate based // The fee used for the sweep transaction is an estimate based
// on the maximum witness size, so we should expect to see a // on the maximum witness size, so we should expect to see a
@ -427,16 +395,14 @@ func testCustomSweepConfTarget(t *testing.T) {
feeRate, err := ctx.Lnd.WalletKit.EstimateFeeRate( feeRate, err := ctx.Lnd.WalletKit.EstimateFeeRate(
context.Background(), expConfTarget, context.Background(), expConfTarget,
) )
if err != nil { require.NoError(t, err, "unable to retrieve fee estimate")
t.Fatalf("unable to retrieve fee estimate: %v", err)
}
minFee := feeRate.FeeForWeight(weight) minFee := feeRate.FeeForWeight(weight)
maxFee := btcutil.Amount(float64(minFee) * 1.1) // Just an estimate that works to sanity check fee upper bound.
maxFee := btcutil.Amount(float64(minFee) * 1.5)
if fee < minFee && fee > maxFee { require.GreaterOrEqual(t, fee, minFee)
t.Fatalf("expected sweep tx to have fee between %v-%v, "+ require.LessOrEqual(t, fee, maxFee)
"got %v", minFee, maxFee, fee)
}
return sweepTx return sweepTx
} }
@ -479,14 +445,8 @@ func testCustomSweepConfTarget(t *testing.T) {
cfg.store.(*storeMock).assertLoopOutState(loopdb.StateSuccess) cfg.store.(*storeMock).assertLoopOutState(loopdb.StateSuccess)
status = <-statusChan status = <-statusChan
if status.State != loopdb.StateSuccess { require.Equal(t, loopdb.StateSuccess, status.State)
t.Fatalf("expected state %v, got %v", loopdb.StateSuccess, require.NoError(t, <-errChan)
status.State)
}
if err := <-errChan; err != nil {
t.Fatal(err)
}
} }
// TestPreimagePush tests or logic that decides whether to push our preimage to // TestPreimagePush tests or logic that decides whether to push our preimage to

@ -8,6 +8,7 @@ import (
"github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/test" "github.com/lightninglabs/loop/test"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/stretchr/testify/require"
) )
// storeMock implements a mock client swap store. // storeMock implements a mock client swap store.
@ -239,9 +240,7 @@ func (s *storeMock) assertLoopInState(
s.t.Helper() s.t.Helper()
state := <-s.loopInUpdateChan state := <-s.loopInUpdateChan
if state.State != expectedState { require.Equal(s.t, expectedState, state.State)
s.t.Fatalf("expected state %v, got %v", expectedState, state)
}
return state return state
} }
@ -252,9 +251,8 @@ func (s *storeMock) assertStorePreimageReveal() {
select { select {
case state := <-s.loopOutUpdateChan: case state := <-s.loopOutUpdateChan:
if state.State != loopdb.StatePreimageRevealed { require.Equal(s.t, loopdb.StatePreimageRevealed, state.State)
s.t.Fatalf("unexpected state")
}
case <-time.After(test.Timeout): case <-time.After(test.Timeout):
s.t.Fatalf("expected swap to be marked as preimage revealed") s.t.Fatalf("expected swap to be marked as preimage revealed")
} }
@ -265,10 +263,8 @@ func (s *storeMock) assertStoreFinished(expectedResult loopdb.SwapState) {
select { select {
case state := <-s.loopOutUpdateChan: case state := <-s.loopOutUpdateChan:
if state.State != expectedResult { require.Equal(s.t, expectedResult, state.State)
s.t.Fatalf("expected result %v, but got %v",
expectedResult, state)
}
case <-time.After(test.Timeout): case <-time.After(test.Timeout):
s.t.Fatalf("expected swap to be finished") s.t.Fatalf("expected swap to be finished")
} }

@ -4,6 +4,7 @@ import (
"context" "context"
"time" "time"
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightninglabs/lndclient" "github.com/lightninglabs/lndclient"
"github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/swap"
@ -67,14 +68,28 @@ func IsTaprootSwap(swapContract *loopdb.SwapContract) bool {
return GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3 return GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3
} }
// getHtlc composes and returns the on-chain swap script. // GetHtlc composes and returns the on-chain swap script.
func (s *swapKit) getHtlc(outputType swap.HtlcOutputType) (*swap.Htlc, error) { func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract,
return swap.NewHtlc( chainParams *chaincfg.Params) (*swap.Htlc, error) {
GetHtlcScriptVersion(s.contract.ProtocolVersion),
s.contract.CltvExpiry, s.contract.SenderKey, switch GetHtlcScriptVersion(contract.ProtocolVersion) {
s.contract.ReceiverKey, s.hash, outputType, case swap.HtlcV2:
s.swapConfig.lnd.ChainParams, return swap.NewHtlcV2(
) contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, hash,
chainParams,
)
case swap.HtlcV3:
return swap.NewHtlcV3(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, contract.SenderKey,
contract.ReceiverKey, hash,
chainParams,
)
}
return nil, swap.ErrInvalidScriptVersion
} }
// swapInfo constructs and returns a filled SwapInfo from // swapInfo constructs and returns a filled SwapInfo from

@ -114,16 +114,16 @@ var (
// QuoteHtlcP2WSH is a template script just used for sweep fee // QuoteHtlcP2WSH is a template script just used for sweep fee
// estimation. // estimation.
QuoteHtlcP2WSH, _ = NewHtlc( QuoteHtlcP2WSH, _ = NewHtlcV2(
HtlcV2, ^int32(0), dummyPubKey, dummyPubKey, quoteHash, ^int32(0), dummyPubKey, dummyPubKey, quoteHash,
HtlcP2WSH, &chaincfg.MainNetParams, &chaincfg.MainNetParams,
) )
// QuoteHtlcP2TR is a template script just used for sweep fee // QuoteHtlcP2TR is a template script just used for sweep fee
// estimation. // estimation.
QuoteHtlcP2TR, _ = NewHtlc( QuoteHtlcP2TR, _ = NewHtlcV3(
HtlcV3, ^int32(0), dummyPubKey, dummyPubKey, quoteHash, ^int32(0), dummyPubKey, dummyPubKey, dummyPubKey, dummyPubKey,
HtlcP2TR, &chaincfg.MainNetParams, quoteHash, &chaincfg.MainNetParams,
) )
// ErrInvalidScriptVersion is returned when an unknown htlc version // ErrInvalidScriptVersion is returned when an unknown htlc version
@ -135,6 +135,10 @@ var (
// selected for a v2 script. // selected for a v2 script.
ErrInvalidOutputSelected = fmt.Errorf("taproot output selected for " + ErrInvalidOutputSelected = fmt.Errorf("taproot output selected for " +
"non taproot htlc") "non taproot htlc")
// ErrInvalidOutputType is returned when an unknown output type is
// associated with a certain swap htlc.
ErrInvalidOutputType = fmt.Errorf("invalid htlc output type")
) )
// String returns the string value of HtlcOutputType. // String returns the string value of HtlcOutputType.
@ -151,38 +155,54 @@ func (h HtlcOutputType) String() string {
} }
} }
// NewHtlc returns a new instance. For v3 scripts, an internal pubkey generated // NewHtlcV2 returns a new V2 (P2WSH) HTLC instance.
// by both participants must be provided. func NewHtlcV2(cltvExpiry int32, senderKey, receiverKey [33]byte,
func NewHtlc(version ScriptVersion, cltvExpiry int32, hash lntypes.Hash, chainParams *chaincfg.Params) (*Htlc, error) {
senderKey, receiverKey [33]byte, hash lntypes.Hash,
outputType HtlcOutputType, chainParams *chaincfg.Params) (*Htlc, error) {
var ( htlc, err := newHTLCScriptV2(
err error cltvExpiry, senderKey, receiverKey, hash,
htlc HtlcScript
) )
switch version { if err != nil {
case HtlcV2: return nil, err
htlc, err = newHTLCScriptV2( }
cltvExpiry, senderKey, receiverKey, hash,
)
case HtlcV3:
htlc, err = newHTLCScriptV3(
cltvExpiry, senderKey, receiverKey, hash,
)
default: address, pkScript, sigScript, err := htlc.lockingConditions(
return nil, ErrInvalidScriptVersion HtlcP2WSH, chainParams,
)
if err != nil {
return nil, fmt.Errorf("could not get address: %w", err)
} }
return &Htlc{
HtlcScript: htlc,
Hash: hash,
Version: HtlcV2,
PkScript: pkScript,
OutputType: HtlcP2WSH,
ChainParams: chainParams,
Address: address,
SigScript: sigScript,
}, nil
}
// NewHtlcV3 returns a new V3 HTLC (P2TR) instance. Internal pubkey generated
// by both participants must be provided.
func NewHtlcV3(cltvExpiry int32, senderInternalKey, receiverInternalKey,
senderKey, receiverKey [33]byte, hash lntypes.Hash,
chainParams *chaincfg.Params) (*Htlc, error) {
htlc, err := newHTLCScriptV3(
cltvExpiry, senderInternalKey, receiverInternalKey,
senderKey, receiverKey, hash,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
address, pkScript, sigScript, err := htlc.lockingConditions( address, pkScript, sigScript, err := htlc.lockingConditions(
outputType, chainParams, HtlcP2TR, chainParams,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get address: %w", err) return nil, fmt.Errorf("could not get address: %w", err)
@ -191,9 +211,9 @@ func NewHtlc(version ScriptVersion, cltvExpiry int32,
return &Htlc{ return &Htlc{
HtlcScript: htlc, HtlcScript: htlc,
Hash: hash, Hash: hash,
Version: version, Version: HtlcV3,
PkScript: pkScript, PkScript: pkScript,
OutputType: outputType, OutputType: HtlcP2TR,
ChainParams: chainParams, ChainParams: chainParams,
Address: address, Address: address,
SigScript: sigScript, SigScript: sigScript,
@ -481,7 +501,8 @@ type HtlcScriptV3 struct {
} }
// newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script. // newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script.
func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte, func newHTLCScriptV3(cltvExpiry int32, senderInternalKey, receiverInternalKey,
senderHtlcKey, receiverHtlcKey [33]byte,
swapHash lntypes.Hash) (*HtlcScriptV3, error) { swapHash lntypes.Hash) (*HtlcScriptV3, error) {
senderPubKey, err := schnorr.ParsePubKey(senderHtlcKey[1:]) senderPubKey, err := schnorr.ParsePubKey(senderHtlcKey[1:])
@ -494,13 +515,6 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte,
return nil, err return nil, err
} }
aggregateKey, _, _, err := musig2.AggregateKeys(
[]*btcec.PublicKey{senderPubKey, receiverPubKey}, true,
)
if err != nil {
return nil, err
}
// Create our success path script, we'll use this separately // Create our success path script, we'll use this separately
// to generate the success path leaf. // to generate the success path leaf.
successPathScript, err := GenSuccessPathScript( successPathScript, err := GenSuccessPathScript(
@ -527,6 +541,31 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte,
rootHash := tree.RootNode.TapHash() rootHash := tree.RootNode.TapHash()
// Parse the pub keys used in the internal aggregate key. They are
// optional and may just be the same keys that are used for the script
// paths.
senderInternalPubKey, err := schnorr.ParsePubKey(senderInternalKey[1:])
if err != nil {
return nil, err
}
receiverInternalPubKey, err := schnorr.ParsePubKey(
receiverInternalKey[1:],
)
if err != nil {
return nil, err
}
// Calculate the internal aggregate key.
aggregateKey, _, _, err := musig2.AggregateKeys(
[]*btcec.PublicKey{
senderInternalPubKey, receiverInternalPubKey,
}, true,
)
if err != nil {
return nil, err
}
// Calculate top level taproot key. // Calculate top level taproot key.
taprootKey := txscript.ComputeTaprootOutputKey( taprootKey := txscript.ComputeTaprootOutputKey(
aggregateKey.PreTweakedKey, rootHash[:], aggregateKey.PreTweakedKey, rootHash[:],

@ -54,9 +54,7 @@ func assertEngineExecution(t *testing.T, valid bool,
done := false done := false
for !done { for !done {
dis, err := vm.DisasmPC() dis, err := vm.DisasmPC()
if err != nil { require.NoError(t, err, "stepping")
t.Fatalf("stepping (%v)\n", err)
}
debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis)) debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis))
done, err = vm.Step() done, err = vm.Step()
@ -134,9 +132,9 @@ func TestHtlcV2(t *testing.T) {
hash := sha256.Sum256(testPreimage[:]) hash := sha256.Sum256(testPreimage[:])
// Create the htlc. // Create the htlc.
htlc, err := NewHtlc( htlc, err := NewHtlcV2(
HtlcV2, testCltvExpiry, senderKey, receiverKey, hash, testCltvExpiry, senderKey, receiverKey, hash,
HtlcP2WSH, &chaincfg.MainNetParams, &chaincfg.MainNetParams,
) )
require.NoError(t, err) require.NoError(t, err)
@ -287,10 +285,9 @@ func TestHtlcV2(t *testing.T) {
bogusKey := [33]byte{0xb, 0xa, 0xd} bogusKey := [33]byte{0xb, 0xa, 0xd}
// Create the htlc with the bogus key. // Create the htlc with the bogus key.
htlc, err = NewHtlc( htlc, err = NewHtlcV2(
HtlcV2, testCltvExpiry, testCltvExpiry, bogusKey, receiverKey,
bogusKey, receiverKey, hash, hash, &chaincfg.MainNetParams,
HtlcP2WSH, &chaincfg.MainNetParams,
) )
require.NoError(t, err) require.NoError(t, err)
@ -357,9 +354,9 @@ func TestHtlcV3(t *testing.T) {
copy(receiverKey[:], receiverPubKey.SerializeCompressed()) copy(receiverKey[:], receiverPubKey.SerializeCompressed())
copy(senderKey[:], senderPubKey.SerializeCompressed()) copy(senderKey[:], senderPubKey.SerializeCompressed())
htlc, err := NewHtlc( htlc, err := NewHtlcV3(
HtlcV3, cltvExpiry, senderKey, receiverKey, cltvExpiry, senderKey, receiverKey, senderKey, receiverKey,
hashedPreimage, HtlcP2TR, &chaincfg.MainNetParams, hashedPreimage, &chaincfg.MainNetParams,
) )
require.NoError(t, err) require.NoError(t, err)
@ -540,10 +537,10 @@ func TestHtlcV3(t *testing.T) {
bogusKey.SerializeCompressed(), bogusKey.SerializeCompressed(),
) )
htlc, err := NewHtlc( htlc, err := NewHtlcV3(
HtlcV3, cltvExpiry, bogusKeyBytes, cltvExpiry, senderKey,
receiverKey, hashedPreimage, HtlcP2TR, receiverKey, bogusKeyBytes, receiverKey,
&chaincfg.MainNetParams, hashedPreimage, &chaincfg.MainNetParams,
) )
require.NoError(t, err) require.NoError(t, err)

@ -86,9 +86,11 @@ func (ctx *Context) AssertRegisterSpendNtfn(script []byte) {
select { select {
case spendIntent := <-ctx.Lnd.RegisterSpendChannel: case spendIntent := <-ctx.Lnd.RegisterSpendChannel:
if !bytes.Equal(spendIntent.PkScript, script) { require.Equal(
ctx.T.Fatalf("server not listening for published htlc script") ctx.T, script, spendIntent.PkScript,
} "server not listening for published htlc script",
)
case <-time.After(Timeout): case <-time.After(Timeout):
DumpGoroutines() DumpGoroutines()
ctx.T.Fatalf("spend not subscribed to") ctx.T.Fatalf("spend not subscribed to")
@ -163,10 +165,11 @@ func (ctx *Context) AssertPaid(
payReq := ctx.DecodeInvoice(swapPayment.Invoice) payReq := ctx.DecodeInvoice(swapPayment.Invoice)
if _, ok := ctx.PaidInvoices[*payReq.Description]; ok { _, ok := ctx.PaidInvoices[*payReq.Description]
ctx.T.Fatalf("duplicate invoice paid: %v", require.False(
*payReq.Description) ctx.T, ok,
} "duplicate invoice paid: %v", *payReq.Description,
)
done := func(result error) { done := func(result error) {
if result != nil { if result != nil {
@ -195,9 +198,10 @@ func (ctx *Context) AssertSettled(
select { select {
case preimage := <-ctx.Lnd.SettleInvoiceChannel: case preimage := <-ctx.Lnd.SettleInvoiceChannel:
hash := sha256.Sum256(preimage[:]) hash := sha256.Sum256(preimage[:])
if expectedHash != hash { require.Equal(
ctx.T.Fatalf("server claims with wrong preimage") ctx.T, expectedHash, lntypes.Hash(hash),
} "server claims with wrong preimage",
)
return preimage return preimage
case <-time.After(Timeout): case <-time.After(Timeout):
@ -232,9 +236,8 @@ func (ctx *Context) DecodeInvoice(request string) *zpay32.Invoice {
ctx.T.Helper() ctx.T.Helper()
payReq, err := ctx.Lnd.DecodeInvoice(request) payReq, err := ctx.Lnd.DecodeInvoice(request)
if err != nil { require.NoError(ctx.T, err)
ctx.T.Fatal(err)
}
return payReq return payReq
} }
@ -256,7 +259,5 @@ func (ctx *Context) GetOutputIndex(tx *wire.MsgTx,
// waits for the notification to be processed by selecting on a // waits for the notification to be processed by selecting on a
// dedicated test channel. // dedicated test channel.
func (ctx *Context) NotifyServerHeight(height int32) { func (ctx *Context) NotifyServerHeight(height int32) {
if err := ctx.Lnd.NotifyHeight(height); err != nil { require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height))
ctx.T.Fatal(err)
}
} }

@ -14,6 +14,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/zpay32" "github.com/lightningnetwork/lnd/zpay32"
"github.com/stretchr/testify/require"
) )
var ( var (
@ -29,11 +30,10 @@ var (
// GetDestAddr deterministically generates a sweep address for testing. // GetDestAddr deterministically generates a sweep address for testing.
func GetDestAddr(t *testing.T, nr byte) btcutil.Address { func GetDestAddr(t *testing.T, nr byte) btcutil.Address {
destAddr, err := btcutil.NewAddressScriptHash([]byte{nr}, destAddr, err := btcutil.NewAddressScriptHash(
&chaincfg.MainNetParams) []byte{nr}, &chaincfg.MainNetParams,
if err != nil { )
t.Fatal(err) require.NoError(t, err)
}
return destAddr return destAddr
} }

@ -140,9 +140,8 @@ func (ctx *testContext) finish() {
ctx.stop() ctx.stop()
select { select {
case err := <-ctx.runErr: case err := <-ctx.runErr:
if err != nil { require.NoError(ctx.T, err)
ctx.T.Fatal(err)
}
case <-time.After(test.Timeout): case <-time.After(test.Timeout):
ctx.T.Fatal("client not stopping") ctx.T.Fatal("client not stopping")
} }
@ -156,19 +155,12 @@ func (ctx *testContext) finish() {
func (ctx *testContext) notifyHeight(height int32) { func (ctx *testContext) notifyHeight(height int32) {
ctx.T.Helper() ctx.T.Helper()
if err := ctx.Lnd.NotifyHeight(height); err != nil { require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height))
ctx.T.Fatal(err)
}
} }
func (ctx *testContext) assertIsDone() { func (ctx *testContext) assertIsDone() {
if err := ctx.Lnd.IsDone(); err != nil { require.NoError(ctx.T, ctx.Lnd.IsDone())
ctx.T.Fatal(err) require.NoError(ctx.T, ctx.store.isDone())
}
if err := ctx.store.isDone(); err != nil {
ctx.T.Fatal(err)
}
select { select {
case <-ctx.statusChan: case <-ctx.statusChan:

Loading…
Cancel
Save