From 5c34dd1177216e4d2282de6d3434ae65783be8d0 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 8 Mar 2021 09:50:17 +0200 Subject: [PATCH 1/2] loopd: refactor loop out request validation This commit moves loop out request validation for labels and confirmation targets into its own function for the purpose of easy testing and also to make the additions of future request validation easy to add and test. --- loopd/swapclient_server.go | 34 ++++++++++++------ loopd/swapclient_server_test.go | 62 +++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/loopd/swapclient_server.go b/loopd/swapclient_server.go index d0e1e32..1e3f925 100644 --- a/loopd/swapclient_server.go +++ b/loopd/swapclient_server.go @@ -34,6 +34,12 @@ const ( minConfTarget = 2 ) +var ( + // errConfTargetTooLow is returned when the chosen confirmation target + // is below the allowed minimum. + errConfTargetTooLow = errors.New("confirmation target too low") +) + // swapClientServer implements the grpc service exposed by loopd. type swapClientServer struct { network lndclient.Network @@ -58,13 +64,6 @@ func (s *swapClientServer) LoopOut(ctx context.Context, log.Infof("Loop out request received") - sweepConfTarget, err := validateConfTarget( - in.SweepConfTarget, loop.DefaultSweepConfTarget, - ) - if err != nil { - return nil, err - } - var sweepAddr btcutil.Address if in.Dest == "" { // Generate sweep address if none specified. @@ -83,8 +82,9 @@ func (s *swapClientServer) LoopOut(ctx context.Context, } } - // Check that the label is valid. - if err := labels.Validate(in.Label); err != nil { + sweepConfTarget, err := validateLoopOutRequest(in.SweepConfTarget, + in.Label) + if err != nil { return nil, err } @@ -894,8 +894,9 @@ func validateConfTarget(target, defaultTarget int32) (int32, error) { // Ensure the target respects our minimum threshold. case target < minConfTarget: - return 0, fmt.Errorf("a confirmation target of at least %v "+ - "must be provided", minConfTarget) + return 0, fmt.Errorf("%w: A confirmation target of at "+ + "least %v must be provided", errConfTargetTooLow, + minConfTarget) default: return target, nil @@ -920,3 +921,14 @@ func validateLoopInRequest(htlcConfTarget int32, external bool) (int32, error) { return validateConfTarget(htlcConfTarget, loop.DefaultHtlcConfTarget) } + +// validateLoopOutRequest validates the confirmation target and label of the +// loop out request. +func validateLoopOutRequest(confTarget int32, label string) (int32, error) { + // Check that the label is valid. + if err := labels.Validate(label); err != nil { + return 0, err + } + + return validateConfTarget(confTarget, loop.DefaultSweepConfTarget) +} diff --git a/loopd/swapclient_server_test.go b/loopd/swapclient_server_test.go index 5ad3718..418eb36 100644 --- a/loopd/swapclient_server_test.go +++ b/loopd/swapclient_server_test.go @@ -1,9 +1,24 @@ package loopd import ( + "errors" "testing" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcutil" "github.com/lightninglabs/loop" + "github.com/lightninglabs/loop/labels" + "github.com/stretchr/testify/require" +) + +var ( + testnetAddr, _ = btcutil.NewAddressScriptHash( + []byte{123}, &chaincfg.TestNet3Params, + ) + + mainnetAddr, _ = btcutil.NewAddressScriptHash( + []byte{123}, &chaincfg.MainNetParams, + ) ) // TestValidateConfTarget tests all failure and success cases for our conf @@ -143,3 +158,50 @@ func TestValidateLoopInRequest(t *testing.T) { }) } } + +// TestValidateLoopOutRequest tests validation of loop out requests. +func TestValidateLoopOutRequest(t *testing.T) { + tests := []struct { + name string + confTarget int32 + label string + err error + expectedTarget int32 + }{ + { + name: "invalid label", + label: labels.Reserved, + confTarget: 2, + err: labels.ErrReservedPrefix, + expectedTarget: 0, + }, + { + name: "invalid conf target", + label: "label ok", + confTarget: 1, + err: errConfTargetTooLow, + expectedTarget: 0, + }, + { + name: "default conf target", + label: "label ok", + confTarget: 0, + err: nil, + expectedTarget: 9, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + conf, err := validateLoopOutRequest( + test.confTarget, test.label, + ) + require.True(t, errors.Is(err, test.err)) + require.Equal(t, test.expectedTarget, conf) + }) + } +} From 5399e605545376be4bc6e79fa75a725a981b71f4 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 8 Mar 2021 10:52:57 +0200 Subject: [PATCH 2/2] loopd: verify that dest addr is for correct chain This commit adds verification to the loop out request to ensure that the formatting of the specified destination address matches the network that lnd is running on. --- loopd/swapclient_server.go | 25 ++++++++++++++---- loopd/swapclient_server_test.go | 47 ++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/loopd/swapclient_server.go b/loopd/swapclient_server.go index 1e3f925..550504a 100644 --- a/loopd/swapclient_server.go +++ b/loopd/swapclient_server.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcutil" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop" @@ -35,6 +36,11 @@ const ( ) var ( + // errIncorrectChain is returned when the format of the + // destination address provided does not match the active chain. + errIncorrectChain = errors.New("invalid address format for the " + + "active chain") + // errConfTargetTooLow is returned when the chosen confirmation target // is below the allowed minimum. errConfTargetTooLow = errors.New("confirmation target too low") @@ -82,8 +88,9 @@ func (s *swapClientServer) LoopOut(ctx context.Context, } } - sweepConfTarget, err := validateLoopOutRequest(in.SweepConfTarget, - in.Label) + sweepConfTarget, err := validateLoopOutRequest( + s.lnd.ChainParams, in.SweepConfTarget, sweepAddr, in.Label, + ) if err != nil { return nil, err } @@ -922,9 +929,17 @@ func validateLoopInRequest(htlcConfTarget int32, external bool) (int32, error) { return validateConfTarget(htlcConfTarget, loop.DefaultHtlcConfTarget) } -// validateLoopOutRequest validates the confirmation target and label of the -// loop out request. -func validateLoopOutRequest(confTarget int32, label string) (int32, error) { +// validateLoopOutRequest validates the confirmation target, destination +// address and label of the loop out request. +func validateLoopOutRequest(chainParams *chaincfg.Params, confTarget int32, + sweepAddr btcutil.Address, label string) (int32, error) { + // Check that the provided destination address has the correct format + // for the active network. + if !sweepAddr.IsForNet(chainParams) { + return 0, fmt.Errorf("%w: Current active network is %s", + errIncorrectChain, chainParams.Name) + } + // Check that the label is valid. if err := labels.Validate(label); err != nil { return 0, err diff --git a/loopd/swapclient_server_test.go b/loopd/swapclient_server_test.go index 418eb36..47bf141 100644 --- a/loopd/swapclient_server_test.go +++ b/loopd/swapclient_server_test.go @@ -163,13 +163,53 @@ func TestValidateLoopInRequest(t *testing.T) { func TestValidateLoopOutRequest(t *testing.T) { tests := []struct { name string + chain chaincfg.Params confTarget int32 + destAddr btcutil.Address label string err error expectedTarget int32 }{ + { + name: "mainnet address with mainnet backend", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + err: nil, + expectedTarget: 2, + }, + { + name: "mainnet address with testnet backend", + chain: chaincfg.TestNet3Params, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + err: errIncorrectChain, + expectedTarget: 0, + }, + { + name: "testnet address with testnet backend", + chain: chaincfg.TestNet3Params, + destAddr: testnetAddr, + label: "label ok", + confTarget: 2, + err: nil, + expectedTarget: 2, + }, + { + name: "testnet address with mainnet backend", + chain: chaincfg.MainNetParams, + destAddr: testnetAddr, + label: "label ok", + confTarget: 2, + err: errIncorrectChain, + expectedTarget: 0, + }, { name: "invalid label", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, label: labels.Reserved, confTarget: 2, err: labels.ErrReservedPrefix, @@ -177,6 +217,8 @@ func TestValidateLoopOutRequest(t *testing.T) { }, { name: "invalid conf target", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, label: "label ok", confTarget: 1, err: errConfTargetTooLow, @@ -184,6 +226,8 @@ func TestValidateLoopOutRequest(t *testing.T) { }, { name: "default conf target", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, label: "label ok", confTarget: 0, err: nil, @@ -198,7 +242,8 @@ func TestValidateLoopOutRequest(t *testing.T) { t.Parallel() conf, err := validateLoopOutRequest( - test.confTarget, test.label, + &test.chain, test.confTarget, test.destAddr, + test.label, ) require.True(t, errors.Is(err, test.err)) require.Equal(t, test.expectedTarget, conf)