From 86db43a2cb52ee6445424d52e08581fa7ee597de Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Thu, 23 Jul 2020 17:28:36 +0200 Subject: [PATCH] loopdb: store protocol version alongside with swaps This commit adds the protocol version to each stored swap. This will be used to ensure that when swaps are resumed after a restart, they're correctly handled given any breaking protocol changes. --- loopdb/codec.go | 34 +++++++++++++++ loopdb/codec_test.go | 53 +++++++++++++++++++++++ loopdb/loop.go | 4 ++ loopdb/protocol_version.go | 76 +++++++++++++++++++++++++++++++++ loopdb/protocol_version_test.go | 47 ++++++++++++++++++++ loopdb/store.go | 49 +++++++++++++++++++++ swap_server_client.go | 23 +++++----- 7 files changed, 273 insertions(+), 13 deletions(-) create mode 100644 loopdb/codec_test.go create mode 100644 loopdb/protocol_version.go create mode 100644 loopdb/protocol_version_test.go diff --git a/loopdb/codec.go b/loopdb/codec.go index 990f164..d0f2ee2 100644 --- a/loopdb/codec.go +++ b/loopdb/codec.go @@ -1,8 +1,42 @@ package loopdb +import ( + "fmt" +) + // itob returns an 8-byte big endian representation of v. func itob(v uint64) []byte { b := make([]byte, 8) byteOrder.PutUint64(b, v) return b } + +// UnmarshalProtocolVersion attempts to unmarshal a byte slice to a +// ProtocolVersion value. If the unmarshal fails, ProtocolVersionUnrecorded is +// returned along with an error. +func UnmarshalProtocolVersion(b []byte) (ProtocolVersion, error) { + if b == nil { + return ProtocolVersionUnrecorded, nil + } + + if len(b) != 4 { + return ProtocolVersionUnrecorded, + fmt.Errorf("invalid size: %v", len(b)) + } + + version := ProtocolVersion(byteOrder.Uint32(b)) + if !version.Valid() { + return ProtocolVersionUnrecorded, + fmt.Errorf("invalid protocol version: %v", version) + } + + return version, nil +} + +// MarshalProtocolVersion marshals a ProtocolVersion value to a byte slice. +func MarshalProtocolVersion(version ProtocolVersion) []byte { + var versionBytes [4]byte + byteOrder.PutUint32(versionBytes[:], uint32(version)) + + return versionBytes[:] +} diff --git a/loopdb/codec_test.go b/loopdb/codec_test.go new file mode 100644 index 0000000..5d58e86 --- /dev/null +++ b/loopdb/codec_test.go @@ -0,0 +1,53 @@ +package loopdb + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestProtocolVersionMarshalUnMarshal tests that marshalling and unmarshalling +// looprpc.ProtocolVersion works correctly. +func TestProtocolVersionMarshalUnMarshal(t *testing.T) { + t.Parallel() + + testVersions := [...]ProtocolVersion{ + ProtocolVersionLegacy, + ProtocolVersionMultiLoopOut, + ProtocolVersionSegwitLoopIn, + ProtocolVersionPreimagePush, + ProtocolVersionUserExpiryLoopOut, + } + + bogusVersion := []byte{0xFF, 0xFF, 0xFF, 0xFF} + invalidSlice := []byte{0xFF, 0xFF, 0xFF} + + for i := 0; i < len(testVersions); i++ { + testVersion := testVersions[i] + + // Test that unmarshal(marshal(v)) == v. + version, err := UnmarshalProtocolVersion( + MarshalProtocolVersion(testVersion), + ) + require.NoError(t, err) + require.Equal(t, testVersion, version) + + // Test that unmarshalling a nil slice returns the default + // version along with no error. + version, err = UnmarshalProtocolVersion(nil) + require.NoError(t, err) + require.Equal(t, ProtocolVersionUnrecorded, version) + + // Test that unmarshalling an unknown version returns the + // default version along with an error. + version, err = UnmarshalProtocolVersion(bogusVersion) + require.Error(t, err, "expected invalid version") + require.Equal(t, ProtocolVersionUnrecorded, version) + + // Test that unmarshalling an invalid slice returns the + // default version along with an error. + version, err = UnmarshalProtocolVersion(invalidSlice) + require.Error(t, err, "expected invalid size") + require.Equal(t, ProtocolVersionUnrecorded, version) + } +} diff --git a/loopdb/loop.go b/loopdb/loop.go index 51ce917..f10f81b 100644 --- a/loopdb/loop.go +++ b/loopdb/loop.go @@ -46,6 +46,10 @@ type SwapContract struct { // Label contains an optional label for the swap. Label string + + // ProtocolVersion stores the protocol version when the swap was + // created. + ProtocolVersion ProtocolVersion } // Loop contains fields shared between LoopIn and LoopOut diff --git a/loopdb/protocol_version.go b/loopdb/protocol_version.go new file mode 100644 index 0000000..1fc7f54 --- /dev/null +++ b/loopdb/protocol_version.go @@ -0,0 +1,76 @@ +package loopdb + +import ( + "math" + + "github.com/lightninglabs/loop/looprpc" +) + +// ProtocolVersion represents the protocol version (declared on rpc level) that +// the client declared to us. +type ProtocolVersion uint32 + +const ( + // ProtocolVersionLegacy indicates that the client is a legacy version + // that did not report its protocol version. + ProtocolVersionLegacy ProtocolVersion = 0 + + // ProtocolVersionMultiLoopOut indicates that the client supports multi + // loop out. + ProtocolVersionMultiLoopOut ProtocolVersion = 1 + + // ProtocolVersionSegwitLoopIn indicates that the client supports segwit + // loop in. + ProtocolVersionSegwitLoopIn ProtocolVersion = 2 + + // ProtocolVersionPreimagePush indicates that the client will push loop + // out preimages to the sever to speed up claim. + ProtocolVersionPreimagePush ProtocolVersion = 3 + + // ProtocolVersionUserExpiryLoopOut indicates that the client will + // propose a cltv expiry height for loop out. + ProtocolVersionUserExpiryLoopOut ProtocolVersion = 4 + + // ProtocolVersionUnrecorded is set for swaps were created before we + // started saving protocol version with swaps. + ProtocolVersionUnrecorded ProtocolVersion = math.MaxUint32 + + // CurrentRpcProtocolVersion defines the version of the RPC protocol + // that is currently supported by the loop client. + CurrentRPCProtocolVersion = looprpc.ProtocolVersion_USER_EXPIRY_LOOP_OUT + + // CurrentInteranlProtocolVersionInternal defines the RPC current + // protocol in the internal representation. + CurrentInternalProtocolVersion = ProtocolVersion(CurrentRPCProtocolVersion) +) + +// Valid returns true if the value of the ProtocolVersion is valid. +func (p ProtocolVersion) Valid() bool { + return p <= CurrentInternalProtocolVersion +} + +// String returns the string representation of a protocol version. +func (p ProtocolVersion) String() string { + switch p { + case ProtocolVersionUnrecorded: + return "Unrecorded" + + case ProtocolVersionLegacy: + return "Legacy" + + case ProtocolVersionMultiLoopOut: + return "Multi Loop Out" + + case ProtocolVersionSegwitLoopIn: + return "Segwit Loop In" + + case ProtocolVersionPreimagePush: + return "Preimage Push" + + case ProtocolVersionUserExpiryLoopOut: + return "User Expiry Loop Out" + + default: + return "Unknown" + } +} diff --git a/loopdb/protocol_version_test.go b/loopdb/protocol_version_test.go new file mode 100644 index 0000000..7f8008f --- /dev/null +++ b/loopdb/protocol_version_test.go @@ -0,0 +1,47 @@ +package loopdb + +import ( + "testing" + + "github.com/lightninglabs/loop/looprpc" + "github.com/stretchr/testify/require" +) + +// TestProtocolVersionSanity tests that protocol versions are sane, meaning +// we always keep our stored protocol version in sync with the RPC protocol +// version except for the unrecorded version. +func TestProtocolVersionSanity(t *testing.T) { + t.Parallel() + + versions := [...]ProtocolVersion{ + ProtocolVersionLegacy, + ProtocolVersionMultiLoopOut, + ProtocolVersionSegwitLoopIn, + ProtocolVersionPreimagePush, + ProtocolVersionUserExpiryLoopOut, + } + + rpcVersions := [...]looprpc.ProtocolVersion{ + looprpc.ProtocolVersion_LEGACY, + looprpc.ProtocolVersion_MULTI_LOOP_OUT, + looprpc.ProtocolVersion_NATIVE_SEGWIT_LOOP_IN, + looprpc.ProtocolVersion_PREIMAGE_PUSH_LOOP_OUT, + looprpc.ProtocolVersion_USER_EXPIRY_LOOP_OUT, + } + + require.Equal(t, len(versions), len(rpcVersions)) + for i, version := range versions { + require.Equal(t, uint32(version), uint32(rpcVersions[i])) + } + + // Finally test that the current version contants are up to date + require.Equal(t, + CurrentInternalProtocolVersion, + versions[len(versions)-1], + ) + + require.Equal(t, + uint32(CurrentInternalProtocolVersion), + uint32(CurrentRPCProtocolVersion), + ) +} diff --git a/loopdb/store.go b/loopdb/store.go index adb88aa..d4651f6 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -69,6 +69,15 @@ var ( // value: string label labelKey = []byte("label") + // protocolVersionKey is used to optionally store the protocol version + // for the serialized swap contract. It is nested within the sub-bucket + // for each active swap. + // + // path: loopInBucket/loopOutBucket -> swapBucket[hash] -> protocolVersionKey + // + // value: protocol version as specified in server.proto + protocolVersionKey = []byte("protocol-version") + // outgoingChanSetKey is the key that stores a list of channel ids that // restrict the loop out swap payment. // @@ -276,6 +285,18 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return err } + // Try to unmarshal the protocol version for the swap. + // If the protocol version is not stored (which is + // the case for old clients), we'll assume the + // ProtocolVersionUnrecorded instead. + contract.ProtocolVersion, err = + UnmarshalProtocolVersion( + swapBucket.Get(protocolVersionKey), + ) + if err != nil { + return err + } + loop := LoopOut{ Loop: Loop{ Events: updates, @@ -401,6 +422,18 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { return err } + // Try to unmarshal the protocol version for the swap. + // If the protocol version is not stored (which is + // the case for old clients), we'll assume the + // ProtocolVersionUnrecorded instead. + contract.ProtocolVersion, err = + UnmarshalProtocolVersion( + swapBucket.Get(protocolVersionKey), + ) + if err != nil { + return err + } + loop := LoopIn{ Loop: Loop{ Events: updates, @@ -512,6 +545,14 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, return err } + // Store the current protocol version. + err = swapBucket.Put(protocolVersionKey, + MarshalProtocolVersion(swap.ProtocolVersion), + ) + if err != nil { + return err + } + // Finally, we'll create an empty updates bucket for this swap // to track any future updates to the swap itself. _, err = swapBucket.CreateBucket(updatesBucketKey) @@ -550,6 +591,14 @@ func (s *boltSwapStore) CreateLoopIn(hash lntypes.Hash, return err } + // Store the current protocol version. + err = swapBucket.Put(protocolVersionKey, + MarshalProtocolVersion(swap.ProtocolVersion), + ) + if err != nil { + return err + } + // Write label to disk if we have one. if err := putLabel(swapBucket, swap.Label); err != nil { return err diff --git a/swap_server_client.go b/swap_server_client.go index 2af795c..ced9be5 100644 --- a/swap_server_client.go +++ b/swap_server_client.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcutil" + "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/looprpc" "github.com/lightninglabs/loop/lsat" "github.com/lightningnetwork/lnd/lntypes" @@ -23,10 +24,6 @@ import ( "google.golang.org/grpc/credentials" ) -// protocolVersion defines the version of the protocol that is currently -// supported by the loop client. -const protocolVersion = looprpc.ProtocolVersion_USER_EXPIRY_LOOP_OUT - var ( // errServerSubscriptionComplete is returned when our subscription to // server updates exits because the server has no more updates to @@ -130,7 +127,7 @@ func (s *grpcSwapServerClient) GetLoopOutTerms(ctx context.Context) ( defer rpcCancel() terms, err := s.server.LoopOutTerms(rpcCtx, &looprpc.ServerLoopOutTermsRequest{ - ProtocolVersion: protocolVersion, + ProtocolVersion: loopdb.CurrentRPCProtocolVersion, }, ) if err != nil { @@ -155,7 +152,7 @@ func (s *grpcSwapServerClient) GetLoopOutQuote(ctx context.Context, &looprpc.ServerLoopOutQuoteRequest{ Amt: uint64(amt), SwapPublicationDeadline: swapPublicationDeadline.Unix(), - ProtocolVersion: protocolVersion, + ProtocolVersion: loopdb.CurrentRPCProtocolVersion, Expiry: expiry, }, ) @@ -187,7 +184,7 @@ func (s *grpcSwapServerClient) GetLoopInTerms(ctx context.Context) ( defer rpcCancel() terms, err := s.server.LoopInTerms(rpcCtx, &looprpc.ServerLoopInTermsRequest{ - ProtocolVersion: protocolVersion, + ProtocolVersion: loopdb.CurrentRPCProtocolVersion, }, ) if err != nil { @@ -208,7 +205,7 @@ func (s *grpcSwapServerClient) GetLoopInQuote(ctx context.Context, quoteResp, err := s.server.LoopInQuote(rpcCtx, &looprpc.ServerLoopInQuoteRequest{ Amt: uint64(amt), - ProtocolVersion: protocolVersion, + ProtocolVersion: loopdb.CurrentRPCProtocolVersion, }, ) if err != nil { @@ -234,7 +231,7 @@ func (s *grpcSwapServerClient) NewLoopOutSwap(ctx context.Context, Amt: uint64(amount), ReceiverKey: receiverKey[:], SwapPublicationDeadline: swapPublicationDeadline.Unix(), - ProtocolVersion: protocolVersion, + ProtocolVersion: loopdb.CurrentRPCProtocolVersion, Expiry: expiry, }, ) @@ -268,7 +265,7 @@ func (s *grpcSwapServerClient) PushLoopOutPreimage(ctx context.Context, _, err := s.server.LoopOutPushPreimage(rpcCtx, &looprpc.ServerLoopOutPushPreimageRequest{ - ProtocolVersion: protocolVersion, + ProtocolVersion: loopdb.CurrentRPCProtocolVersion, Preimage: preimage[:], }, ) @@ -288,7 +285,7 @@ func (s *grpcSwapServerClient) NewLoopInSwap(ctx context.Context, Amt: uint64(amount), SenderKey: senderKey[:], SwapInvoice: swapInvoice, - ProtocolVersion: protocolVersion, + ProtocolVersion: loopdb.CurrentRPCProtocolVersion, } if lastHop != nil { req.LastHop = lastHop[:] @@ -331,7 +328,7 @@ func (s *grpcSwapServerClient) SubscribeLoopInUpdates(ctx context.Context, resp, err := s.server.SubscribeLoopInUpdates( ctx, &looprpc.SubscribeUpdatesRequest{ - ProtocolVersion: protocolVersion, + ProtocolVersion: loopdb.CurrentRPCProtocolVersion, SwapHash: hash[:], }, ) @@ -362,7 +359,7 @@ func (s *grpcSwapServerClient) SubscribeLoopOutUpdates(ctx context.Context, resp, err := s.server.SubscribeLoopOutUpdates( ctx, &looprpc.SubscribeUpdatesRequest{ - ProtocolVersion: protocolVersion, + ProtocolVersion: loopdb.CurrentRPCProtocolVersion, SwapHash: hash[:], }, )