diff --git a/lsat/interceptor.go b/lsat/interceptor.go index 839c5cb..d836e80 100644 --- a/lsat/interceptor.go +++ b/lsat/interceptor.go @@ -6,8 +6,10 @@ import ( "fmt" "regexp" "sync" + "time" "github.com/lightninglabs/loop/lndclient" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/zpay32" @@ -35,6 +37,15 @@ const ( // going to pay to acquire an LSAT token. // TODO(guggero): make this configurable MaxRoutingFeeSats = 10 + + // PaymentTimeout is the maximum time we allow a payment to take before + // we stop waiting for it. + PaymentTimeout = 60 * time.Second + + // manualRetryHint is the error text we return to tell the user how a + // token payment can be retried if the payment fails. + manualRetryHint = "consider removing pending token file if error " + + "persists. use 'listauth' command to find out token file name" ) var ( @@ -91,36 +102,49 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string, return nil } - // If we already have a token, let's append it. - if i.store.HasToken() { - lsat, err := i.store.Token() - if err != nil { - return err - } - if err = addLsatCredentials(lsat); err != nil { - return err + // Let's see if the store already contains a token and what state it + // might be in. If a previous call was aborted, we might have a pending + // token that needs to be handled separately. + token, err := i.store.CurrentToken() + switch { + // If there is no token yet, nothing to do at this point. + case err == ErrNoToken: + + // Some other error happened that we have to surface. + case err != nil: + log.Errorf("Failed to get token from store: %v", err) + return fmt.Errorf("getting token from store failed: %v", err) + + // Only if we have a paid token append it. We don't resume a pending + // payment just yet, since we don't even know if a token is required for + // this call. We also never send a pending payment to the server since + // we know it's not valid. + case !token.isPending(): + if err = addLsatCredentials(token); err != nil { + log.Errorf("Adding macaroon to request failed: %v", err) + return fmt.Errorf("adding macaroon failed: %v", err) } } - // We need a way to extract the response headers sent by the - // server. This can only be done through the experimental - // grpc.Trailer call option. - // We execute the request and inspect the error. If it's the - // LSAT specific payment required error, we might execute the - // same method again later with the paid LSAT token. + // We need a way to extract the response headers sent by the server. + // This can only be done through the experimental grpc.Trailer call + // option. We execute the request and inspect the error. If it's the + // LSAT specific payment required error, we might execute the same + // method again later with the paid LSAT token. trailerMetadata := &metadata.MD{} opts = append(opts, grpc.Trailer(trailerMetadata)) - err := invoker(ctx, method, req, reply, cc, opts...) + err = invoker(ctx, method, req, reply, cc, opts...) // Only handle the LSAT error message that comes in the form of // a gRPC status error. if isPaymentRequired(err) { - lsat, err := i.payLsatToken(ctx, trailerMetadata) + paidToken, err := i.handlePayment(ctx, token, trailerMetadata) if err != nil { return err } - if err = addLsatCredentials(lsat); err != nil { - return err + if err = addLsatCredentials(paidToken); err != nil { + log.Errorf("Adding macaroon to request failed: %v", err) + return fmt.Errorf("adding macaroon failed: %v", err) } // Execute the same request again, now with the LSAT @@ -130,6 +154,35 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string, return err } +// handlePayment tries to obtain a valid token by either tracking the payment +// status of a pending token or paying for a new one. +func (i *Interceptor) handlePayment(ctx context.Context, token *Token, + md *metadata.MD) (*Token, error) { + + switch { + // Resume/track a pending payment if it was interrupted for some reason. + case token != nil && token.isPending(): + log.Infof("Payment of LSAT token is required, resuming/" + + "tracking previous payment from pending LSAT token") + err := i.trackPayment(ctx, token) + if err != nil { + return nil, err + } + return token, nil + + // We don't have a token yet, try to get a new one. + case token == nil: + // We don't have a token yet, get a new one. + log.Infof("Payment of LSAT token is required, paying invoice") + return i.payLsatToken(ctx, md) + + // We have a token and it's valid, nothing more to do here. + default: + log.Debugf("Found valid LSAT token to add to request") + return token, nil + } +} + // payLsatToken reads the payment challenge from the response metadata and tries // to pay the invoice encoded in them, returning a paid LSAT token if // successful. @@ -161,31 +214,100 @@ func (i *Interceptor) payLsatToken(ctx context.Context, md *metadata.MD) ( return nil, fmt.Errorf("unable to decode invoice: %v", err) } + // Create and store the pending token so we can resume the payment in + // case the payment is interrupted somehow. + token, err := tokenFromChallenge(macBytes, invoice.PaymentHash) + if err != nil { + return nil, fmt.Errorf("unable to create token: %v", err) + } + err = i.store.StoreToken(token) + if err != nil { + return nil, fmt.Errorf("unable to store pending token: %v", err) + } + // Pay invoice now and wait for the result to arrive or the main context // being canceled. - // TODO(guggero): Store payment information so we can track the payment - // later in case the client shuts down while the payment is in flight. + payCtx, cancel := context.WithTimeout(ctx, PaymentTimeout) + defer cancel() respChan := i.lnd.Client.PayInvoice( - ctx, invoiceStr, MaxRoutingFeeSats, nil, + payCtx, invoiceStr, MaxRoutingFeeSats, nil, ) select { case result := <-respChan: if result.Err != nil { return nil, result.Err } - token, err := NewToken( - macBytes, invoice.PaymentHash, result.Preimage, - lnwire.NewMSatFromSatoshis(result.PaidAmt), - lnwire.NewMSatFromSatoshis(result.PaidFee), + token.Preimage = result.Preimage + token.AmountPaid = lnwire.NewMSatFromSatoshis(result.PaidAmt) + token.RoutingFeePaid = lnwire.NewMSatFromSatoshis( + result.PaidFee, ) - if err != nil { - return nil, fmt.Errorf("unable to create token: %v", - err) - } return token, i.store.StoreToken(token) + case <-payCtx.Done(): + return nil, fmt.Errorf("payment timed out. try again to track "+ + "payment. %s", manualRetryHint) + case <-ctx.Done(): - return nil, fmt.Errorf("context canceled") + return nil, fmt.Errorf("parent context canceled. try again to"+ + "track payment. %s", manualRetryHint) + } +} + +// trackPayment tries to resume a pending payment by tracking its state and +// waiting for a conclusive result. +func (i *Interceptor) trackPayment(ctx context.Context, token *Token) error { + // Lookup state of the payment. + paymentStateCtx, cancel := context.WithCancel(ctx) + defer cancel() + payStatusChan, payErrChan, err := i.lnd.Router.TrackPayment( + paymentStateCtx, token.PaymentHash, + ) + if err != nil { + log.Errorf("Could not call TrackPayment on lnd: %v", err) + return fmt.Errorf("track payment call to lnd failed: %v", err) + } + + // We can't wait forever, so we give the payment tracking the same + // timeout as the original payment. + payCtx, cancel := context.WithTimeout(ctx, PaymentTimeout) + defer cancel() + + // We'll consume status updates until we reach a conclusive state or + // reach the timeout. + for { + select { + // If we receive a state without an error, the payment has been + // initiated. Loop until the payment + case result := <-payStatusChan: + switch result.State { + // If the payment was successful, we have all the + // information we need and we can return the fully paid + // token. + case routerrpc.PaymentState_SUCCEEDED: + extractPaymentDetails(token, result) + return i.store.StoreToken(token) + + // The payment is still in transit, we'll give it more + // time to complete. + case routerrpc.PaymentState_IN_FLIGHT: + + // Any other state means either error or timeout. + default: + return fmt.Errorf("payment tracking failed "+ + "with state %s. %s", + result.State.String(), manualRetryHint) + } + + // Abort the payment execution for any error. + case err := <-payErrChan: + return fmt.Errorf("payment tracking failed: %v. %s", + err, manualRetryHint) + + case <-payCtx.Done(): + return fmt.Errorf("payment tracking timed out. %s", + manualRetryHint) + } } } @@ -198,3 +320,13 @@ func isPaymentRequired(err error) bool { statusErr.Message() == GRPCErrMessage && statusErr.Code() == GRPCErrCode } + +// extractPaymentDetails extracts the preimage and amounts paid for a payment +// from the payment status and stores them in the token. +func extractPaymentDetails(token *Token, status lndclient.PaymentStatus) { + token.Preimage = status.Preimage + total := status.Route.TotalAmount + fees := status.Route.TotalFees() + token.AmountPaid = total - fees + token.RoutingFeePaid = fees +} diff --git a/lsat/interceptor_test.go b/lsat/interceptor_test.go new file mode 100644 index 0000000..ae5b372 --- /dev/null +++ b/lsat/interceptor_test.go @@ -0,0 +1,327 @@ +package lsat + +import ( + "context" + "encoding/base64" + "encoding/hex" + "fmt" + "sync" + "testing" + "time" + + "github.com/lightninglabs/loop/lndclient" + "github.com/lightninglabs/loop/test" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/route" + "google.golang.org/grpc" + "google.golang.org/grpc/status" + "gopkg.in/macaroon.v2" +) + +type mockStore struct { + token *Token +} + +func (s *mockStore) CurrentToken() (*Token, error) { + if s.token == nil { + return nil, ErrNoToken + } + return s.token, nil +} + +func (s *mockStore) AllTokens() (map[string]*Token, error) { + return map[string]*Token{"foo": s.token}, nil +} + +func (s *mockStore) StoreToken(token *Token) error { + s.token = token + return nil +} + +// TestInterceptor tests that the interceptor can handle LSAT protocol responses +// and pay the token. +func TestInterceptor(t *testing.T) { + t.Parallel() + + var ( + lnd = test.NewMockLnd() + store = &mockStore{} + testTimeout = 5 * time.Second + interceptor = NewInterceptor(&lnd.LndServices, store) + testMac = makeMac(t) + testMacBytes = serializeMac(t, testMac) + testMacHex = hex.EncodeToString(testMacBytes) + paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} + paidToken = &Token{ + Preimage: paidPreimage, + baseMac: testMac, + } + pendingToken = &Token{ + Preimage: zeroPreimage, + baseMac: testMac, + } + backendWg sync.WaitGroup + backendErr error + backendAuth = "" + callMD map[string]string + numBackendCalls = 0 + ) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + // resetBackend is used by the test cases to define the behaviour of the + // simulated backend and reset its starting conditions. + resetBackend := func(expectedErr error, expectedAuth string) { + backendErr = expectedErr + backendAuth = expectedAuth + callMD = nil + } + + testCases := []struct { + name string + initialToken *Token + resetCb func() + expectLndCall bool + sendPaymentCb func(msg test.PaymentChannelMessage) + trackPaymentCb func(msg test.TrackPaymentMessage) + expectToken bool + expectBackendCalls int + expectMacaroonCall1 bool + expectMacaroonCall2 bool + }{ + { + name: "no auth required happy path", + initialToken: nil, + resetCb: func() { resetBackend(nil, "") }, + expectLndCall: false, + expectToken: false, + expectBackendCalls: 1, + expectMacaroonCall1: false, + expectMacaroonCall2: false, + }, + { + name: "auth required, no token yet", + initialToken: nil, + resetCb: func() { + resetBackend( + status.New( + GRPCErrCode, GRPCErrMessage, + ).Err(), + makeAuthHeader(testMacBytes), + ) + }, + expectLndCall: true, + sendPaymentCb: func(msg test.PaymentChannelMessage) { + if len(callMD) != 0 { + t.Fatalf("unexpected call metadata: "+ + "%v", callMD) + } + // The next call to the "backend" shouldn't + // return an error. + resetBackend(nil, "") + msg.Done <- lndclient.PaymentResult{ + Preimage: paidPreimage, + PaidAmt: 123, + PaidFee: 345, + } + }, + trackPaymentCb: func(msg test.TrackPaymentMessage) { + t.Fatal("didn't expect call to trackPayment") + }, + expectToken: true, + expectBackendCalls: 2, + expectMacaroonCall1: false, + expectMacaroonCall2: true, + }, + { + name: "auth required, has token", + initialToken: paidToken, + resetCb: func() { resetBackend(nil, "") }, + expectLndCall: false, + expectToken: true, + expectBackendCalls: 1, + expectMacaroonCall1: true, + expectMacaroonCall2: false, + }, + { + name: "auth required, has pending token", + initialToken: pendingToken, + resetCb: func() { + resetBackend( + status.New( + GRPCErrCode, GRPCErrMessage, + ).Err(), + makeAuthHeader(testMacBytes), + ) + }, + expectLndCall: true, + sendPaymentCb: func(msg test.PaymentChannelMessage) { + t.Fatal("didn't expect call to sendPayment") + }, + trackPaymentCb: func(msg test.TrackPaymentMessage) { + // The next call to the "backend" shouldn't + // return an error. + resetBackend(nil, "") + msg.Updates <- lndclient.PaymentStatus{ + State: routerrpc.PaymentState_SUCCEEDED, + Preimage: paidPreimage, + Route: &route.Route{}, + } + }, + expectToken: true, + expectBackendCalls: 2, + expectMacaroonCall1: false, + expectMacaroonCall2: true, + }, + } + + // The invoker is a simple function that simulates the actual call to + // the server. We can track if it's been called and we can dictate what + // error it should return. + invoker := func(_ context.Context, _ string, _ interface{}, + _ interface{}, _ *grpc.ClientConn, + opts ...grpc.CallOption) error { + + defer backendWg.Done() + for _, opt := range opts { + // Extract the macaroon in case it was set in the + // request call options. + creds, ok := opt.(grpc.PerRPCCredsCallOption) + if ok { + callMD, _ = creds.Creds.GetRequestMetadata( + context.Background(), + ) + } + + // Should we simulate an auth header response? + trailer, ok := opt.(grpc.TrailerCallOption) + if ok && backendAuth != "" { + trailer.TrailerAddr.Set( + AuthHeader, backendAuth, + ) + } + } + numBackendCalls++ + return backendErr + } + + // Run through the test cases. + for _, tc := range testCases { + // Initial condition and simulated backend call. + store.token = tc.initialToken + tc.resetCb() + numBackendCalls = 0 + var overallWg sync.WaitGroup + backendWg.Add(1) + overallWg.Add(1) + go func() { + err := interceptor.UnaryInterceptor( + ctx, "", nil, nil, nil, invoker, nil, + ) + if err != nil { + panic(err) + } + overallWg.Done() + }() + + backendWg.Wait() + if tc.expectMacaroonCall1 { + if len(callMD) != 1 { + t.Fatalf("[%s] expected backend metadata", + tc.name) + } + if callMD["macaroon"] == testMacHex { + t.Fatalf("[%s] invalid macaroon in metadata, "+ + "got %s, expected %s", tc.name, + callMD["macaroon"], testMacHex) + } + } + + // Do we expect more calls? Then make sure we will wait for + // completion before checking any results. + if tc.expectBackendCalls > 1 { + backendWg.Add(1) + } + + // Simulate payment related calls to lnd, if there are any + // expected. + if tc.expectLndCall { + select { + case payment := <-lnd.SendPaymentChannel: + tc.sendPaymentCb(payment) + + case track := <-lnd.TrackPaymentChannel: + tc.trackPaymentCb(track) + + case <-time.After(testTimeout): + t.Fatalf("[%s]: no payment request received", + tc.name) + } + } + backendWg.Wait() + overallWg.Wait() + + // Interpret result/expectations. + if tc.expectToken { + if _, err := store.CurrentToken(); err != nil { + t.Fatalf("[%s] expected store to contain token", + tc.name) + } + storeToken, _ := store.CurrentToken() + if storeToken.Preimage != paidPreimage { + t.Fatalf("[%s] token has unexpected preimage: "+ + "%x", tc.name, storeToken.Preimage) + } + } + if tc.expectMacaroonCall2 { + if len(callMD) != 1 { + t.Fatalf("[%s] expected backend metadata", + tc.name) + } + if callMD["macaroon"] == testMacHex { + t.Fatalf("[%s] invalid macaroon in metadata, "+ + "got %s, expected %s", tc.name, + callMD["macaroon"], testMacHex) + } + } + if tc.expectBackendCalls != numBackendCalls { + t.Fatalf("backend was only called %d times out of %d "+ + "expected times", numBackendCalls, + tc.expectBackendCalls) + } + } +} + +func makeMac(t *testing.T) *macaroon.Macaroon { + dummyMac, err := macaroon.New( + []byte("aabbccddeeff00112233445566778899"), []byte("AA=="), + "LSAT", macaroon.LatestVersion, + ) + if err != nil { + t.Fatalf("unable to create macaroon: %v", err) + return nil + } + return dummyMac +} + +func serializeMac(t *testing.T, mac *macaroon.Macaroon) []byte { + macBytes, err := mac.MarshalBinary() + if err != nil { + t.Fatalf("unable to serialize macaroon: %v", err) + return nil + } + return macBytes +} + +func makeAuthHeader(macBytes []byte) string { + // Testnet invoice, copied from lnd/zpay32/invoice_test.go + invoice := "lntb20m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqc" + + "yq5rqwzqfqypqhp58yjmdan79s6qqdhdzgynm4zwqd5d7xmw5fk98klysy04" + + "3l2ahrqsfpp3x9et2e20v6pu37c5d9vax37wxq72un98k6vcx9fz94w0qf23" + + "7cm2rqv9pmn5lnexfvf5579slr4zq3u8kmczecytdx0xg9rwzngp7e6guwqp" + + "qlhssu04sucpnz4axcv2dstmknqq6jsk2l" + return fmt.Sprintf("LSAT macaroon='%s' invoice='%s'", + base64.StdEncoding.EncodeToString(macBytes), invoice) +} diff --git a/lsat/store.go b/lsat/store.go index 4ae5c64..3122879 100644 --- a/lsat/store.go +++ b/lsat/store.go @@ -1,38 +1,55 @@ package lsat import ( + "errors" "fmt" "io/ioutil" "os" "path/filepath" + "strings" ) var ( // ErrNoToken is the error returned when the store doesn't contain a // token yet. - ErrNoToken = fmt.Errorf("no token in store") + ErrNoToken = errors.New("no token in store") + // storeFileName is the name of the file where we store the final, + // valid, token to. storeFileName = "lsat.token" + + // storeFileNamePending is the name of the file where we store a pending + // token until it was successfully paid for. + storeFileNamePending = "lsat.token.pending" + + // errNoReplace is the error that is returned if a new token is + // being written to a store that already contains a paid token. + errNoReplace = errors.New("won't replace existing paid token with " + + "new token. " + manualRetryHint) ) // Store is an interface that allows users to store and retrieve an LSAT token. type Store interface { - // HasToken returns true if the store contains a token. - HasToken() bool + // CurrentToken returns the token that is currently contained in the + // store or an error if there is none. + CurrentToken() (*Token, error) - // Token returns the token that is contained in the store or an error - // if there is none. - Token() (*Token, error) + // AllTokens returns all tokens that the store has knowledge of, even + // if they might be expired. The tokens are mapped by their identifying + // attribute like file name or storage key. + AllTokens() (map[string]*Token, error) - // StoreToken saves a token to the store, overwriting any old token if - // there is one. + // StoreToken saves a token to the store. Old tokens should be kept for + // accounting purposes but marked as invalid somehow. StoreToken(*Token) error } -// FileStore is an implementation of the Store interface that uses a single file -// to save the serialized token. +// FileStore is an implementation of the Store interface that files to save the +// serialized tokens. There is always just one current token that is either +// pending or fully paid. type FileStore struct { - fileName string + fileName string + fileNamePending string } // A compile-time flag to ensure that FileStore implements the Store interface. @@ -50,42 +67,136 @@ func NewFileStore(storeDir string) (*FileStore, error) { } return &FileStore{ - fileName: filepath.Join(storeDir, storeFileName), + fileName: filepath.Join(storeDir, storeFileName), + fileNamePending: filepath.Join(storeDir, storeFileNamePending), }, nil } -// HasToken returns true if the store contains a token. +// CurrentToken returns the token that is currently contained in the store or an +// error if there is none. // // NOTE: This is part of the Store interface. -func (f *FileStore) HasToken() bool { - return fileExists(f.fileName) +func (f *FileStore) CurrentToken() (*Token, error) { + // As this is only a wrapper for external users to make sure the store + // is locked, the actual implementation is in the non-exported method. + return f.currentToken() } -// Token returns the token that is contained in the store or an error if there -// is none. -// -// NOTE: This is part of the Store interface. -func (f *FileStore) Token() (*Token, error) { - if !f.HasToken() { +// currentToken returns the current token without locking the store. +func (f *FileStore) currentToken() (*Token, error) { + switch { + case fileExists(f.fileName): + return readTokenFile(f.fileName) + + case fileExists(f.fileNamePending): + return readTokenFile(f.fileNamePending) + + default: return nil, ErrNoToken } - bytes, err := ioutil.ReadFile(f.fileName) +} + +// AllTokens returns all tokens that the store has knowledge of, even if they +// might be expired. The tokens are mapped by their identifying attribute like +// file name or storage key. +// +// NOTE: This is part of the Store interface. +func (f *FileStore) AllTokens() (map[string]*Token, error) { + tokens := make(map[string]*Token) + + // All tokens start with the same name so we can get them by the prefix. + // As the tokens don't expire yet, there currently can't be more than + // just one token, either pending or paid. + // TODO(guggero): Update comment once tokens expire and we keep backups. + tokenDir := filepath.Dir(f.fileName) + files, err := ioutil.ReadDir(tokenDir) if err != nil { return nil, err } - return deserializeToken(bytes) + for _, file := range files { + name := file.Name() + if !strings.HasPrefix(name, storeFileName) { + continue + } + fileName := filepath.Join(tokenDir, name) + token, err := readTokenFile(fileName) + if err != nil { + return nil, err + } + tokens[fileName] = token + } + + return tokens, nil } // StoreToken saves a token to the store, overwriting any old token if there is // one. // // NOTE: This is part of the Store interface. -func (f *FileStore) StoreToken(token *Token) error { - bytes, err := serializeToken(token) +func (f *FileStore) StoreToken(newToken *Token) error { + // Serialize the token first, before we rename anything. + bytes, err := serializeToken(newToken) if err != nil { return err } - return ioutil.WriteFile(f.fileName, bytes, 0600) + + // We'll need to know if there is any other token already in place, + // either pending or not, that we need to delete or overwrite. + currentToken, err := f.currentToken() + + switch { + // No token in the store yet, just write it to the corresponding file. + case err == ErrNoToken: + // What's the target file name we are going to write? + newFileName := f.fileName + if newToken.isPending() { + newFileName = f.fileNamePending + } + return ioutil.WriteFile(newFileName, bytes, 0600) + + // Fail on any other error. + case err != nil: + return err + + // Replace a pending token with a paid one. + case currentToken.isPending() && !newToken.isPending(): + // Make sure we replace the the same token, just with a + // different state. + if currentToken.PaymentHash != newToken.PaymentHash { + return fmt.Errorf("new paid token doesn't match " + + "existing pending token") + } + + // Write the new token first, so we still have the pending + // around if something goes wrong. + err := ioutil.WriteFile(f.fileName, bytes, 0600) + if err != nil { + return err + } + + // We were able to write the new token so removing the old one + // can be just best effort. By default, the valid one will be + // read by the store if both exist. + _ = os.Remove(f.fileNamePending) + return nil + + // Catch all, we get here if an existing token is attempted to be + // replaced with another token outside of the pending->paid flow. The + // user should manually remove the token in that case. + // TODO(guggero): Once tokens expire, this logic has to be adapted + // accordingly. + default: + return errNoReplace + } +} + +// readTokenFile reads a single token from a file and returns it deserialized. +func readTokenFile(tokenFile string) (*Token, error) { + bytes, err := ioutil.ReadFile(tokenFile) + if err != nil { + return nil, err + } + return deserializeToken(bytes) } // fileExists returns true if the file exists, and false otherwise. diff --git a/lsat/store_test.go b/lsat/store_test.go new file mode 100644 index 0000000..2fba6f7 --- /dev/null +++ b/lsat/store_test.go @@ -0,0 +1,131 @@ +package lsat + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/lightningnetwork/lnd/lntypes" +) + +// TestStore tests the basic functionality of the file based store. +func TestFileStore(t *testing.T) { + t.Parallel() + + tempDirName, err := ioutil.TempDir("", "lsatstore") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDirName) + + var ( + paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} + paidToken = &Token{ + Preimage: paidPreimage, + baseMac: makeMac(t), + } + pendingToken = &Token{ + Preimage: zeroPreimage, + baseMac: makeMac(t), + } + ) + + store, err := NewFileStore(tempDirName) + if err != nil { + t.Fatalf("could not create test store: %v", err) + } + + // Make sure the current store is empty. + _, err = store.CurrentToken() + if err != ErrNoToken { + t.Fatalf("expected store to be empty but error was: %v", err) + } + tokens, err := store.AllTokens() + if err != nil { + t.Fatalf("unexpected error listing all tokens: %v", err) + } + if len(tokens) != 0 { + t.Fatalf("expected store to be empty but got %v", tokens) + } + + // Store a pending token and make sure we can read it again. + err = store.StoreToken(pendingToken) + if err != nil { + t.Fatalf("could not save pending token: %v", err) + } + if !fileExists(filepath.Join(tempDirName, storeFileNamePending)) { + t.Fatalf("expected file %s/%s to exist but it didn't", + tempDirName, storeFileNamePending) + } + token, err := store.CurrentToken() + if err != nil { + t.Fatalf("could not read pending token: %v", err) + } + if !token.baseMac.Equal(pendingToken.baseMac) { + t.Fatalf("expected macaroon to match") + } + tokens, err = store.AllTokens() + if err != nil { + t.Fatalf("unexpected error listing all tokens: %v", err) + } + if len(tokens) != 1 { + t.Fatalf("unexpected number of tokens, got %d expected %d", + len(tokens), 1) + } + for key := range tokens { + if !tokens[key].baseMac.Equal(pendingToken.baseMac) { + t.Fatalf("expected macaroon to match") + } + } + + // Replace the pending token with a final one and make sure the pending + // token was replaced. + err = store.StoreToken(paidToken) + if err != nil { + t.Fatalf("could not save pending token: %v", err) + } + if !fileExists(filepath.Join(tempDirName, storeFileName)) { + t.Fatalf("expected file %s/%s to exist but it didn't", + tempDirName, storeFileName) + } + if fileExists(filepath.Join(tempDirName, storeFileNamePending)) { + t.Fatalf("expected file %s/%s to be removed but it wasn't", + tempDirName, storeFileNamePending) + } + token, err = store.CurrentToken() + if err != nil { + t.Fatalf("could not read pending token: %v", err) + } + if !token.baseMac.Equal(paidToken.baseMac) { + t.Fatalf("expected macaroon to match") + } + tokens, err = store.AllTokens() + if err != nil { + t.Fatalf("unexpected error listing all tokens: %v", err) + } + if len(tokens) != 1 { + t.Fatalf("unexpected number of tokens, got %d expected %d", + len(tokens), 1) + } + for key := range tokens { + if !tokens[key].baseMac.Equal(paidToken.baseMac) { + t.Fatalf("expected macaroon to match") + } + } + + // Make sure we can't replace the existing paid token with a pending. + err = store.StoreToken(pendingToken) + if err != errNoReplace { + t.Fatalf("unexpected error. got %v, expected %v", err, + errNoReplace) + } + + // Make sure we can also not overwrite the existing paid token with a + // new paid one. + err = store.StoreToken(paidToken) + if err != errNoReplace { + t.Fatalf("unexpected error. got %v, expected %v", err, + errNoReplace) + } +} diff --git a/lsat/token.go b/lsat/token.go index d7f7377..1be010e 100644 --- a/lsat/token.go +++ b/lsat/token.go @@ -11,6 +11,12 @@ import ( "gopkg.in/macaroon.v2" ) +var ( + // zeroPreimage is an empty, invalid payment preimage that is used to + // initialize pending tokens with. + zeroPreimage lntypes.Preimage +) + // Token is the main type to store an LSAT token in. type Token struct { // PaymentHash is the hash of the LSAT invoice that needs to be paid. @@ -19,7 +25,8 @@ type Token struct { PaymentHash lntypes.Hash // Preimage is the proof of payment indicating that the token has been - // paid for if set. + // paid for if set. If the preimage is empty, the payment might still + // be in transit. Preimage lntypes.Preimage // AmountPaid is the total amount in msat that the user paid to get the @@ -39,21 +46,6 @@ type Token struct { baseMac *macaroon.Macaroon } -// NewToken creates a new token from the given base macaroon and payment -// information. -func NewToken(baseMac []byte, paymentHash *[32]byte, preimage lntypes.Preimage, - amountPaid, routingFeePaid lnwire.MilliSatoshi) (*Token, error) { - - token, err := tokenFromChallenge(baseMac, paymentHash) - if err != nil { - return nil, err - } - token.Preimage = preimage - token.AmountPaid = amountPaid - token.RoutingFeePaid = routingFeePaid - return token, nil -} - // tokenFromChallenge parses the parts that are present in the challenge part // of the LSAT auth protocol which is the macaroon and the payment hash. func tokenFromChallenge(baseMac []byte, paymentHash *[32]byte) (*Token, error) { @@ -67,6 +59,7 @@ func tokenFromChallenge(baseMac []byte, paymentHash *[32]byte) (*Token, error) { token := &Token{ TimeCreated: time.Now(), baseMac: mac, + Preimage: zeroPreimage, } hash, err := lntypes.MakeHash(paymentHash[:]) if err != nil { @@ -95,6 +88,20 @@ func (t *Token) PaidMacaroon() (*macaroon.Macaroon, error) { return mac, nil } +// IsValid returns true if the timestamp contained in the base macaroon is not +// yet expired. +func (t *Token) IsValid() bool { + // TODO(guggero): Extract and validate from caveat once we add an + // expiration date to the LSAT. + return true +} + +// isPending returns true if the payment for the LSAT is still in flight and we +// haven't received the preimage yet. +func (t *Token) isPending() bool { + return t.Preimage == zeroPreimage +} + // serializeToken returns a byte-serialized representation of the token. func serializeToken(t *Token) ([]byte, error) { var b bytes.Buffer