loopdb: add the local pubkey's keylocator to the persisted contract

pull/497/head
Andras Banki-Horvath 2 years ago
parent ce3026daa9
commit a252e2c706
No known key found for this signature in database
GPG Key ID: 80E5375C094198D8

@ -1,7 +1,11 @@
package loopdb
import (
"bytes"
"fmt"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/keychain"
)
// itob returns an 8-byte big endian representation of v.
@ -40,3 +44,39 @@ func MarshalProtocolVersion(version ProtocolVersion) []byte {
return versionBytes[:]
}
// MarshalKeyLocator marshals a keychain.KeyLocator to a byte slice.
func MarshalKeyLocator(keyLocator keychain.KeyLocator) ([]byte, error) {
var (
scratch [8]byte
buf bytes.Buffer
)
err := channeldb.EKeyLocator(&buf, &keyLocator, &scratch)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// UnmarshalKeyLocator unmarshals a keychain.KeyLocator from a byte slice.
func UnmarshalKeyLocator(data []byte) (keychain.KeyLocator, error) {
if data == nil {
return keychain.KeyLocator{}, nil
}
var (
scratch [8]byte
keyLocator keychain.KeyLocator
)
err := channeldb.DKeyLocator(
bytes.NewReader(data), &keyLocator, &scratch, 8,
)
if err != nil {
return keychain.KeyLocator{}, err
}
return keyLocator, nil
}

@ -1,8 +1,10 @@
package loopdb
import (
"math"
"testing"
"github.com/lightningnetwork/lnd/keychain"
"github.com/stretchr/testify/require"
)
@ -51,3 +53,46 @@ func TestProtocolVersionMarshalUnMarshal(t *testing.T) {
require.Equal(t, ProtocolVersionUnrecorded, version)
}
}
// TestKeyLocatorMarshalUnMarshal tests that marshalling and unmarshalling
// keychain.KeyLocator works correctly.
func TestKeyLocatorMarshalUnMarshal(t *testing.T) {
t.Parallel()
tests := []struct {
keyLoc keychain.KeyLocator
}{
{
// Test that an empty keylocator is serialized and
// deserialized correctly.
keyLoc: keychain.KeyLocator{},
},
{
// Test that the max value keylocator is serialized and
// deserialized correctly.
keyLoc: keychain.KeyLocator{
Family: keychain.KeyFamily(math.MaxUint32),
Index: math.MaxUint32,
},
},
{
// Test that an arbitrary keylocator is serialized and
// deserialized correctly.
keyLoc: keychain.KeyLocator{
Family: keychain.KeyFamily(5),
Index: 7,
},
},
}
for _, test := range tests {
test := test
buf, err := MarshalKeyLocator(test.keyLoc)
require.NoError(t, err)
keyLoc, err := UnmarshalKeyLocator(buf)
require.NoError(t, err)
require.Equal(t, test.keyLoc, keyLoc)
}
}

@ -6,6 +6,7 @@ import (
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
)
@ -26,6 +27,11 @@ type SwapContract struct {
// HTLC.
ReceiverKey [33]byte
// ClientKeyLocator is the key locator (family and index) for the client
// key. It is for the receiver key if this is a loop out contract, or
// the sender key if this is a loop in contract.
ClientKeyLocator keychain.KeyLocator
// CltvExpiry is the total absolute CLTV expiry of the swap.
CltvExpiry int32

@ -102,6 +102,16 @@ var (
// parameters.
liquidtyParamsKey = []byte("params")
// keyLocatorKey is the key that stores the receiver key's locator info
// for loop outs or the sender key's locator info for loop ins. This is
// required for MuSig2 swaps. Only serialized/deserialized for swaps
// that have protocol version >= ProtocolVersionHtlcV3.
//
// path: loopInBucket/loopOutBucket -> swapBucket[hash] -> keyLocatorKey
//
// value: concatenation of uint32 values [family, index].
keyLocatorKey = []byte("keylocator")
byteOrder = binary.BigEndian
keyLength = 33
@ -327,6 +337,16 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) {
return err
}
// Try to unmarshal the key locator.
if contract.ProtocolVersion >= ProtocolVersionHtlcV3 {
contract.ClientKeyLocator, err = UnmarshalKeyLocator(
swapBucket.Get(keyLocatorKey),
)
if err != nil {
return err
}
}
loop := LoopOut{
Loop: Loop{
Events: updates,
@ -464,6 +484,16 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) {
return err
}
// Try to unmarshal the key locator.
if contract.ProtocolVersion >= ProtocolVersionHtlcV3 {
contract.ClientKeyLocator, err = UnmarshalKeyLocator(
swapBucket.Get(keyLocatorKey),
)
if err != nil {
return err
}
}
loop := LoopIn{
Loop: Loop{
Events: updates,
@ -583,6 +613,21 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash,
return err
}
// Store the key locator for swaps with taproot htlc.
if swap.ProtocolVersion >= ProtocolVersionHtlcV3 {
keyLocator, err := MarshalKeyLocator(
swap.ClientKeyLocator,
)
if err != nil {
return err
}
err = swapBucket.Put(keyLocatorKey, keyLocator)
if err != nil {
return err
}
}
// Finally, we'll create an empty updates bucket for this swap
// to track any future updates to the swap itself.
_, err = swapBucket.CreateBucket(updatesBucketKey)
@ -634,6 +679,21 @@ func (s *boltSwapStore) CreateLoopIn(hash lntypes.Hash,
return err
}
// Store the key locator for swaps with taproot htlc.
if swap.ProtocolVersion >= ProtocolVersionHtlcV3 {
keyLocator, err := MarshalKeyLocator(
swap.ClientKeyLocator,
)
if err != nil {
return err
}
err = swapBucket.Put(keyLocatorKey, keyLocator)
if err != nil {
return err
}
}
// Finally, we'll create an empty updates bucket for this swap
// to track any future updates to the swap itself.
_, err = swapBucket.CreateBucket(updatesBucketKey)

@ -240,6 +240,7 @@ func newLoopInSwap(globalCtx context.Context, cfg *swapConfig,
InitiationTime: initiationTime,
ReceiverKey: swapResp.receiverKey,
SenderKey: senderKey,
ClientKeyLocator: keyDesc.KeyLocator,
Preimage: swapPreimage,
AmountRequested: request.Amount,
CltvExpiry: swapResp.expiry,

@ -171,6 +171,7 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig,
InitiationTime: initiationTime,
ReceiverKey: receiverKey,
SenderKey: swapResp.senderKey,
ClientKeyLocator: keyDesc.KeyLocator,
Preimage: swapPreimage,
AmountRequested: request.Amount,
CltvExpiry: request.Expiry,

Loading…
Cancel
Save