diff --git a/interface.go b/interface.go index a1be6d6..46d88f0 100644 --- a/interface.go +++ b/interface.go @@ -74,6 +74,9 @@ type OutRequest struct { // Expiry is the absolute expiry height of the on-chain htlc. Expiry int32 + + // Label contains an optional label for the swap. + Label string } // Out contains the full details of a loop out request. This includes things @@ -186,6 +189,9 @@ type LoopInRequest struct { // ExternalHtlc specifies whether the htlc is published by an external // source. ExternalHtlc bool + + // Label contains an optional label for the swap. + Label string } // LoopInTerms are the server terms on which it executes loop in swaps. diff --git a/labels/labels.go b/labels/labels.go new file mode 100644 index 0000000..b2137f4 --- /dev/null +++ b/labels/labels.go @@ -0,0 +1,46 @@ +package labels + +import ( + "errors" +) + +const ( + // MaxLength is the maximum length we allow for labels. + MaxLength = 500 + + // Reserved is used as a prefix to separate labels that are created by + // loopd from those created by users. + Reserved = "[reserved]" +) + +var ( + // ErrLabelTooLong is returned when a label exceeds our length limit. + ErrLabelTooLong = errors.New("label exceeds maximum length") + + // ErrReservedPrefix is returned when a label contains the prefix + // which is reserved for internally produced labels. + ErrReservedPrefix = errors.New("label contains reserved prefix") +) + +// Validate checks that a label is of appropriate length and is not in our list +// of reserved labels. +func Validate(label string) error { + if len(label) > MaxLength { + return ErrLabelTooLong + } + + // If the label is shorter than our reserved prefix, it cannot contain + // it. + if len(label) < len(Reserved) { + return nil + } + + // Check if our label begins with our reserved prefix. We don't mind if + // it has our reserved prefix in another case, we just need to be able + // to reserve a subset of labels with this prefix. + if label[0:len(Reserved)] == Reserved { + return ErrReservedPrefix + } + + return nil +} diff --git a/labels/labels_test.go b/labels/labels_test.go new file mode 100644 index 0000000..6010d55 --- /dev/null +++ b/labels/labels_test.go @@ -0,0 +1,53 @@ +package labels + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestValidate tests validation of labels. +func TestValidate(t *testing.T) { + tests := []struct { + name string + label string + err error + }{ + { + name: "label ok", + label: "label", + err: nil, + }, + { + name: "exceeds limit", + label: strings.Repeat(" ", MaxLength+1), + err: ErrLabelTooLong, + }, + { + name: "exactly reserved prefix", + label: Reserved, + err: ErrReservedPrefix, + }, + { + name: "starts with reserved prefix", + label: fmt.Sprintf("%v test", Reserved), + err: ErrReservedPrefix, + }, + { + name: "ends with reserved prefix", + label: fmt.Sprintf("test %v", Reserved), + err: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, test.err, Validate(test.label)) + }) + } +} diff --git a/loopdb/loop.go b/loopdb/loop.go index 870df9f..51ce917 100644 --- a/loopdb/loop.go +++ b/loopdb/loop.go @@ -43,6 +43,9 @@ type SwapContract struct { // InitiationTime is the time at which the swap was initiated. InitiationTime time.Time + + // Label contains an optional label for the swap. + Label string } // Loop contains fields shared between LoopIn and LoopOut diff --git a/loopdb/loopin.go b/loopdb/loopin.go index 409d679..756656c 100644 --- a/loopdb/loopin.go +++ b/loopdb/loopin.go @@ -6,6 +6,8 @@ import ( "fmt" "time" + "github.com/coreos/bbolt" + "github.com/lightninglabs/loop/labels" "github.com/lightningnetwork/lnd/routing/route" ) @@ -24,6 +26,10 @@ type LoopInContract struct { // ExternalHtlc specifies whether the htlc is published by an external // source. ExternalHtlc bool + + // Label contains an optional label for the swap. Note that this field + // is stored separately to the rest of the contract on disk. + Label string } // LoopIn is a combination of the contract and the updates. @@ -112,6 +118,31 @@ func serializeLoopInContract(swap *LoopInContract) ( return b.Bytes(), nil } +// putLabel performs validation of a label and writes it to the bucket provided +// under the label key if it is non-zero. +func putLabel(bucket *bbolt.Bucket, label string) error { + if len(label) == 0 { + return nil + } + + if err := labels.Validate(label); err != nil { + return err + } + + return bucket.Put(labelKey, []byte(label)) +} + +// getLabel attempts to get an optional label stored under the label key in a +// bucket. If it is not present, an empty label is returned. +func getLabel(bucket *bbolt.Bucket) string { + label := bucket.Get(labelKey) + if label == nil { + return "" + } + + return string(label) +} + // deserializeLoopInContract deserializes the loop in contract from a byte slice. func deserializeLoopInContract(value []byte) (*LoopInContract, error) { r := bytes.NewReader(value) diff --git a/loopdb/store.go b/loopdb/store.go index 1d48904..b4a6fad 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -60,6 +60,15 @@ var ( // value: time || rawSwapState contractKey = []byte("contract") + // labelKey is the key that stores an optional label for the swap. If + // a swap was created before we started adding labels, or was created + // without a label, this key will not be present. + // + // path: loopInBucket/loopOutBucket -> swapBucket[hash] -> labelKey + // + // value: string label + labelKey = []byte("label") + // outgoingChanSetKey is the key that stores a list of channel ids that // restrict the loop out swap payment. // @@ -207,6 +216,9 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return err } + // Get our label for this swap, if it is present. + contract.Label = getLabel(swapBucket) + // Read the list of concatenated outgoing channel ids // that form the outgoing set. setBytes := swapBucket.Get(outgoingChanSetKey) @@ -352,6 +364,9 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { return err } + // Get our label for this swap, if it is present. + contract.Label = getLabel(swapBucket) + updates, err := deserializeUpdates(swapBucket) if err != nil { return err @@ -434,6 +449,10 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, return err } + if err := putLabel(swapBucket, swap.Label); err != nil { + return err + } + // Write the outgoing channel set. var b bytes.Buffer for _, chanID := range swap.OutgoingChanSet { @@ -447,6 +466,11 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, return err } + // Write label to disk if we have one. + if err := putLabel(swapBucket, swap.Label); err != nil { + return err + } + // Finally, we'll create an empty updates bucket for this swap // to track any future updates to the swap itself. _, err = swapBucket.CreateBucket(updatesBucketKey) @@ -485,6 +509,11 @@ func (s *boltSwapStore) CreateLoopIn(hash lntypes.Hash, return err } + // Write label to disk if we have one. + if err := putLabel(swapBucket, swap.Label); err != nil { + return err + } + // Finally, we'll create an empty updates bucket for this swap // to track any future updates to the swap itself. _, err = swapBucket.CreateBucket(updatesBucketKey) diff --git a/loopdb/store_test.go b/loopdb/store_test.go index d373876..bde8c25 100644 --- a/loopdb/store_test.go +++ b/loopdb/store_test.go @@ -83,6 +83,13 @@ func TestLoopOutStore(t *testing.T) { t.Run("two channel outgoing set", func(t *testing.T) { testLoopOutStore(t, &restrictedSwap) }) + + labelledSwap := unrestrictedSwap + labelledSwap.Label = "test label" + t.Run("labelled swap", func(t *testing.T) { + testLoopOutStore(t, &labelledSwap) + }) + } // testLoopOutStore tests the basic functionality of the current bbolt @@ -196,27 +203,6 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { // TestLoopInStore tests all the basic functionality of the current bbolt // swap store. func TestLoopInStore(t *testing.T) { - tempDirName, err := ioutil.TempDir("", "clientstore") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDirName) - - store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) - if err != nil { - t.Fatal(err) - } - - // First, verify that an empty database has no active swaps. - swaps, err := store.FetchLoopInSwaps() - if err != nil { - t.Fatal(err) - } - if len(swaps) != 0 { - t.Fatal("expected empty store") - } - - hash := sha256.Sum256(testPreimage[:]) initiationTime := time.Date(2018, 11, 1, 0, 0, 0, 0, time.UTC) // Next, we'll make a new pending swap that we'll insert into the @@ -243,6 +229,38 @@ func TestLoopInStore(t *testing.T) { ExternalHtlc: true, } + t.Run("loop in", func(t *testing.T) { + testLoopInStore(t, pendingSwap) + }) + + labelledSwap := pendingSwap + labelledSwap.Label = "test label" + t.Run("loop in with label", func(t *testing.T) { + testLoopInStore(t, labelledSwap) + }) +} + +func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { + tempDirName, err := ioutil.TempDir("", "clientstore") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDirName) + + store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) + if err != nil { + t.Fatal(err) + } + + // First, verify that an empty database has no active swaps. + swaps, err := store.FetchLoopInSwaps() + if err != nil { + t.Fatal(err) + } + if len(swaps) != 0 { + t.Fatal("expected empty store") + } + // checkSwap is a test helper function that'll assert the state of a // swap. checkSwap := func(expectedState SwapState) { @@ -269,6 +287,8 @@ func TestLoopInStore(t *testing.T) { } } + hash := sha256.Sum256(testPreimage[:]) + // If we create a new swap, then it should show up as being initialized // right after. if err := store.CreateLoopIn(hash, &pendingSwap); err != nil { diff --git a/loopin.go b/loopin.go index 703974e..89777d9 100644 --- a/loopin.go +++ b/loopin.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/lndclient" + "github.com/lightninglabs/loop/labels" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" "github.com/lightningnetwork/lnd/chainntnfs" @@ -77,6 +78,11 @@ func newLoopInSwap(globalCtx context.Context, cfg *swapConfig, currentHeight int32, request *LoopInRequest) (*loopInInitResult, error) { + // Before we start, check that the label is valid. + if err := labels.Validate(request.Label); err != nil { + return nil, err + } + // Request current server loop in terms and use these to calculate the // swap fee that we should subtract from the swap amount in the payment // request that we send to the server. @@ -165,6 +171,7 @@ func newLoopInSwap(globalCtx context.Context, cfg *swapConfig, CltvExpiry: swapResp.expiry, MaxMinerFee: request.MaxMinerFee, MaxSwapFee: request.MaxSwapFee, + Label: request.Label, }, } diff --git a/loopout.go b/loopout.go index 5e26bc4..4f92ff5 100644 --- a/loopout.go +++ b/loopout.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightninglabs/lndclient" + "github.com/lightninglabs/loop/labels" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/sweep" @@ -87,6 +88,11 @@ type loopOutInitResult struct { func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, currentHeight int32, request *OutRequest) (*loopOutInitResult, error) { + // Before we start, check that the label is valid. + if err := labels.Validate(request.Label); err != nil { + return nil, err + } + // Generate random preimage. var swapPreimage [32]byte if _, err := rand.Read(swapPreimage[:]); err != nil { @@ -154,6 +160,7 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, CltvExpiry: request.Expiry, MaxMinerFee: request.MaxMinerFee, MaxSwapFee: request.MaxSwapFee, + Label: request.Label, }, OutgoingChanSet: chanSet, }