diff --git a/test/lightning_client_mock.go b/test/lightning_client_mock.go index ef97b25..2202223 100644 --- a/test/lightning_client_mock.go +++ b/test/lightning_client_mock.go @@ -14,6 +14,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/zpay32" "golang.org/x/net/context" ) @@ -180,6 +181,53 @@ func (h *mockLightningClient) ListTransactions( return txs, nil } +// GetNodeInfo retrieves info on the node, and if includeChannels is True, +// will return other channels the node may have with other peers +func (h *mockLightningClient) GetNodeInfo(ctx context.Context, + pubKeyBytes route.Vertex, includeChannels bool) (*lndclient.NodeInfo, error) { + + nodeInfo := lndclient.NodeInfo{} + + if !includeChannels { + return nil, nil + } + + nodePubKey, err := route.NewVertexFromStr(h.lnd.NodePubkey) + if err != nil { + return nil, err + } + + // NodeInfo.Channels should only contain channels which: do not belong + // to the queried node; are not private; have the provided vertex as a + // participant + for _, edge := range h.lnd.ChannelEdges { + if (edge.Node1 == pubKeyBytes || edge.Node2 == pubKeyBytes) && + (edge.Node1 != nodePubKey || edge.Node2 != nodePubKey) { + + for _, channel := range h.lnd.Channels { + if channel.ChannelID == edge.ChannelID && !channel.Private { + nodeInfo.Channels = append(nodeInfo.Channels, *edge) + } + } + } + } + + nodeInfo.ChannelCount = len(nodeInfo.Channels) + + return &nodeInfo, nil +} + +// GetChanInfo retrieves all the info the node has on the given channel +func (h *mockLightningClient) GetChanInfo(ctx context.Context, + channelID uint64) (*lndclient.ChannelEdge, error) { + + var channelEdge *lndclient.ChannelEdge + if channelEdge, ok := h.lnd.ChannelEdges[channelID]; ok { + return channelEdge, nil + } + return channelEdge, fmt.Errorf("not found") +} + // ListChannels retrieves all channels of the backing lnd node. func (h *mockLightningClient) ListChannels(ctx context.Context, _, _ bool) ( []lndclient.ChannelInfo, error) { diff --git a/test/lnd_services_mock.go b/test/lnd_services_mock.go index 7d3466c..ccb14d8 100644 --- a/test/lnd_services_mock.go +++ b/test/lnd_services_mock.go @@ -164,6 +164,7 @@ type LndMockServices struct { Invoices map[lntypes.Hash]*lndclient.Invoice Channels []lndclient.ChannelInfo + ChannelEdges map[uint64]*lndclient.ChannelEdge ClosedChannels []lndclient.ClosedChannel ForwardingEvents []lndclient.ForwardingEvent Payments []lndclient.Payment diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..02f40dc --- /dev/null +++ b/utils.go @@ -0,0 +1,203 @@ +package loop + +import ( + "context" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcutil" + "github.com/lightninglabs/lndclient" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/zpay32" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var ( + // DefaultMaxHopHints is set to 20 as that is the default set in LND + DefaultMaxHopHints = 20 +) + +// SelectHopHints is a direct port of the SelectHopHints found in lnd. It was +// reimplemented because the current implementation in LND relies on internals +// not externalized through the API. Hopefully in the future SelectHopHints +// will be refactored to allow for custom data sources. It iterates through all +// the active and public channels available and returns eligible channels. +// Eligibility requirements are simple: does the channel have enough liquidity +// to fulfill the request and is the node whitelisted (if specified) +func SelectHopHints(ctx context.Context, lnd *lndclient.LndServices, + amtMSat btcutil.Amount, numMaxHophints int, + includeNodes map[route.Vertex]struct{}) ([][]zpay32.HopHint, error) { + + // Fetch all active and public channels. + openChannels, err := lnd.Client.ListChannels(ctx, false, false) + if err != nil { + return nil, err + } + + // We'll add our hop hints in two passes, first we'll add all channels + // that are eligible to be hop hints, and also have a local balance + // above the payment amount. + var totalHintBandwidth btcutil.Amount + + // chanInfoCache is a simple cache for any information we retrieve + // through GetChanInfo + chanInfoCache := make(map[uint64]*lndclient.ChannelEdge) + + // skipCache is a simple cache which holds the indice of any + // channel we've added to final hopHints + skipCache := make(map[int]struct{}) + + hopHints := make([][]zpay32.HopHint, 0, numMaxHophints) + + for i, channel := range openChannels { + // In this first pass, we'll ignore all channels in + // isolation that can't satisfy this payment. + + // Retrieve extra info for each channel not available in + // listChannels + chanInfo, err := lnd.Client.GetChanInfo(ctx, channel.ChannelID) + if err != nil { + return nil, err + } + + // Cache the GetChanInfo result since it might be useful + chanInfoCache[channel.ChannelID] = chanInfo + + // Skip if channel can't forward payment + if channel.RemoteBalance < amtMSat { + log.Debugf( + "Skipping ChannelID: %v for hints as "+ + "remote balance (%v sats) "+ + "insufficient appears to be private", + channel.ChannelID, channel.RemoteBalance, + ) + continue + } + // If includeNodes is set, we'll only add channels with peers in + // includeNodes. This is done to respect the last_hop parameter. + if len(includeNodes) > 0 { + if _, ok := includeNodes[channel.PubKeyBytes]; !ok { + continue + } + } + + // Mark the index to skip so we can skip it on the next + // iteration if needed. We'll skip all channels that make + // it past this point as they'll likely belong to private + // nodes or be selected. + skipCache[i] = struct{}{} + + // We want to prevent leaking private nodes, which we define as + // nodes with only private channels. + // + // GetNodeInfo will never return private channels, even if + // they're somehow known to us. If there are any channels + // returned, we can consider the node to be public. + nodeInfo, err := lnd.Client.GetNodeInfo( + ctx, channel.PubKeyBytes, true, + ) + + // If the error is node isn't found, just iterate. Otherwise, + // fail. + status, ok := status.FromError(err) + if ok && status.Code() == codes.NotFound { + log.Warnf("Skipping ChannelID: %v for hints as peer "+ + "(NodeID: %v) is not found: %v", + channel.ChannelID, channel.PubKeyBytes.String(), + err) + continue + } else if err != nil { + return nil, err + } + + if len(nodeInfo.Channels) == 0 { + log.Infof( + "Skipping ChannelID: %v for hints as peer "+ + "(NodeID: %v) appears to be private", + channel.ChannelID, channel.PubKeyBytes.String(), + ) + continue + } + + nodeID, err := btcec.ParsePubKey( + channel.PubKeyBytes[:], btcec.S256(), + ) + if err != nil { + return nil, err + } + + // Now that we now this channel use usable, add it as a hop + // hint and the indexes we'll use later. + hopHints = append(hopHints, []zpay32.HopHint{{ + NodeID: nodeID, + ChannelID: channel.ChannelID, + FeeBaseMSat: uint32(chanInfo.Node2Policy.FeeBaseMsat), + FeeProportionalMillionths: uint32( + chanInfo.Node2Policy.FeeRateMilliMsat, + ), + CLTVExpiryDelta: uint16( + chanInfo.Node2Policy.TimeLockDelta), + }}) + + totalHintBandwidth += channel.RemoteBalance + } + + // If we have enough hop hints at this point, then we'll exit early. + // Otherwise, we'll continue to add more that may help out mpp users. + if len(hopHints) >= numMaxHophints { + return hopHints, nil + } + + // In this second pass we'll add channels, and we'll either stop when + // we have 20 hop hints, we've run through all the available channels, + // or if the sum of available bandwidth in the routing hints exceeds 2x + // the payment amount. We do 2x here to account for a margin of error + // if some of the selected channels no longer become operable. + hopHintFactor := btcutil.Amount(lnwire.MilliSatoshi(2)) + + for i := 0; i < len(openChannels); i++ { + // If we hit either of our early termination conditions, then + // we'll break the loop here. + if totalHintBandwidth > amtMSat*hopHintFactor || + len(hopHints) >= numMaxHophints { + + break + } + + // Skip the channel if we already selected it. + if _, ok := skipCache[i]; ok { + continue + } + + channel := openChannels[i] + chanInfo := chanInfoCache[channel.ChannelID] + + nodeID, err := btcec.ParsePubKey( + channel.PubKeyBytes[:], btcec.S256()) + if err != nil { + continue + } + + // Include the route hint in our set of options that will be + // used when creating the invoice. + hopHints = append(hopHints, []zpay32.HopHint{{ + NodeID: nodeID, + ChannelID: channel.ChannelID, + FeeBaseMSat: uint32(chanInfo.Node2Policy.FeeBaseMsat), + FeeProportionalMillionths: uint32( + chanInfo.Node2Policy.FeeRateMilliMsat, + ), + CLTVExpiryDelta: uint16( + chanInfo.Node2Policy.TimeLockDelta), + }}) + + // As we've just added a new hop hint, we'll accumulate it's + // available balance now to update our tally. + // + // TODO(roasbeef): have a cut off based on min bandwidth? + totalHintBandwidth += channel.RemoteBalance + } + + return hopHints, nil +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..d3bfbc6 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,297 @@ +package loop + +import ( + "context" + "encoding/hex" + "math/big" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcutil" + "github.com/lightninglabs/lndclient" + mock_lnd "github.com/lightninglabs/loop/test" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/zpay32" + "github.com/stretchr/testify/require" +) + +var ( + chanID1 = lnwire.NewShortChanIDFromInt(1) + chanID2 = lnwire.NewShortChanIDFromInt(2) + chanID3 = lnwire.NewShortChanIDFromInt(3) + chanID4 = lnwire.NewShortChanIDFromInt(4) + + // To generate a nodeID we'll have to perform a few steps. + // + // Step 1: We generate the corresponding Y value to an + // arbitrary X value on the secp25k1 curve. This + // is done outside this function, and the output + // is converted to string with big.Int.Text(16) + // and converted to []bytes. btcec.decompressPoint + // is used here to generate this Y value + // + // Step 2: Construct a btcec.PublicKey object with the + // aforementioned values + // + // Step 3: Convert the pubkey to a Vertex by passing a + // compressed pubkey. This compression looses the + // Y value as it can be inferred. + // + // The Vertex object mainly contains the X value information, + // and has the underlying []bytes type. We generate the Y + // value information ourselves as that is returned in the + // hophints, and we must ensure it's accuracy + + // Generate origin NodeID + originYBytes, _ = hex.DecodeString( + "bde70df51939b94c9c24979fa7dd04ebd9b" + + "3572da7802290438af2a681895441", + ) + pubKeyOrigin = &btcec.PublicKey{ + X: big.NewInt(0), + Y: new(big.Int).SetBytes(originYBytes), + Curve: btcec.S256(), + } + origin, _ = route.NewVertexFromBytes(pubKeyOrigin.SerializeCompressed()) + + // Generate peer1 NodeID + pubKey1YBytes, _ = hex.DecodeString( + "598ec453728e0ffe0ae2f5e174243cf58f2" + + "a3f2c83d2457b43036db568b11093", + ) + pubKeyPeer1 = &btcec.PublicKey{ + X: big.NewInt(4), + Y: new(big.Int).SetBytes(pubKey1YBytes), + Curve: btcec.S256(), + } + peer1, _ = route.NewVertexFromBytes(pubKeyPeer1.SerializeCompressed()) + + // Generate peer2 NodeID + pubKey2YBytes, _ = hex.DecodeString( + "bde70df51939b94c9c24979fa7dd04ebd" + + "9b3572da7802290438af2a681895441", + ) + pubKeyPeer2 = &btcec.PublicKey{ + X: big.NewInt(1), + Y: new(big.Int).SetBytes(pubKey2YBytes), + Curve: btcec.S256(), + } + peer2, _ = route.NewVertexFromBytes(pubKeyPeer2.SerializeCompressed()) + + // Construct channel1 which will be returned my listChannels and + // channelEdge1 which will be returned by getChanInfo + chan1Capacity = btcutil.Amount(10000) + channel1 = lndclient.ChannelInfo{ + Active: true, + ChannelID: chanID1.ToUint64(), + PubKeyBytes: peer1, + LocalBalance: 10000, + RemoteBalance: 0, + Capacity: chan1Capacity, + } + channelEdge1 = lndclient.ChannelEdge{ + ChannelID: chanID1.ToUint64(), + ChannelPoint: "b121f1d368b8f60648970bc36b37e7b9700d" + + "ed098c60b027e42e9c648e297502:0", + Capacity: chan1Capacity, + Node1: origin, + Node2: peer1, + Node1Policy: &lndclient.RoutingPolicy{ + FeeBaseMsat: 0, + FeeRateMilliMsat: 0, + TimeLockDelta: 144, + }, + Node2Policy: &lndclient.RoutingPolicy{ + FeeBaseMsat: 0, + FeeRateMilliMsat: 0, + TimeLockDelta: 144, + }, + } + + // Construct channel2 which will be returned my listChannels and + // channelEdge2 which will be returned by getChanInfo + chan2Capacity = btcutil.Amount(10000) + channel2 = lndclient.ChannelInfo{ + Active: true, + ChannelID: chanID1.ToUint64(), + PubKeyBytes: peer2, + LocalBalance: 0, + RemoteBalance: 10000, + Capacity: chan1Capacity, + } + channelEdge2 = lndclient.ChannelEdge{ + ChannelID: chanID2.ToUint64(), + ChannelPoint: "b121f1d368b8f60648970bc36b37e7b9700d" + + "ed098c60b027e42e9c648e297502:0", + Capacity: chan2Capacity, + Node1: origin, + Node2: peer2, + Node1Policy: &lndclient.RoutingPolicy{ + FeeBaseMsat: 0, + FeeRateMilliMsat: 0, + TimeLockDelta: 144, + }, + Node2Policy: &lndclient.RoutingPolicy{ + FeeBaseMsat: 0, + FeeRateMilliMsat: 0, + TimeLockDelta: 144, + }, + } + + // Construct channel3 which will be returned my listChannels and + // channelEdge3 which will be returned by getChanInfo + chan3Capacity = btcutil.Amount(10000) + channel3 = lndclient.ChannelInfo{ + Active: true, + ChannelID: chanID3.ToUint64(), + PubKeyBytes: peer2, + LocalBalance: 10000, + RemoteBalance: 0, + Capacity: chan1Capacity, + } + channelEdge3 = lndclient.ChannelEdge{ + ChannelID: chanID3.ToUint64(), + ChannelPoint: "b121f1d368b8f60648970bc36b37e7b9700d" + + "ed098c60b027e42e9c648e297502:0", + Capacity: chan3Capacity, + Node1: origin, + Node2: peer2, + Node1Policy: &lndclient.RoutingPolicy{ + FeeBaseMsat: 0, + FeeRateMilliMsat: 0, + TimeLockDelta: 144, + }, + Node2Policy: &lndclient.RoutingPolicy{ + FeeBaseMsat: 0, + FeeRateMilliMsat: 0, + TimeLockDelta: 144, + }, + } + + // Construct channel4 which will be returned my listChannels and + // channelEdge4 which will be returned by getChanInfo + chan4Capacity = btcutil.Amount(10000) + channel4 = lndclient.ChannelInfo{ + Active: true, + ChannelID: chanID4.ToUint64(), + PubKeyBytes: peer2, + LocalBalance: 10000, + RemoteBalance: 0, + Capacity: chan4Capacity, + } + channelEdge4 = lndclient.ChannelEdge{ + ChannelID: chanID4.ToUint64(), + ChannelPoint: "6fe4408bba52c0a0ee15365e107105de" + + "fabfc70c497556af69351c4cfbc167b:0", + Capacity: chan1Capacity, + Node1: origin, + Node2: peer2, + Node1Policy: &lndclient.RoutingPolicy{ + FeeBaseMsat: 0, + FeeRateMilliMsat: 0, + TimeLockDelta: 144, + }, + Node2Policy: &lndclient.RoutingPolicy{ + FeeBaseMsat: 0, + FeeRateMilliMsat: 0, + TimeLockDelta: 144, + }, + } +) + +func TestSelectHopHints(t *testing.T) { + tests := []struct { + name string + channels []lndclient.ChannelInfo + channelEdges map[uint64]*lndclient.ChannelEdge + expectedHopHints [][]zpay32.HopHint + amtMSat btcutil.Amount + numMaxHophints int + includeNodes map[route.Vertex]struct{} + expectedError error + }{ + // 3 inputs set assumes the host node has 3 channels to chose + // from. Only channel 2 with peer 2 is ideal, however we should + // still include the other 2 after in the order they were + // provided just in case + { + name: "3 inputs set", + channels: []lndclient.ChannelInfo{ + channel2, + channel3, + channel4, + }, + channelEdges: map[uint64]*lndclient.ChannelEdge{ + channel2.ChannelID: &channelEdge2, + channel3.ChannelID: &channelEdge3, + channel4.ChannelID: &channelEdge4, + }, + expectedHopHints: [][]zpay32.HopHint{ + {{ + NodeID: pubKeyPeer2, + ChannelID: channel2.ChannelID, + FeeBaseMSat: 0, + FeeProportionalMillionths: 0, + CLTVExpiryDelta: 144, + }}, + {{ + NodeID: pubKeyPeer2, + ChannelID: channel3.ChannelID, + FeeBaseMSat: 0, + FeeProportionalMillionths: 0, + CLTVExpiryDelta: 144, + }}, + {{ + NodeID: pubKeyPeer2, + ChannelID: channel4.ChannelID, + FeeBaseMSat: 0, + FeeProportionalMillionths: 0, + CLTVExpiryDelta: 144, + }}, + }, + amtMSat: chan1Capacity, + numMaxHophints: 20, + includeNodes: make(map[route.Vertex]struct{}), + expectedError: nil, + }, + { + name: "invalid set", + channels: []lndclient.ChannelInfo{ + channel1, + }, + channelEdges: map[uint64]*lndclient.ChannelEdge{ + channel1.ChannelID: &channelEdge1, + }, + expectedHopHints: [][]zpay32.HopHint{ + {{ + NodeID: pubKeyPeer1, + ChannelID: channel1.ChannelID, + FeeBaseMSat: 0, + FeeProportionalMillionths: 0, + CLTVExpiryDelta: 144, + }}, + }, amtMSat: chan1Capacity, + numMaxHophints: 20, + includeNodes: make(map[route.Vertex]struct{}), + expectedError: nil, + }, + } + for _, test := range tests { + test := test + ctx := context.Background() + + lnd := mock_lnd.NewMockLnd() + lnd.Channels = test.channels + lnd.ChannelEdges = test.channelEdges + t.Run(test.name, func(t *testing.T) { + hopHints, err := SelectHopHints( + ctx, &lnd.LndServices, test.amtMSat, + test.numMaxHophints, test.includeNodes, + ) + require.Equal(t, test.expectedError, err) + require.Equal(t, test.expectedHopHints, hopHints) + }) + } + +}