Merge pull request #541 from bhandras/htlc-v3-interal-key

swap: refactor htlc construction to allow passing of internal keys
pull/527/head
András Bánki-Horváth 1 year ago committed by GitHub
commit 446f163530
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -213,29 +213,24 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
SwapHash: swp.Hash,
LastUpdate: swp.LastUpdateTime(),
}
scriptVersion := GetHtlcScriptVersion(
swp.Contract.ProtocolVersion,
)
outputType := swap.HtlcP2WSH
if scriptVersion == swap.HtlcV3 {
outputType = swap.HtlcP2TR
}
htlc, err := swap.NewHtlc(
scriptVersion,
swp.Contract.CltvExpiry, swp.Contract.SenderKey,
swp.Contract.ReceiverKey, swp.Hash,
outputType, s.lndServices.ChainParams,
htlc, err := GetHtlc(
swp.Hash, &swp.Contract.SwapContract,
s.lndServices.ChainParams,
)
if err != nil {
return nil, err
}
if outputType == swap.HtlcP2TR {
swapInfo.HtlcAddressP2TR = htlc.Address
} else {
switch htlc.OutputType {
case swap.HtlcP2WSH:
swapInfo.HtlcAddressP2WSH = htlc.Address
case swap.HtlcP2TR:
swapInfo.HtlcAddressP2TR = htlc.Address
default:
return nil, swap.ErrInvalidOutputType
}
swaps = append(swaps, swapInfo)
@ -250,34 +245,23 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
LastUpdate: swp.LastUpdateTime(),
}
scriptVersion := GetHtlcScriptVersion(
swp.Contract.SwapContract.ProtocolVersion,
htlc, err := GetHtlc(
swp.Hash, &swp.Contract.SwapContract,
s.lndServices.ChainParams,
)
if err != nil {
return nil, err
}
if scriptVersion == swap.HtlcV3 {
htlcP2TR, err := swap.NewHtlc(
swap.HtlcV3, swp.Contract.CltvExpiry,
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
swp.Hash, swap.HtlcP2TR,
s.lndServices.ChainParams,
)
if err != nil {
return nil, err
}
switch htlc.OutputType {
case swap.HtlcP2WSH:
swapInfo.HtlcAddressP2WSH = htlc.Address
swapInfo.HtlcAddressP2TR = htlcP2TR.Address
} else {
htlcP2WSH, err := swap.NewHtlc(
swap.HtlcV2, swp.Contract.CltvExpiry,
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
swp.Hash, swap.HtlcP2WSH,
s.lndServices.ChainParams,
)
if err != nil {
return nil, err
}
case swap.HtlcP2TR:
swapInfo.HtlcAddressP2TR = htlc.Address
swapInfo.HtlcAddressP2WSH = htlcP2WSH.Address
default:
return nil, swap.ErrInvalidOutputType
}
swaps = append(swaps, swapInfo)

@ -1,7 +1,6 @@
package loop
import (
"bytes"
"context"
"crypto/sha256"
"errors"
@ -57,9 +56,7 @@ func TestLoopOutSuccess(t *testing.T) {
// Initiate loop out.
info, err := ctx.swapClient.LoopOut(context.Background(), &req)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
ctx.assertStored()
ctx.assertStatus(loopdb.StateInitiated)
@ -84,9 +81,7 @@ func TestLoopOutFailOffchain(t *testing.T) {
ctx := createClientTestContext(t, nil)
_, err := ctx.swapClient.LoopOut(context.Background(), testRequest)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
ctx.assertStored()
ctx.assertStatus(loopdb.StateInitiated)
@ -208,14 +203,10 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
amt := btcutil.Amount(50000)
swapPayReq, err := getInvoice(hash, amt, swapInvoiceDesc)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
prePayReq, err := getInvoice(hash, 100, prepayInvoiceDesc)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
_, senderPubKey := test.CreateKey(1)
var senderKey [33]byte
@ -284,16 +275,26 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
// Assert that the loopout htlc equals to the expected one.
scriptVersion := GetHtlcScriptVersion(protocolVersion)
var htlc *swap.Htlc
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
outputType = swap.HtlcP2WSH
switch scriptVersion {
case swap.HtlcV2:
htlc, err = swap.NewHtlcV2(
pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, hash, &chaincfg.TestNet3Params,
)
case swap.HtlcV3:
htlc, err = swap.NewHtlcV3(
pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, senderKey, receiverKey, hash,
&chaincfg.TestNet3Params,
)
default:
t.Fatalf(swap.ErrInvalidScriptVersion.Error())
}
htlc, err := swap.NewHtlc(
scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, hash, outputType, &chaincfg.TestNet3Params,
)
require.NoError(t, err)
require.Equal(t, htlc.PkScript, confIntent.PkScript)
@ -363,10 +364,11 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash,
// Expect client on-chain sweep of HTLC.
sweepTx := ctx.ReceiveTx()
if !bytes.Equal(sweepTx.TxIn[0].PreviousOutPoint.Hash[:],
htlcOutpoint.Hash[:]) {
ctx.T.Fatalf("client not sweeping from htlc tx")
}
require.Equal(
ctx.T, htlcOutpoint.Hash[:],
sweepTx.TxIn[0].PreviousOutPoint.Hash[:],
"client not sweeping from htlc tx",
)
var preImageIndex int
switch scriptVersion {
@ -380,9 +382,7 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash,
// Check preimage.
clientPreImage := sweepTx.TxIn[0].Witness[preImageIndex]
clientPreImageHash := sha256.Sum256(clientPreImage)
if clientPreImageHash != hash {
ctx.T.Fatalf("incorrect preimage")
}
require.Equal(ctx.T, hash, lntypes.Hash(clientPreImageHash))
// Since we successfully published our sweep, we expect the preimage to
// have been pushed to our mock server.

@ -130,16 +130,13 @@ func TestValidateConfTarget(t *testing.T) {
test.confTarget, defaultConf,
)
haveErr := err != nil
if haveErr != test.expectErr {
t.Fatalf("expected err: %v, got: %v",
test.expectErr, err)
if test.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
if target != test.expectedTarget {
t.Fatalf("expected: %v, got: %v",
test.expectedTarget, target)
}
require.Equal(t, test.expectedTarget, target)
})
}
}
@ -199,16 +196,13 @@ func TestValidateLoopInRequest(t *testing.T) {
test.confTarget, external,
)
haveErr := err != nil
if haveErr != test.expectErr {
t.Fatalf("expected err: %v, got: %v",
test.expectErr, err)
if test.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
if conf != test.expectedTarget {
t.Fatalf("expected: %v, got: %v",
test.expectedTarget, conf)
}
require.Equal(t, test.expectedTarget, conf)
})
}
}

@ -7,7 +7,6 @@ import (
"github.com/lightninglabs/lndclient"
"github.com/lightninglabs/loop"
"github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/swap"
)
// view prints all swaps currently in the database.
@ -49,24 +48,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
}
for _, s := range swaps {
scriptVersion := loop.GetHtlcScriptVersion(
s.Contract.ProtocolVersion,
)
var outputType swap.HtlcOutputType
switch scriptVersion {
case swap.HtlcV2:
outputType = swap.HtlcP2WSH
case swap.HtlcV3:
outputType = swap.HtlcP2TR
}
htlc, err := swap.NewHtlc(
loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion),
s.Contract.CltvExpiry,
s.Contract.SenderKey,
s.Contract.ReceiverKey,
s.Hash, outputType, chainParams,
htlc, err := loop.GetHtlc(
s.Hash, &s.Contract.SwapContract, chainParams,
)
if err != nil {
return err
@ -77,7 +60,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
s.Contract.InitiationTime, s.Contract.InitiationHeight,
)
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
fmt.Printf(" Htlc address: %v\n", htlc.Address)
fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType,
htlc.Address)
fmt.Printf(" Uncharge channels: %v\n",
s.Contract.OutgoingChanSet)
@ -113,12 +97,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
}
for _, s := range swaps {
htlc, err := swap.NewHtlc(
loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion),
s.Contract.CltvExpiry,
s.Contract.SenderKey,
s.Contract.ReceiverKey,
s.Hash, swap.HtlcP2WSH, chainParams,
htlc, err := loop.GetHtlc(
s.Hash, &s.Contract.SwapContract, chainParams,
)
if err != nil {
return err
@ -129,7 +109,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
s.Contract.InitiationTime, s.Contract.InitiationHeight,
)
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
fmt.Printf(" Htlc address: %v\n", htlc.Address)
fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType,
htlc.Address)
fmt.Printf(" Amt: %v, Expiry: %v\n",
s.Contract.AmountRequested, s.Contract.CltvExpiry,
)

