diff --git a/loopd/swapclient_server.go b/loopd/swapclient_server.go index a402baa..1c2de0c 100644 --- a/loopd/swapclient_server.go +++ b/loopd/swapclient_server.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcutil" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop" @@ -34,6 +35,17 @@ const ( minConfTarget = 2 ) +var ( + // errIncorrectChain is returned when the format of the + // destination address provided does not match the active chain. + errIncorrectChain = errors.New("invalid address format for the " + + "active chain") + + // errConfTargetTooLow is returned when the chosen confirmation target + // is below the allowed minimum. + errConfTargetTooLow = errors.New("confirmation target too low") +) + // swapClientServer implements the grpc service exposed by loopd. type swapClientServer struct { network lndclient.Network @@ -58,13 +70,6 @@ func (s *swapClientServer) LoopOut(ctx context.Context, log.Infof("Loop out request received") - sweepConfTarget, err := validateConfTarget( - in.SweepConfTarget, loop.DefaultSweepConfTarget, - ) - if err != nil { - return nil, err - } - var sweepAddr btcutil.Address if in.Dest == "" { // Generate sweep address if none specified. @@ -83,8 +88,10 @@ func (s *swapClientServer) LoopOut(ctx context.Context, } } - // Check that the label is valid. - if err := labels.Validate(in.Label); err != nil { + sweepConfTarget, err := validateLoopOutRequest( + s.lnd.ChainParams, in.SweepConfTarget, sweepAddr, in.Label, + ) + if err != nil { return nil, err } @@ -943,8 +950,9 @@ func validateConfTarget(target, defaultTarget int32) (int32, error) { // Ensure the target respects our minimum threshold. case target < minConfTarget: - return 0, fmt.Errorf("a confirmation target of at least %v "+ - "must be provided", minConfTarget) + return 0, fmt.Errorf("%w: A confirmation target of at "+ + "least %v must be provided", errConfTargetTooLow, + minConfTarget) default: return target, nil @@ -969,3 +977,22 @@ func validateLoopInRequest(htlcConfTarget int32, external bool) (int32, error) { return validateConfTarget(htlcConfTarget, loop.DefaultHtlcConfTarget) } + +// validateLoopOutRequest validates the confirmation target, destination +// address and label of the loop out request. +func validateLoopOutRequest(chainParams *chaincfg.Params, confTarget int32, + sweepAddr btcutil.Address, label string) (int32, error) { + // Check that the provided destination address has the correct format + // for the active network. + if !sweepAddr.IsForNet(chainParams) { + return 0, fmt.Errorf("%w: Current active network is %s", + errIncorrectChain, chainParams.Name) + } + + // Check that the label is valid. + if err := labels.Validate(label); err != nil { + return 0, err + } + + return validateConfTarget(confTarget, loop.DefaultSweepConfTarget) +} diff --git a/loopd/swapclient_server_test.go b/loopd/swapclient_server_test.go index 5ad3718..47bf141 100644 --- a/loopd/swapclient_server_test.go +++ b/loopd/swapclient_server_test.go @@ -1,9 +1,24 @@ package loopd import ( + "errors" "testing" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcutil" "github.com/lightninglabs/loop" + "github.com/lightninglabs/loop/labels" + "github.com/stretchr/testify/require" +) + +var ( + testnetAddr, _ = btcutil.NewAddressScriptHash( + []byte{123}, &chaincfg.TestNet3Params, + ) + + mainnetAddr, _ = btcutil.NewAddressScriptHash( + []byte{123}, &chaincfg.MainNetParams, + ) ) // TestValidateConfTarget tests all failure and success cases for our conf @@ -143,3 +158,95 @@ func TestValidateLoopInRequest(t *testing.T) { }) } } + +// TestValidateLoopOutRequest tests validation of loop out requests. +func TestValidateLoopOutRequest(t *testing.T) { + tests := []struct { + name string + chain chaincfg.Params + confTarget int32 + destAddr btcutil.Address + label string + err error + expectedTarget int32 + }{ + { + name: "mainnet address with mainnet backend", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + err: nil, + expectedTarget: 2, + }, + { + name: "mainnet address with testnet backend", + chain: chaincfg.TestNet3Params, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + err: errIncorrectChain, + expectedTarget: 0, + }, + { + name: "testnet address with testnet backend", + chain: chaincfg.TestNet3Params, + destAddr: testnetAddr, + label: "label ok", + confTarget: 2, + err: nil, + expectedTarget: 2, + }, + { + name: "testnet address with mainnet backend", + chain: chaincfg.MainNetParams, + destAddr: testnetAddr, + label: "label ok", + confTarget: 2, + err: errIncorrectChain, + expectedTarget: 0, + }, + { + name: "invalid label", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: labels.Reserved, + confTarget: 2, + err: labels.ErrReservedPrefix, + expectedTarget: 0, + }, + { + name: "invalid conf target", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 1, + err: errConfTargetTooLow, + expectedTarget: 0, + }, + { + name: "default conf target", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 0, + err: nil, + expectedTarget: 9, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + conf, err := validateLoopOutRequest( + &test.chain, test.confTarget, test.destAddr, + test.label, + ) + require.True(t, errors.Is(err, test.err)) + require.Equal(t, test.expectedTarget, conf) + }) + } +}