diff --git a/loopd/swapclient_server.go b/loopd/swapclient_server.go index d0e1e32..1e3f925 100644 --- a/loopd/swapclient_server.go +++ b/loopd/swapclient_server.go @@ -34,6 +34,12 @@ const ( minConfTarget = 2 ) +var ( + // errConfTargetTooLow is returned when the chosen confirmation target + // is below the allowed minimum. + errConfTargetTooLow = errors.New("confirmation target too low") +) + // swapClientServer implements the grpc service exposed by loopd. type swapClientServer struct { network lndclient.Network @@ -58,13 +64,6 @@ func (s *swapClientServer) LoopOut(ctx context.Context, log.Infof("Loop out request received") - sweepConfTarget, err := validateConfTarget( - in.SweepConfTarget, loop.DefaultSweepConfTarget, - ) - if err != nil { - return nil, err - } - var sweepAddr btcutil.Address if in.Dest == "" { // Generate sweep address if none specified. @@ -83,8 +82,9 @@ func (s *swapClientServer) LoopOut(ctx context.Context, } } - // Check that the label is valid. - if err := labels.Validate(in.Label); err != nil { + sweepConfTarget, err := validateLoopOutRequest(in.SweepConfTarget, + in.Label) + if err != nil { return nil, err } @@ -894,8 +894,9 @@ func validateConfTarget(target, defaultTarget int32) (int32, error) { // Ensure the target respects our minimum threshold. case target < minConfTarget: - return 0, fmt.Errorf("a confirmation target of at least %v "+ - "must be provided", minConfTarget) + return 0, fmt.Errorf("%w: A confirmation target of at "+ + "least %v must be provided", errConfTargetTooLow, + minConfTarget) default: return target, nil @@ -920,3 +921,14 @@ func validateLoopInRequest(htlcConfTarget int32, external bool) (int32, error) { return validateConfTarget(htlcConfTarget, loop.DefaultHtlcConfTarget) } + +// validateLoopOutRequest validates the confirmation target and label of the +// loop out request. +func validateLoopOutRequest(confTarget int32, label string) (int32, error) { + // Check that the label is valid. + if err := labels.Validate(label); err != nil { + return 0, err + } + + return validateConfTarget(confTarget, loop.DefaultSweepConfTarget) +} diff --git a/loopd/swapclient_server_test.go b/loopd/swapclient_server_test.go index 5ad3718..418eb36 100644 --- a/loopd/swapclient_server_test.go +++ b/loopd/swapclient_server_test.go @@ -1,9 +1,24 @@ package loopd import ( + "errors" "testing" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcutil" "github.com/lightninglabs/loop" + "github.com/lightninglabs/loop/labels" + "github.com/stretchr/testify/require" +) + +var ( + testnetAddr, _ = btcutil.NewAddressScriptHash( + []byte{123}, &chaincfg.TestNet3Params, + ) + + mainnetAddr, _ = btcutil.NewAddressScriptHash( + []byte{123}, &chaincfg.MainNetParams, + ) ) // TestValidateConfTarget tests all failure and success cases for our conf @@ -143,3 +158,50 @@ func TestValidateLoopInRequest(t *testing.T) { }) } } + +// TestValidateLoopOutRequest tests validation of loop out requests. +func TestValidateLoopOutRequest(t *testing.T) { + tests := []struct { + name string + confTarget int32 + label string + err error + expectedTarget int32 + }{ + { + name: "invalid label", + label: labels.Reserved, + confTarget: 2, + err: labels.ErrReservedPrefix, + expectedTarget: 0, + }, + { + name: "invalid conf target", + label: "label ok", + confTarget: 1, + err: errConfTargetTooLow, + expectedTarget: 0, + }, + { + name: "default conf target", + label: "label ok", + confTarget: 0, + err: nil, + expectedTarget: 9, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + conf, err := validateLoopOutRequest( + test.confTarget, test.label, + ) + require.True(t, errors.Is(err, test.err)) + require.Equal(t, test.expectedTarget, conf) + }) + } +}