@ -404,26 +404,26 @@ func validateLoopInContract(lnd *lndclient.LndServices,
// initHtlcs creates and updates the native and nested segwit htlcs
// of the loopInSwap.
func (s *loopInSwap) initHtlcs() error {
if IsTaprootSwap(&s.SwapContract) {
htlcP2TR, err := s.swapKit.getHtlc(swap.HtlcP2TR)
if err != nil {
return err
}
htlc, err := GetHtlc(
s.hash, &s.SwapContract, s.swapKit.lnd.ChainParams,
)
if err != nil {
return err
}
s.swapKit.log.Infof("Htlc address (P2TR): %v", htlcP2TR.Address)
s.htlcP2TR = htlcP2TR
switch htlc.OutputType {
case swap.HtlcP2WSH:
s.htlcP2WSH = htlc
return nil
}
case swap.HtlcP2TR:
s.htlcP2TR = htlc
htlcP2WSH, err := s.swapKit.getHtlc(swap.HtlcP2WSH)
if err != nil {
return err
default:
return fmt.Errorf("invalid output type")
}
// Log htlc addresses for debugging.
s.swapKit.log.Infof("Htlc address (P2WSH): %v", htlcP2WSH.Address)
s.htlcP2WSH = htlcP2WSH
s.swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType,
htlc.Address)
return nil
}

@ -58,9 +58,8 @@ func testLoopInSuccess(t *testing.T) {
context.Background(), cfg,
height, req,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
inSwap := initResult.swap
ctx.store.assertLoopInStored()
@ -142,10 +141,7 @@ func testLoopInSuccess(t *testing.T) {
ctx.assertState(loopdb.StateSuccess)
ctx.store.assertLoopInState(loopdb.StateSuccess)
err = <-errChan
if err != nil {
t.Fatal(err)
}
require.NoError(t, <-errChan)
}
// TestLoopInTimeout tests scenarios where the server doesn't sweep the htlc
@ -215,9 +211,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
context.Background(), cfg,
height, &req,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
inSwap := initResult.swap
ctx.store.assertLoopInStored()
@ -289,11 +283,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
ctx.assertState(loopdb.StateFailIncorrectHtlcAmt)
ctx.store.assertLoopInState(loopdb.StateFailIncorrectHtlcAmt)
err = <-errChan
if err != nil {
t.Fatal(err)
}
require.NoError(t, <-errChan)
return
}
@ -308,9 +298,11 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
// Expect a signing request for the htlc tx output value.
signReq := <-ctx.lnd.SignOutputRawChannel
if signReq.SignDescriptors[0].Output.Value != htlcTx.TxOut[0].Value {
t.Fatal("invalid signing amount")
}
require.Equal(
t, htlcTx.TxOut[0].Value,
signReq.SignDescriptors[0].Output.Value,
"invalid signing amount",
)
// Expect timeout tx to be published.
timeoutTx := <-ctx.lnd.TxPublishChannel
@ -341,10 +333,7 @@ func testLoopInTimeout(t *testing.T, externalValue int64) {
state := ctx.store.assertLoopInState(loopdb.StateFailTimeout)
require.Equal(t, cost, state.Cost)
err = <-errChan
if err != nil {
t.Fatal(err)
}
require.NoError(t, <-errChan)
}
// TestLoopInResume tests resuming swaps in various states.
@ -455,34 +444,38 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool,
pendSwap.Loop.Events[0].Cost = cost
}
scriptVersion := GetHtlcScriptVersion(storedVersion)
var (
htlc *swap.Htlc
err error
)
outputType := swap.HtlcP2WSH
if scriptVersion == swap.HtlcV3 {
outputType = swap.HtlcP2TR
switch GetHtlcScriptVersion(storedVersion) {
case swap.HtlcV2:
htlc, err = swap.NewHtlcV2(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, testPreimage.Hash(),
cfg.lnd.ChainParams,
)
case swap.HtlcV3:
htlc, err = swap.NewHtlcV3(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, contract.SenderKey,
contract.ReceiverKey, testPreimage.Hash(),
cfg.lnd.ChainParams,
)
default:
t.Fatalf("unknown HTLC script version")
}
htlc, err := swap.NewHtlc(
scriptVersion, contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, testPreimage.Hash(), outputType,
cfg.lnd.ChainParams,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
inSwap, err := resumeLoopInSwap(
context.Background(), cfg,
pendSwap,
)
if err != nil {
t.Fatal(err)
}
inSwap, err := resumeLoopInSwap(context.Background(), cfg, pendSwap)
require.NoError(t, err)
var height int32
if expired {
@ -501,10 +494,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool,
}()
defer func() {
err = <-errChan
if err != nil {
t.Fatal(err)
}
require.NoError(t, <-errChan)
select {
case <-ctx.lnd.SendPaymentChannel:

@ -63,10 +63,7 @@ func newLoopInTestContext(t *testing.T) *loopInTestContext {
func (c *loopInTestContext) assertState(expectedState loopdb.SwapState) {
state := <-c.statusChan
if state.State != expectedState {
c.t.Fatalf("expected state %v but got %v", expectedState,
state.State)
}
require.Equal(c.t, expectedState, state.State)
}
// assertSubscribeInvoice asserts that the client subscribes to invoice updates

@ -201,21 +201,17 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig,
swapKit.lastUpdateTime = initiationTime
scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion())
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
// Default to using P2WSH for legacy htlcs.
outputType = swap.HtlcP2WSH
}
// Create the htlc.
htlc, err := swapKit.getHtlc(outputType)
htlc, err := GetHtlc(
swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams,
)
if err != nil {
return nil, err
}
// Log htlc address for debugging.
swapKit.log.Infof("Htlc address: %v", htlc.Address)
swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType,
htlc.Address)
// Obtain the payment addr since we'll need it later for routing plugin
// recommendation and possibly for cancel.
@ -261,15 +257,10 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig,
hash, swap.TypeOut, cfg, &pend.Contract.SwapContract,
)
scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion)
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
// Default to using P2WSH for legacy htlcs.
outputType = swap.HtlcP2WSH
}
// Create the htlc.
htlc, err := swapKit.getHtlc(outputType)
htlc, err := GetHtlc(
swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams,
)
if err != nil {
return nil, err
}

