loopd: refactor loop out request validation

This commit moves loop out request validation for labels and
confirmation targets into its own function for the purpose of easy
testing and also to make the additions of future request validation easy
to add and test.
pull/349/head
Elle Mouton 3 years ago
parent 9ce7fe4df9
commit 5c34dd1177

@ -34,6 +34,12 @@ const (
minConfTarget = 2
)
var (
// errConfTargetTooLow is returned when the chosen confirmation target
// is below the allowed minimum.
errConfTargetTooLow = errors.New("confirmation target too low")
)
// swapClientServer implements the grpc service exposed by loopd.
type swapClientServer struct {
network lndclient.Network
@ -58,13 +64,6 @@ func (s *swapClientServer) LoopOut(ctx context.Context,
log.Infof("Loop out request received")
sweepConfTarget, err := validateConfTarget(
in.SweepConfTarget, loop.DefaultSweepConfTarget,
)
if err != nil {
return nil, err
}
var sweepAddr btcutil.Address
if in.Dest == "" {
// Generate sweep address if none specified.
@ -83,8 +82,9 @@ func (s *swapClientServer) LoopOut(ctx context.Context,
}
}
// Check that the label is valid.
if err := labels.Validate(in.Label); err != nil {
sweepConfTarget, err := validateLoopOutRequest(in.SweepConfTarget,
in.Label)
if err != nil {
return nil, err
}
@ -894,8 +894,9 @@ func validateConfTarget(target, defaultTarget int32) (int32, error) {
// Ensure the target respects our minimum threshold.
case target < minConfTarget:
return 0, fmt.Errorf("a confirmation target of at least %v "+
"must be provided", minConfTarget)
return 0, fmt.Errorf("%w: A confirmation target of at "+
"least %v must be provided", errConfTargetTooLow,
minConfTarget)
default:
return target, nil
@ -920,3 +921,14 @@ func validateLoopInRequest(htlcConfTarget int32, external bool) (int32, error) {
return validateConfTarget(htlcConfTarget, loop.DefaultHtlcConfTarget)
}
// validateLoopOutRequest validates the confirmation target and label of the
// loop out request.
func validateLoopOutRequest(confTarget int32, label string) (int32, error) {
// Check that the label is valid.
if err := labels.Validate(label); err != nil {
return 0, err
}
return validateConfTarget(confTarget, loop.DefaultSweepConfTarget)
}

@ -1,9 +1,24 @@
package loopd
import (
"errors"
"testing"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcutil"
"github.com/lightninglabs/loop"
"github.com/lightninglabs/loop/labels"
"github.com/stretchr/testify/require"
)
var (
testnetAddr, _ = btcutil.NewAddressScriptHash(
[]byte{123}, &chaincfg.TestNet3Params,
)
mainnetAddr, _ = btcutil.NewAddressScriptHash(
[]byte{123}, &chaincfg.MainNetParams,
)
)
// TestValidateConfTarget tests all failure and success cases for our conf
@ -143,3 +158,50 @@ func TestValidateLoopInRequest(t *testing.T) {
})
}
}
// TestValidateLoopOutRequest tests validation of loop out requests.
func TestValidateLoopOutRequest(t *testing.T) {
tests := []struct {
name string
confTarget int32
label string
err error
expectedTarget int32
}{
{
name: "invalid label",
label: labels.Reserved,
confTarget: 2,
err: labels.ErrReservedPrefix,
expectedTarget: 0,
},
{
name: "invalid conf target",
label: "label ok",
confTarget: 1,
err: errConfTargetTooLow,
expectedTarget: 0,
},
{
name: "default conf target",
label: "label ok",
confTarget: 0,
err: nil,
expectedTarget: 9,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
conf, err := validateLoopOutRequest(
test.confTarget, test.label,
)
require.True(t, errors.Is(err, test.err))
require.Equal(t, test.expectedTarget, conf)
})
}
}

Loading…
Cancel
Save