diff --git a/loopdb/codec.go b/loopdb/codec.go index d0f2ee2..fd3fa6e 100644 --- a/loopdb/codec.go +++ b/loopdb/codec.go @@ -1,7 +1,11 @@ package loopdb import ( + "bytes" "fmt" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/keychain" ) // itob returns an 8-byte big endian representation of v. @@ -40,3 +44,39 @@ func MarshalProtocolVersion(version ProtocolVersion) []byte { return versionBytes[:] } + +// MarshalKeyLocator marshals a keychain.KeyLocator to a byte slice. +func MarshalKeyLocator(keyLocator keychain.KeyLocator) ([]byte, error) { + var ( + scratch [8]byte + buf bytes.Buffer + ) + + err := channeldb.EKeyLocator(&buf, &keyLocator, &scratch) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// UnmarshalKeyLocator unmarshals a keychain.KeyLocator from a byte slice. +func UnmarshalKeyLocator(data []byte) (keychain.KeyLocator, error) { + if data == nil { + return keychain.KeyLocator{}, nil + } + + var ( + scratch [8]byte + keyLocator keychain.KeyLocator + ) + + err := channeldb.DKeyLocator( + bytes.NewReader(data), &keyLocator, &scratch, 8, + ) + if err != nil { + return keychain.KeyLocator{}, err + } + + return keyLocator, nil +} diff --git a/loopdb/codec_test.go b/loopdb/codec_test.go index 5d58e86..1d01c6a 100644 --- a/loopdb/codec_test.go +++ b/loopdb/codec_test.go @@ -1,8 +1,10 @@ package loopdb import ( + "math" "testing" + "github.com/lightningnetwork/lnd/keychain" "github.com/stretchr/testify/require" ) @@ -51,3 +53,46 @@ func TestProtocolVersionMarshalUnMarshal(t *testing.T) { require.Equal(t, ProtocolVersionUnrecorded, version) } } + +// TestKeyLocatorMarshalUnMarshal tests that marshalling and unmarshalling +// keychain.KeyLocator works correctly. +func TestKeyLocatorMarshalUnMarshal(t *testing.T) { + t.Parallel() + + tests := []struct { + keyLoc keychain.KeyLocator + }{ + { + // Test that an empty keylocator is serialized and + // deserialized correctly. + keyLoc: keychain.KeyLocator{}, + }, + { + // Test that the max value keylocator is serialized and + // deserialized correctly. + keyLoc: keychain.KeyLocator{ + Family: keychain.KeyFamily(math.MaxUint32), + Index: math.MaxUint32, + }, + }, + { + // Test that an arbitrary keylocator is serialized and + // deserialized correctly. + keyLoc: keychain.KeyLocator{ + Family: keychain.KeyFamily(5), + Index: 7, + }, + }, + } + + for _, test := range tests { + test := test + + buf, err := MarshalKeyLocator(test.keyLoc) + require.NoError(t, err) + + keyLoc, err := UnmarshalKeyLocator(buf) + require.NoError(t, err) + require.Equal(t, test.keyLoc, keyLoc) + } +} diff --git a/loopdb/loop.go b/loopdb/loop.go index 07d9d4b..98f195a 100644 --- a/loopdb/loop.go +++ b/loopdb/loop.go @@ -6,6 +6,7 @@ import ( "time" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" ) @@ -26,6 +27,11 @@ type SwapContract struct { // HTLC. ReceiverKey [33]byte + // ClientKeyLocator is the key locator (family and index) for the client + // key. It is for the receiver key if this is a loop out contract, or + // the sender key if this is a loop in contract. + ClientKeyLocator keychain.KeyLocator + // CltvExpiry is the total absolute CLTV expiry of the swap. CltvExpiry int32 diff --git a/loopdb/store.go b/loopdb/store.go index 67fdcd6..1343974 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -102,6 +102,16 @@ var ( // parameters. liquidtyParamsKey = []byte("params") + // keyLocatorKey is the key that stores the receiver key's locator info + // for loop outs or the sender key's locator info for loop ins. This is + // required for MuSig2 swaps. Only serialized/deserialized for swaps + // that have protocol version >= ProtocolVersionHtlcV3. + // + // path: loopInBucket/loopOutBucket -> swapBucket[hash] -> keyLocatorKey + // + // value: concatenation of uint32 values [family, index]. + keyLocatorKey = []byte("keylocator") + byteOrder = binary.BigEndian keyLength = 33 @@ -327,6 +337,16 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return err } + // Try to unmarshal the key locator. + if contract.ProtocolVersion >= ProtocolVersionHtlcV3 { + contract.ClientKeyLocator, err = UnmarshalKeyLocator( + swapBucket.Get(keyLocatorKey), + ) + if err != nil { + return err + } + } + loop := LoopOut{ Loop: Loop{ Events: updates, @@ -464,6 +484,16 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { return err } + // Try to unmarshal the key locator. + if contract.ProtocolVersion >= ProtocolVersionHtlcV3 { + contract.ClientKeyLocator, err = UnmarshalKeyLocator( + swapBucket.Get(keyLocatorKey), + ) + if err != nil { + return err + } + } + loop := LoopIn{ Loop: Loop{ Events: updates, @@ -583,6 +613,21 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, return err } + // Store the key locator for swaps with taproot htlc. + if swap.ProtocolVersion >= ProtocolVersionHtlcV3 { + keyLocator, err := MarshalKeyLocator( + swap.ClientKeyLocator, + ) + if err != nil { + return err + } + + err = swapBucket.Put(keyLocatorKey, keyLocator) + if err != nil { + return err + } + } + // Finally, we'll create an empty updates bucket for this swap // to track any future updates to the swap itself. _, err = swapBucket.CreateBucket(updatesBucketKey) @@ -634,6 +679,21 @@ func (s *boltSwapStore) CreateLoopIn(hash lntypes.Hash, return err } + // Store the key locator for swaps with taproot htlc. + if swap.ProtocolVersion >= ProtocolVersionHtlcV3 { + keyLocator, err := MarshalKeyLocator( + swap.ClientKeyLocator, + ) + if err != nil { + return err + } + + err = swapBucket.Put(keyLocatorKey, keyLocator) + if err != nil { + return err + } + } + // Finally, we'll create an empty updates bucket for this swap // to track any future updates to the swap itself. _, err = swapBucket.CreateBucket(updatesBucketKey) diff --git a/loopin.go b/loopin.go index 3e89c91..6326581 100644 --- a/loopin.go +++ b/loopin.go @@ -240,6 +240,7 @@ func newLoopInSwap(globalCtx context.Context, cfg *swapConfig, InitiationTime: initiationTime, ReceiverKey: swapResp.receiverKey, SenderKey: senderKey, + ClientKeyLocator: keyDesc.KeyLocator, Preimage: swapPreimage, AmountRequested: request.Amount, CltvExpiry: swapResp.expiry, diff --git a/loopout.go b/loopout.go index 56d679d..9368e40 100644 --- a/loopout.go +++ b/loopout.go @@ -171,6 +171,7 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, InitiationTime: initiationTime, ReceiverKey: receiverKey, SenderKey: swapResp.senderKey, + ClientKeyLocator: keyDesc.KeyLocator, Preimage: swapPreimage, AmountRequested: request.Amount, CltvExpiry: request.Expiry,