Merge pull request #349 from ellemouton/validate-dest-addr-network

loopd: verify that dest addr is for correct chain
pull/353/head
Carla Kirk-Cohen 3 years ago committed by GitHub
commit 4d9d398b23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,6 +8,7 @@ import (
"sync"
"time"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcutil"
"github.com/lightninglabs/lndclient"
"github.com/lightninglabs/loop"
@ -34,6 +35,17 @@ const (
minConfTarget = 2
)
var (
// errIncorrectChain is returned when the format of the
// destination address provided does not match the active chain.
errIncorrectChain = errors.New("invalid address format for the " +
"active chain")
// errConfTargetTooLow is returned when the chosen confirmation target
// is below the allowed minimum.
errConfTargetTooLow = errors.New("confirmation target too low")
)
// swapClientServer implements the grpc service exposed by loopd.
type swapClientServer struct {
network lndclient.Network
@ -58,13 +70,6 @@ func (s *swapClientServer) LoopOut(ctx context.Context,
log.Infof("Loop out request received")
sweepConfTarget, err := validateConfTarget(
in.SweepConfTarget, loop.DefaultSweepConfTarget,
)
if err != nil {
return nil, err
}
var sweepAddr btcutil.Address
if in.Dest == "" {
// Generate sweep address if none specified.
@ -83,8 +88,10 @@ func (s *swapClientServer) LoopOut(ctx context.Context,
}
}
// Check that the label is valid.
if err := labels.Validate(in.Label); err != nil {
sweepConfTarget, err := validateLoopOutRequest(
s.lnd.ChainParams, in.SweepConfTarget, sweepAddr, in.Label,
)
if err != nil {
return nil, err
}
@ -943,8 +950,9 @@ func validateConfTarget(target, defaultTarget int32) (int32, error) {
// Ensure the target respects our minimum threshold.
case target < minConfTarget:
return 0, fmt.Errorf("a confirmation target of at least %v "+
"must be provided", minConfTarget)
return 0, fmt.Errorf("%w: A confirmation target of at "+
"least %v must be provided", errConfTargetTooLow,
minConfTarget)
default:
return target, nil
@ -969,3 +977,22 @@ func validateLoopInRequest(htlcConfTarget int32, external bool) (int32, error) {
return validateConfTarget(htlcConfTarget, loop.DefaultHtlcConfTarget)
}
// validateLoopOutRequest validates the confirmation target, destination
// address and label of the loop out request.
func validateLoopOutRequest(chainParams *chaincfg.Params, confTarget int32,
sweepAddr btcutil.Address, label string) (int32, error) {
// Check that the provided destination address has the correct format
// for the active network.
if !sweepAddr.IsForNet(chainParams) {
return 0, fmt.Errorf("%w: Current active network is %s",
errIncorrectChain, chainParams.Name)
}
// Check that the label is valid.
if err := labels.Validate(label); err != nil {
return 0, err
}
return validateConfTarget(confTarget, loop.DefaultSweepConfTarget)
}

@ -1,9 +1,24 @@
package loopd
import (
"errors"
"testing"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcutil"
"github.com/lightninglabs/loop"
"github.com/lightninglabs/loop/labels"
"github.com/stretchr/testify/require"
)
var (
testnetAddr, _ = btcutil.NewAddressScriptHash(
[]byte{123}, &chaincfg.TestNet3Params,
)
mainnetAddr, _ = btcutil.NewAddressScriptHash(
[]byte{123}, &chaincfg.MainNetParams,
)
)
// TestValidateConfTarget tests all failure and success cases for our conf
@ -143,3 +158,95 @@ func TestValidateLoopInRequest(t *testing.T) {
})
}
}
// TestValidateLoopOutRequest tests validation of loop out requests.
func TestValidateLoopOutRequest(t *testing.T) {
tests := []struct {
name string
chain chaincfg.Params
confTarget int32
destAddr btcutil.Address
label string
err error
expectedTarget int32
}{
{
name: "mainnet address with mainnet backend",
chain: chaincfg.MainNetParams,
destAddr: mainnetAddr,
label: "label ok",
confTarget: 2,
err: nil,
expectedTarget: 2,
},
{
name: "mainnet address with testnet backend",
chain: chaincfg.TestNet3Params,
destAddr: mainnetAddr,
label: "label ok",
confTarget: 2,
err: errIncorrectChain,
expectedTarget: 0,
},
{
name: "testnet address with testnet backend",
chain: chaincfg.TestNet3Params,
destAddr: testnetAddr,
label: "label ok",
confTarget: 2,
err: nil,
expectedTarget: 2,
},
{
name: "testnet address with mainnet backend",
chain: chaincfg.MainNetParams,
destAddr: testnetAddr,
label: "label ok",
confTarget: 2,
err: errIncorrectChain,
expectedTarget: 0,
},
{
name: "invalid label",
chain: chaincfg.MainNetParams,
destAddr: mainnetAddr,
label: labels.Reserved,
confTarget: 2,
err: labels.ErrReservedPrefix,
expectedTarget: 0,
},
{
name: "invalid conf target",
chain: chaincfg.MainNetParams,
destAddr: mainnetAddr,
label: "label ok",
confTarget: 1,
err: errConfTargetTooLow,
expectedTarget: 0,
},
{
name: "default conf target",
chain: chaincfg.MainNetParams,
destAddr: mainnetAddr,
label: "label ok",
confTarget: 0,
err: nil,
expectedTarget: 9,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
conf, err := validateLoopOutRequest(
&test.chain, test.confTarget, test.destAddr,
test.label,
)
require.True(t, errors.Is(err, test.err))
require.Equal(t, test.expectedTarget, conf)
})
}
}

Loading…
Cancel
Save