@ -4,7 +4,6 @@ import (
"context"
"errors"
"math"
"reflect"
"testing"
"time"
@ -66,7 +65,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
blockEpochChan := make(chan interface{})
statusChan := make(chan SwapInfo)
const maxParts = 5
const maxParts = uint32(5)
chanSet := loopdb.ChannelSet{2, 3}
@ -77,9 +76,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
initResult, err := newLoopOutSwap(
context.Background(), cfg, height, &req,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
swap := initResult.swap
// Execute the swap in its own goroutine.
@ -105,9 +102,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
store.assertLoopOutStored()
state := <-statusChan
if state.State != loopdb.StateInitiated {
t.Fatal("unexpected state")
}
require.Equal(t, loopdb.StateInitiated, state.State)
// Check that the SwapInfo contains the outgoing chan set
require.Equal(t, chanSet, state.OutgoingChanSet)
@ -130,18 +125,12 @@ func testLoopOutPaymentParameters(t *testing.T) {
}
// Assert that it is sent as a multi-part payment.
if swapPayment.MaxParts != maxParts {
t.Fatalf("Expected %v parts, but got %v",
maxParts, swapPayment.MaxParts)
}
require.Equal(t, maxParts, swapPayment.MaxParts)
// Verify the outgoing channel set restriction.
if !reflect.DeepEqual(
[]uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds,
) {
t.Fatalf("Unexpected outgoing channel set")
}
require.Equal(
t, []uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds,
)
// Swap is expected to register for confirmation of the htlc. Assert
// this to prevent a blocked channel in the mock.
@ -152,10 +141,7 @@ func testLoopOutPaymentParameters(t *testing.T) {
cancel()
// Expect the swap to signal that it was cancelled.
err = <-errChan
if err != context.Canceled {
t.Fatal(err)
}
require.Equal(t, context.Canceled, <-errChan)
}
// TestLateHtlcPublish tests that the client is not revealing the preimage if
@ -198,9 +184,7 @@ func testLateHtlcPublish(t *testing.T) {
initResult, err := newLoopOutSwap(
context.Background(), cfg, height, testRequest,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
swap := initResult.swap
sweeper := &sweep.Sweeper{Lnd: &lnd.LndServices}
@ -225,11 +209,8 @@ func testLateHtlcPublish(t *testing.T) {
}()
store.assertLoopOutStored()
state := <-statusChan
if state.State != loopdb.StateInitiated {
t.Fatal("unexpected state")
}
status := <-statusChan
require.Equal(t, loopdb.StateInitiated, status.State)
signalSwapPaymentResult := ctx.AssertPaid(swapInvoiceDesc)
signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc)
@ -249,15 +230,9 @@ func testLateHtlcPublish(t *testing.T) {
store.assertStoreFinished(loopdb.StateFailTimeout)
status := <-statusChan
if status.State != loopdb.StateFailTimeout {
t.Fatal("unexpected state")
}
err = <-errChan
if err != nil {
t.Fatal(err)
}
status = <-statusChan
require.Equal(t, loopdb.StateFailTimeout, status.State)
require.NoError(t, <-errChan)
}
// TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a
@ -304,9 +279,7 @@ func testCustomSweepConfTarget(t *testing.T) {
initResult, err := newLoopOutSwap(
context.Background(), cfg, ctx.Lnd.Height, &testReq,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
swap := initResult.swap
// Set up the required dependencies to execute the swap.
@ -339,9 +312,7 @@ func testCustomSweepConfTarget(t *testing.T) {
// The swap should be found in its initial state.
cfg.store.(*storeMock).assertLoopOutStored()
state := <-statusChan
if state.State != loopdb.StateInitiated {
t.Fatal("unexpected state")
}
require.Equal(t, loopdb.StateInitiated, state.State)
// We'll then pay both the swap and prepay invoice, which should trigger
// the server to publish the on-chain HTLC.
@ -381,10 +352,7 @@ func testCustomSweepConfTarget(t *testing.T) {
cfg.store.(*storeMock).assertLoopOutState(loopdb.StatePreimageRevealed)
status := <-statusChan
if status.State != loopdb.StatePreimageRevealed {
t.Fatalf("expected state %v, got %v",
loopdb.StatePreimageRevealed, status.State)
}
require.Equal(t, loopdb.StatePreimageRevealed, status.State)
// When using taproot htlcs the flow is different as we do reveal the
// preimage before sweeping in order for the server to trust us with
@ -410,10 +378,10 @@ func testCustomSweepConfTarget(t *testing.T) {
t.Helper()
sweepTx := ctx.ReceiveTx()
if sweepTx.TxIn[0].PreviousOutPoint.Hash != htlcTx.TxHash() {
t.Fatalf("expected sweep tx to spend %v, got %v",
htlcTx.TxHash(), sweepTx.TxIn[0].PreviousOutPoint)
}
require.Equal(
t, htlcTx.TxHash(),
sweepTx.TxIn[0].PreviousOutPoint.Hash,
)
// The fee used for the sweep transaction is an estimate based
// on the maximum witness size, so we should expect to see a
@ -427,16 +395,14 @@ func testCustomSweepConfTarget(t *testing.T) {
feeRate, err := ctx.Lnd.WalletKit.EstimateFeeRate(
context.Background(), expConfTarget,
)
if err != nil {
t.Fatalf("unable to retrieve fee estimate: %v", err)
}
require.NoError(t, err, "unable to retrieve fee estimate")
minFee := feeRate.FeeForWeight(weight)
maxFee := btcutil.Amount(float64(minFee) * 1.1)
// Just an estimate that works to sanity check fee upper bound.
maxFee := btcutil.Amount(float64(minFee) * 1.5)
if fee < minFee && fee > maxFee {
t.Fatalf("expected sweep tx to have fee between %v-%v, "+
"got %v", minFee, maxFee, fee)
}
require.GreaterOrEqual(t, fee, minFee)
require.LessOrEqual(t, fee, maxFee)
return sweepTx
}
@ -479,14 +445,8 @@ func testCustomSweepConfTarget(t *testing.T) {
cfg.store.(*storeMock).assertLoopOutState(loopdb.StateSuccess)
status = <-statusChan
if status.State != loopdb.StateSuccess {
t.Fatalf("expected state %v, got %v", loopdb.StateSuccess,
status.State)
}
if err := <-errChan; err != nil {
t.Fatal(err)
}
require.Equal(t, loopdb.StateSuccess, status.State)
require.NoError(t, <-errChan)
}
// TestPreimagePush tests or logic that decides whether to push our preimage to

@ -8,6 +8,7 @@ import (
"github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/test"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/stretchr/testify/require"
)
// storeMock implements a mock client swap store.
@ -239,9 +240,7 @@ func (s *storeMock) assertLoopInState(
s.t.Helper()
state := <-s.loopInUpdateChan
if state.State != expectedState {
s.t.Fatalf("expected state %v, got %v", expectedState, state)
}
require.Equal(s.t, expectedState, state.State)
return state
}
@ -252,9 +251,8 @@ func (s *storeMock) assertStorePreimageReveal() {
select {
case state := <-s.loopOutUpdateChan:
if state.State != loopdb.StatePreimageRevealed {
s.t.Fatalf("unexpected state")
}
require.Equal(s.t, loopdb.StatePreimageRevealed, state.State)
case <-time.After(test.Timeout):
s.t.Fatalf("expected swap to be marked as preimage revealed")
}
@ -265,10 +263,8 @@ func (s *storeMock) assertStoreFinished(expectedResult loopdb.SwapState) {
select {
case state := <-s.loopOutUpdateChan:
if state.State != expectedResult {
s.t.Fatalf("expected result %v, but got %v",
expectedResult, state)
}
require.Equal(s.t, expectedResult, state.State)
case <-time.After(test.Timeout):
s.t.Fatalf("expected swap to be finished")
}

@ -4,6 +4,7 @@ import (
"context"
"time"
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightninglabs/lndclient"
"github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/swap"
@ -67,14 +68,28 @@ func IsTaprootSwap(swapContract *loopdb.SwapContract) bool {
return GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3
}
// getHtlc composes and returns the on-chain swap script.
func (s *swapKit) getHtlc(outputType swap.HtlcOutputType) (*swap.Htlc, error) {
return swap.NewHtlc(
GetHtlcScriptVersion(s.contract.ProtocolVersion),
s.contract.CltvExpiry, s.contract.SenderKey,
s.contract.ReceiverKey, s.hash, outputType,
s.swapConfig.lnd.ChainParams,
)
// GetHtlc composes and returns the on-chain swap script.
func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract,
chainParams *chaincfg.Params) (*swap.Htlc, error) {
switch GetHtlcScriptVersion(contract.ProtocolVersion) {
case swap.HtlcV2:
return swap.NewHtlcV2(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, hash,
chainParams,
)
case swap.HtlcV3:
return swap.NewHtlcV3(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, contract.SenderKey,
contract.ReceiverKey, hash,
chainParams,
)
}
return nil, swap.ErrInvalidScriptVersion
}
// swapInfo constructs and returns a filled SwapInfo from

@ -114,16 +114,16 @@ var (
// QuoteHtlcP2WSH is a template script just used for sweep fee
// estimation.
QuoteHtlcP2WSH, _ = NewHtlc(
HtlcV2, ^int32(0), dummyPubKey, dummyPubKey, quoteHash,
HtlcP2WSH, &chaincfg.MainNetParams,
QuoteHtlcP2WSH, _ = NewHtlcV2(
^int32(0), dummyPubKey, dummyPubKey, quoteHash,
&chaincfg.MainNetParams,
)
// QuoteHtlcP2TR is a template script just used for sweep fee
// estimation.
QuoteHtlcP2TR, _ = NewHtlc(
HtlcV3, ^int32(0), dummyPubKey, dummyPubKey, quoteHash,
HtlcP2TR, &chaincfg.MainNetParams,
QuoteHtlcP2TR, _ = NewHtlcV3(
^int32(0), dummyPubKey, dummyPubKey, dummyPubKey, dummyPubKey,
quoteHash, &chaincfg.MainNetParams,
)
// ErrInvalidScriptVersion is returned when an unknown htlc version
@ -135,6 +135,10 @@ var (
// selected for a v2 script.
ErrInvalidOutputSelected = fmt.Errorf("taproot output selected for " +
"non taproot htlc")
// ErrInvalidOutputType is returned when an unknown output type is
// associated with a certain swap htlc.
ErrInvalidOutputType = fmt.Errorf("invalid htlc output type")
)
// String returns the string value of HtlcOutputType.
@ -151,38 +155,54 @@ func (h HtlcOutputType) String() string {
}
}
// NewHtlc returns a new instance. For v3 scripts, an internal pubkey generated
// by both participants must be provided.
func NewHtlc(version ScriptVersion, cltvExpiry int32,
senderKey, receiverKey [33]byte, hash lntypes.Hash,
outputType HtlcOutputType, chainParams *chaincfg.Params) (*Htlc, error) {
// NewHtlcV2 returns a new V2 (P2WSH) HTLC instance.
func NewHtlcV2(cltvExpiry int32, senderKey, receiverKey [33]byte,
hash lntypes.Hash, chainParams *chaincfg.Params) (*Htlc, error) {
var (
err error
htlc HtlcScript
htlc, err := newHTLCScriptV2(
cltvExpiry, senderKey, receiverKey, hash,
)
switch version {
case HtlcV2:
htlc, err = newHTLCScriptV2(
cltvExpiry, senderKey, receiverKey, hash,
)
case HtlcV3:
htlc, err = newHTLCScriptV3(
cltvExpiry, senderKey, receiverKey, hash,
)
if err != nil {
return nil, err
}
default:
return nil, ErrInvalidScriptVersion
address, pkScript, sigScript, err := htlc.lockingConditions(
HtlcP2WSH, chainParams,
)
if err != nil {
return nil, fmt.Errorf("could not get address: %w", err)
}
return &Htlc{
HtlcScript: htlc,
Hash: hash,
Version: HtlcV2,
PkScript: pkScript,
OutputType: HtlcP2WSH,
ChainParams: chainParams,
Address: address,
SigScript: sigScript,
}, nil
}
// NewHtlcV3 returns a new V3 HTLC (P2TR) instance. Internal pubkey generated
// by both participants must be provided.
func NewHtlcV3(cltvExpiry int32, senderInternalKey, receiverInternalKey,
senderKey, receiverKey [33]byte, hash lntypes.Hash,
chainParams *chaincfg.Params) (*Htlc, error) {
htlc, err := newHTLCScriptV3(
cltvExpiry, senderInternalKey, receiverInternalKey,
senderKey, receiverKey, hash,
)
if err != nil {
return nil, err
}
address, pkScript, sigScript, err := htlc.lockingConditions(
outputType, chainParams,
HtlcP2TR, chainParams,
)
if err != nil {
return nil, fmt.Errorf("could not get address: %w", err)
@ -191,9 +211,9 @@ func NewHtlc(version ScriptVersion, cltvExpiry int32,
return &Htlc{
HtlcScript: htlc,
Hash: hash,
Version: version,
Version: HtlcV3,
PkScript: pkScript,
OutputType: outputType,
OutputType: HtlcP2TR,
ChainParams: chainParams,
Address: address,
SigScript: sigScript,
@ -481,7 +501,8 @@ type HtlcScriptV3 struct {
}
// newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script.
func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte,
func newHTLCScriptV3(cltvExpiry int32, senderInternalKey, receiverInternalKey,
senderHtlcKey, receiverHtlcKey [33]byte,
swapHash lntypes.Hash) (*HtlcScriptV3, error) {
senderPubKey, err := schnorr.ParsePubKey(senderHtlcKey[1:])
@ -494,13 +515,6 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte,
return nil, err
}
aggregateKey, _, _, err := musig2.AggregateKeys(
[]*btcec.PublicKey{senderPubKey, receiverPubKey}, true,
)
if err != nil {
return nil, err
}
// Create our success path script, we'll use this separately
// to generate the success path leaf.
successPathScript, err := GenSuccessPathScript(
@ -527,6 +541,31 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte,
rootHash := tree.RootNode.TapHash()
// Parse the pub keys used in the internal aggregate key. They are
// optional and may just be the same keys that are used for the script
// paths.
senderInternalPubKey, err := schnorr.ParsePubKey(senderInternalKey[1:])
if err != nil {
return nil, err
}
receiverInternalPubKey, err := schnorr.ParsePubKey(
receiverInternalKey[1:],
)
if err != nil {
return nil, err
}
// Calculate the internal aggregate key.
aggregateKey, _, _, err := musig2.AggregateKeys(
[]*btcec.PublicKey{
senderInternalPubKey, receiverInternalPubKey,
}, true,
)
if err != nil {
return nil, err
}
// Calculate top level taproot key.
taprootKey := txscript.ComputeTaprootOutputKey(
aggregateKey.PreTweakedKey, rootHash[:],

@ -54,9 +54,7 @@ func assertEngineExecution(t *testing.T, valid bool,
done := false
for !done {
dis, err := vm.DisasmPC()
if err != nil {
t.Fatalf("stepping (%v)\n", err)
}
require.NoError(t, err, "stepping")
debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis))
done, err = vm.Step()
@ -134,9 +132,9 @@ func TestHtlcV2(t *testing.T) {
hash := sha256.Sum256(testPreimage[:])
// Create the htlc.
htlc, err := NewHtlc(
HtlcV2, testCltvExpiry, senderKey, receiverKey, hash,
HtlcP2WSH, &chaincfg.MainNetParams,
htlc, err := NewHtlcV2(
testCltvExpiry, senderKey, receiverKey, hash,
&chaincfg.MainNetParams,
)
require.NoError(t, err)
@ -287,10 +285,9 @@ func TestHtlcV2(t *testing.T) {
bogusKey := [33]byte{0xb, 0xa, 0xd}
// Create the htlc with the bogus key.
htlc, err = NewHtlc(
HtlcV2, testCltvExpiry,
bogusKey, receiverKey, hash,
HtlcP2WSH, &chaincfg.MainNetParams,
htlc, err = NewHtlcV2(
testCltvExpiry, bogusKey, receiverKey,
hash, &chaincfg.MainNetParams,
)
require.NoError(t, err)
@ -357,9 +354,9 @@ func TestHtlcV3(t *testing.T) {
copy(receiverKey[:], receiverPubKey.SerializeCompressed())
copy(senderKey[:], senderPubKey.SerializeCompressed())
htlc, err := NewHtlc(
HtlcV3, cltvExpiry, senderKey, receiverKey,
hashedPreimage, HtlcP2TR, &chaincfg.MainNetParams,
htlc, err := NewHtlcV3(
cltvExpiry, senderKey, receiverKey, senderKey, receiverKey,
hashedPreimage, &chaincfg.MainNetParams,
)
require.NoError(t, err)
@ -540,10 +537,10 @@ func TestHtlcV3(t *testing.T) {
bogusKey.SerializeCompressed(),
)
htlc, err := NewHtlc(
HtlcV3, cltvExpiry, bogusKeyBytes,
receiverKey, hashedPreimage, HtlcP2TR,
&chaincfg.MainNetParams,
htlc, err := NewHtlcV3(
cltvExpiry, senderKey,
receiverKey, bogusKeyBytes, receiverKey,
hashedPreimage, &chaincfg.MainNetParams,
)
require.NoError(t, err)

@ -86,9 +86,11 @@ func (ctx *Context) AssertRegisterSpendNtfn(script []byte) {
select {
case spendIntent := <-ctx.Lnd.RegisterSpendChannel:
if !bytes.Equal(spendIntent.PkScript, script) {
ctx.T.Fatalf("server not listening for published htlc script")
}
require.Equal(
ctx.T, script, spendIntent.PkScript,
"server not listening for published htlc script",
)
case <-time.After(Timeout):
DumpGoroutines()
ctx.T.Fatalf("spend not subscribed to")
@ -163,10 +165,11 @@ func (ctx *Context) AssertPaid(
payReq := ctx.DecodeInvoice(swapPayment.Invoice)
if _, ok := ctx.PaidInvoices[*payReq.Description]; ok {
ctx.T.Fatalf("duplicate invoice paid: %v",
*payReq.Description)
}
_, ok := ctx.PaidInvoices[*payReq.Description]
require.False(
ctx.T, ok,
"duplicate invoice paid: %v", *payReq.Description,
)
done := func(result error) {
if result != nil {
@ -195,9 +198,10 @@ func (ctx *Context) AssertSettled(
select {
case preimage := <-ctx.Lnd.SettleInvoiceChannel:
hash := sha256.Sum256(preimage[:])
if expectedHash != hash {
ctx.T.Fatalf("server claims with wrong preimage")
}
require.Equal(
ctx.T, expectedHash, lntypes.Hash(hash),
"server claims with wrong preimage",
)
return preimage
case <-time.After(Timeout):
@ -232,9 +236,8 @@ func (ctx *Context) DecodeInvoice(request string) *zpay32.Invoice {
ctx.T.Helper()
payReq, err := ctx.Lnd.DecodeInvoice(request)
if err != nil {
ctx.T.Fatal(err)
}
require.NoError(ctx.T, err)
return payReq
}
@ -256,7 +259,5 @@ func (ctx *Context) GetOutputIndex(tx *wire.MsgTx,
// waits for the notification to be processed by selecting on a
// dedicated test channel.
func (ctx *Context) NotifyServerHeight(height int32) {
if err := ctx.Lnd.NotifyHeight(height); err != nil {
ctx.T.Fatal(err)
}
require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height))
}

@ -14,6 +14,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/zpay32"
"github.com/stretchr/testify/require"
)
var (
@ -29,11 +30,10 @@ var (
// GetDestAddr deterministically generates a sweep address for testing.
func GetDestAddr(t *testing.T, nr byte) btcutil.Address {
destAddr, err := btcutil.NewAddressScriptHash([]byte{nr},
&chaincfg.MainNetParams)
if err != nil {
t.Fatal(err)
}
destAddr, err := btcutil.NewAddressScriptHash(
[]byte{nr}, &chaincfg.MainNetParams,
)
require.NoError(t, err)
return destAddr
}

@ -140,9 +140,8 @@ func (ctx *testContext) finish() {
ctx.stop()
select {
case err := <-ctx.runErr:
if err != nil {
ctx.T.Fatal(err)
}
require.NoError(ctx.T, err)
case <-time.After(test.Timeout):
ctx.T.Fatal("client not stopping")
}
@ -156,19 +155,12 @@ func (ctx *testContext) finish() {
func (ctx *testContext) notifyHeight(height int32) {
ctx.T.Helper()
if err := ctx.Lnd.NotifyHeight(height); err != nil {
ctx.T.Fatal(err)
}
require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height))
}
func (ctx *testContext) assertIsDone() {
if err := ctx.Lnd.IsDone(); err != nil {
ctx.T.Fatal(err)
}
if err := ctx.store.isDone(); err != nil {
ctx.T.Fatal(err)
}
require.NoError(ctx.T, ctx.Lnd.IsDone())
require.NoError(ctx.T, ctx.store.isDone())
select {
case <-ctx.statusChan:

Loading…
Cancel
Save