From 79c54b93341bc8b0b7111cdb0c95049eea48768d Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Thu, 6 Feb 2020 10:49:02 +0100 Subject: [PATCH] lsat: add stream interceptor --- lsat/interceptor.go | 46 +++++ lsat/interceptor_test.go | 397 ++++++++++++++++++++++----------------- lsat/store_test.go | 4 +- 3 files changed, 273 insertions(+), 174 deletions(-) diff --git a/lsat/interceptor.go b/lsat/interceptor.go index 1aec126..d0adc15 100644 --- a/lsat/interceptor.go +++ b/lsat/interceptor.go @@ -145,6 +145,52 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string, return invoker(rpcCtx2, method, req, reply, cc, iCtx.opts...) } +// StreamInterceptor is an interceptor method that can be used directly by gRPC +// for streaming calls. If the store contains a token, it is attached as +// credentials to every stream establishment call before patching it through. +// The response error is also intercepted for every initial stream initiation. +// If there is an error returned and it is indicating a payment challenge, a +// token is acquired and paid for automatically. The original request is then +// repeated back to the server, now with the new token attached. +func (i *Interceptor) StreamInterceptor(ctx context.Context, + desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, + error) { + + // To avoid paying for a token twice if two parallel requests are + // happening, we require an exclusive lock here. + i.lock.Lock() + defer i.lock.Unlock() + + // Create the context that we'll use to initiate the real request. This + // contains the means to extract response headers and possibly also an + // auth token, if we already have paid for one. + iCtx, err := i.newInterceptContext(ctx, opts) + if err != nil { + return nil, err + } + + // Try establishing the stream now. If anything goes wrong, we only + // handle the LSAT error message that comes in the form of a gRPC status + // error. The context of a stream will be used for the whole lifetime of + // it, so we can't really clamp down on the initial call with a timeout. + stream, err := streamer(ctx, desc, cc, method, iCtx.opts...) + if !isPaymentRequired(err) { + return stream, err + } + + // Find out if we need to pay for a new token or perhaps resume + // a previously aborted payment. + err = i.handlePayment(iCtx) + if err != nil { + return nil, err + } + + // Execute the same request again, now with the LSAT token added + // as an RPC credential. + return streamer(ctx, desc, cc, method, iCtx.opts...) +} + // newInterceptContext creates the initial intercept context that can capture // metadata from the server and sends the local token to the server if one // already exists. diff --git a/lsat/interceptor_test.go b/lsat/interceptor_test.go index 42a26a3..cf7bac2 100644 --- a/lsat/interceptor_test.go +++ b/lsat/interceptor_test.go @@ -19,6 +19,21 @@ import ( "gopkg.in/macaroon.v2" ) +type interceptTestCase struct { + name string + initialPreimage *lntypes.Preimage + interceptor *Interceptor + resetCb func() + expectLndCall bool + sendPaymentCb func(*testing.T, test.PaymentChannelMessage) + trackPaymentCb func(*testing.T, test.TrackPaymentMessage) + expectToken bool + expectInterceptErr string + expectBackendCalls int + expectMacaroonCall1 bool + expectMacaroonCall2 bool +} + type mockStore struct { token *Token } @@ -39,66 +54,29 @@ func (s *mockStore) StoreToken(token *Token) error { return nil } -// TestInterceptor tests that the interceptor can handle LSAT protocol responses -// and pay the token. -func TestInterceptor(t *testing.T) { - t.Parallel() - - var ( - lnd = test.NewMockLnd() - store = &mockStore{} - testTimeout = 5 * time.Second - interceptor = NewInterceptor( - &lnd.LndServices, store, testTimeout, - DefaultMaxCostSats, DefaultMaxRoutingFeeSats, - ) - testMac = makeMac(t) - testMacBytes = serializeMac(t, testMac) - testMacHex = hex.EncodeToString(testMacBytes) - paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} - paidToken = &Token{ - Preimage: paidPreimage, - baseMac: testMac, - } - pendingToken = &Token{ - Preimage: zeroPreimage, - baseMac: testMac, - } - backendWg sync.WaitGroup - backendErr error - backendAuth = "" - callMD map[string]string - numBackendCalls = 0 +var ( + lnd = test.NewMockLnd() + store = &mockStore{} + testTimeout = 5 * time.Second + interceptor = NewInterceptor( + &lnd.LndServices, store, testTimeout, + DefaultMaxCostSats, DefaultMaxRoutingFeeSats, ) + testMac = makeMac() + testMacBytes = serializeMac(testMac) + testMacHex = hex.EncodeToString(testMacBytes) + paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} + backendErr error + backendAuth = "" + callMD map[string]string + numBackendCalls = 0 + overallWg sync.WaitGroup + backendWg sync.WaitGroup - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - // resetBackend is used by the test cases to define the behaviour of the - // simulated backend and reset its starting conditions. - resetBackend := func(expectedErr error, expectedAuth string) { - backendErr = expectedErr - backendAuth = expectedAuth - callMD = nil - } - - testCases := []struct { - name string - initialToken *Token - interceptor *Interceptor - resetCb func() - expectLndCall bool - sendPaymentCb func(msg test.PaymentChannelMessage) - trackPaymentCb func(msg test.TrackPaymentMessage) - expectToken bool - expectInterceptErr string - expectBackendCalls int - expectMacaroonCall1 bool - expectMacaroonCall2 bool - }{ + testCases = []interceptTestCase{ { name: "no auth required happy path", - initialToken: nil, + initialPreimage: nil, interceptor: interceptor, resetCb: func() { resetBackend(nil, "") }, expectLndCall: false, @@ -108,9 +86,9 @@ func TestInterceptor(t *testing.T) { expectMacaroonCall2: false, }, { - name: "auth required, no token yet", - initialToken: nil, - interceptor: interceptor, + name: "auth required, no token yet", + initialPreimage: nil, + interceptor: interceptor, resetCb: func() { resetBackend( status.New( @@ -120,7 +98,9 @@ func TestInterceptor(t *testing.T) { ) }, expectLndCall: true, - sendPaymentCb: func(msg test.PaymentChannelMessage) { + sendPaymentCb: func(t *testing.T, + msg test.PaymentChannelMessage) { + if len(callMD) != 0 { t.Fatalf("unexpected call metadata: "+ "%v", callMD) @@ -134,7 +114,9 @@ func TestInterceptor(t *testing.T) { PaidFee: 345, } }, - trackPaymentCb: func(msg test.TrackPaymentMessage) { + trackPaymentCb: func(t *testing.T, + msg test.TrackPaymentMessage) { + t.Fatal("didn't expect call to trackPayment") }, expectToken: true, @@ -144,7 +126,7 @@ func TestInterceptor(t *testing.T) { }, { name: "auth required, has token", - initialToken: paidToken, + initialPreimage: &paidPreimage, interceptor: interceptor, resetCb: func() { resetBackend(nil, "") }, expectLndCall: false, @@ -154,9 +136,9 @@ func TestInterceptor(t *testing.T) { expectMacaroonCall2: false, }, { - name: "auth required, has pending token", - initialToken: pendingToken, - interceptor: interceptor, + name: "auth required, has pending token", + initialPreimage: &zeroPreimage, + interceptor: interceptor, resetCb: func() { resetBackend( status.New( @@ -166,10 +148,14 @@ func TestInterceptor(t *testing.T) { ) }, expectLndCall: true, - sendPaymentCb: func(msg test.PaymentChannelMessage) { + sendPaymentCb: func(t *testing.T, + msg test.PaymentChannelMessage) { + t.Fatal("didn't expect call to sendPayment") }, - trackPaymentCb: func(msg test.TrackPaymentMessage) { + trackPaymentCb: func(t *testing.T, + msg test.TrackPaymentMessage) { + // The next call to the "backend" shouldn't // return an error. resetBackend(nil, "") @@ -185,8 +171,8 @@ func TestInterceptor(t *testing.T) { expectMacaroonCall2: true, }, { - name: "auth required, no token yet, cost limit", - initialToken: nil, + name: "auth required, no token yet, cost limit", + initialPreimage: nil, interceptor: NewInterceptor( &lnd.LndServices, store, testTimeout, 100, DefaultMaxRoutingFeeSats, @@ -209,144 +195,211 @@ func TestInterceptor(t *testing.T) { expectMacaroonCall2: false, }, } +) + +// resetBackend is used by the test cases to define the behaviour of the +// simulated backend and reset its starting conditions. +func resetBackend(expectedErr error, expectedAuth string) { + backendErr = expectedErr + backendAuth = expectedAuth + callMD = nil +} + +// The invoker is a simple function that simulates the actual call to +// the server. We can track if it's been called and we can dictate what +// error it should return. +func invoker(opts []grpc.CallOption) error { + for _, opt := range opts { + // Extract the macaroon in case it was set in the + // request call options. + creds, ok := opt.(grpc.PerRPCCredsCallOption) + if ok { + callMD, _ = creds.Creds.GetRequestMetadata( + context.Background(), + ) + } + + // Should we simulate an auth header response? + trailer, ok := opt.(grpc.TrailerCallOption) + if ok && backendAuth != "" { + trailer.TrailerAddr.Set( + AuthHeader, backendAuth, + ) + } + } + numBackendCalls++ + return backendErr +} - // The invoker is a simple function that simulates the actual call to - // the server. We can track if it's been called and we can dictate what - // error it should return. - invoker := func(_ context.Context, _ string, _ interface{}, - _ interface{}, _ *grpc.ClientConn, +// TestUnaryInterceptor tests that the interceptor can handle LSAT protocol +// responses for unary calls and pay the token. +func TestUnaryInterceptor(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + unaryInvoker := func(_ context.Context, _ string, + _ interface{}, _ interface{}, _ *grpc.ClientConn, opts ...grpc.CallOption) error { defer backendWg.Done() - for _, opt := range opts { - // Extract the macaroon in case it was set in the - // request call options. - creds, ok := opt.(grpc.PerRPCCredsCallOption) - if ok { - callMD, _ = creds.Creds.GetRequestMetadata( - context.Background(), - ) - } + return invoker(opts) + } - // Should we simulate an auth header response? - trailer, ok := opt.(grpc.TrailerCallOption) - if ok && backendAuth != "" { - trailer.TrailerAddr.Set( - AuthHeader, backendAuth, - ) - } + // Run through the test cases. + for _, tc := range testCases { + tc := tc + intercept := func() error { + return tc.interceptor.UnaryInterceptor( + ctx, "", nil, nil, nil, unaryInvoker, nil, + ) } - numBackendCalls++ - return backendErr + t.Run(tc.name, func(t *testing.T) { + testInterceptor(t, tc, intercept) + }) + } +} + +// TestStreamInterceptor tests that the interceptor can handle LSAT protocol +// responses in streams and pay the token. +func TestStreamInterceptor(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + streamInvoker := func(_ context.Context, + _ *grpc.StreamDesc, _ *grpc.ClientConn, + _ string, opts ...grpc.CallOption) ( + grpc.ClientStream, error) { // nolint: unparam + + defer backendWg.Done() + return nil, invoker(opts) } // Run through the test cases. for _, tc := range testCases { - // Initial condition and simulated backend call. - store.token = tc.initialToken - tc.resetCb() - numBackendCalls = 0 - var overallWg sync.WaitGroup - backendWg.Add(1) - overallWg.Add(1) - go func() { - err := tc.interceptor.UnaryInterceptor( - ctx, "", nil, nil, nil, invoker, nil, + tc := tc + intercept := func() error { + _, err := tc.interceptor.StreamInterceptor( + ctx, nil, nil, "", streamInvoker, ) - if err != nil && tc.expectInterceptErr != "" && - err.Error() != tc.expectInterceptErr { - panic(fmt.Errorf("unexpected error '%s', "+ - "expected '%s'", err.Error(), - tc.expectInterceptErr)) - } - overallWg.Done() - }() - - backendWg.Wait() - if tc.expectMacaroonCall1 { - if len(callMD) != 1 { - t.Fatalf("[%s] expected backend metadata", - tc.name) - } - if callMD["macaroon"] == testMacHex { - t.Fatalf("[%s] invalid macaroon in metadata, "+ - "got %s, expected %s", tc.name, - callMD["macaroon"], testMacHex) - } + return err + } + t.Run(tc.name, func(t *testing.T) { + testInterceptor(t, tc, intercept) + }) + } +} + +func testInterceptor(t *testing.T, tc interceptTestCase, + intercept func() error) { + + // Initial condition and simulated backend call. + store.token = makeToken(tc.initialPreimage) + tc.resetCb() + numBackendCalls = 0 + backendWg.Add(1) + overallWg.Add(1) + go func() { + defer overallWg.Done() + err := intercept() + if err != nil && tc.expectInterceptErr != "" && + err.Error() != tc.expectInterceptErr { + panic(fmt.Errorf("unexpected error '%s', "+ + "expected '%s'", err.Error(), + tc.expectInterceptErr)) } + }() - // Do we expect more calls? Then make sure we will wait for - // completion before checking any results. - if tc.expectBackendCalls > 1 { - backendWg.Add(1) + backendWg.Wait() + if tc.expectMacaroonCall1 { + if len(callMD) != 1 { + t.Fatalf("[%s] expected backend metadata", + tc.name) + } + if callMD["macaroon"] == testMacHex { + t.Fatalf("[%s] invalid macaroon in metadata, "+ + "got %s, expected %s", tc.name, + callMD["macaroon"], testMacHex) } + } + + // Do we expect more calls? Then make sure we will wait for + // completion before checking any results. + if tc.expectBackendCalls > 1 { + backendWg.Add(1) + } - // Simulate payment related calls to lnd, if there are any - // expected. - if tc.expectLndCall { - select { - case payment := <-lnd.SendPaymentChannel: - tc.sendPaymentCb(payment) + // Simulate payment related calls to lnd, if there are any + // expected. + if tc.expectLndCall { + select { + case payment := <-lnd.SendPaymentChannel: + tc.sendPaymentCb(t, payment) - case track := <-lnd.TrackPaymentChannel: - tc.trackPaymentCb(track) + case track := <-lnd.TrackPaymentChannel: + tc.trackPaymentCb(t, track) - case <-time.After(testTimeout): - t.Fatalf("[%s]: no payment request received", - tc.name) - } + case <-time.After(testTimeout): + t.Fatalf("[%s]: no payment request received", + tc.name) } - backendWg.Wait() - overallWg.Wait() - - // Interpret result/expectations. - if tc.expectToken { - if _, err := store.CurrentToken(); err != nil { - t.Fatalf("[%s] expected store to contain token", - tc.name) - } - storeToken, _ := store.CurrentToken() - if storeToken.Preimage != paidPreimage { - t.Fatalf("[%s] token has unexpected preimage: "+ - "%x", tc.name, storeToken.Preimage) - } + } + backendWg.Wait() + overallWg.Wait() + + if tc.expectToken { + if _, err := store.CurrentToken(); err != nil { + t.Fatalf("[%s] expected store to contain token", + tc.name) } - if tc.expectMacaroonCall2 { - if len(callMD) != 1 { - t.Fatalf("[%s] expected backend metadata", - tc.name) - } - if callMD["macaroon"] == testMacHex { - t.Fatalf("[%s] invalid macaroon in metadata, "+ - "got %s, expected %s", tc.name, - callMD["macaroon"], testMacHex) - } + storeToken, _ := store.CurrentToken() + if storeToken.Preimage != paidPreimage { + t.Fatalf("[%s] token has unexpected preimage: "+ + "%x", tc.name, storeToken.Preimage) + } + } + if tc.expectMacaroonCall2 { + if len(callMD) != 1 { + t.Fatalf("[%s] expected backend metadata", + tc.name) } - if tc.expectBackendCalls != numBackendCalls { - t.Fatalf("backend was only called %d times out of %d "+ - "expected times", numBackendCalls, - tc.expectBackendCalls) + if callMD["macaroon"] == testMacHex { + t.Fatalf("[%s] invalid macaroon in metadata, "+ + "got %s, expected %s", tc.name, + callMD["macaroon"], testMacHex) } } + if tc.expectBackendCalls != numBackendCalls { + t.Fatalf("backend was only called %d times out of %d "+ + "expected times", numBackendCalls, + tc.expectBackendCalls) + } +} + +func makeToken(preimage *lntypes.Preimage) *Token { + if preimage == nil { + return nil + } + return &Token{ + Preimage: *preimage, + baseMac: testMac, + } } -func makeMac(t *testing.T) *macaroon.Macaroon { +func makeMac() *macaroon.Macaroon { dummyMac, err := macaroon.New( []byte("aabbccddeeff00112233445566778899"), []byte("AA=="), "LSAT", macaroon.LatestVersion, ) if err != nil { - t.Fatalf("unable to create macaroon: %v", err) - return nil + panic(fmt.Errorf("unable to create macaroon: %v", err)) } return dummyMac } -func serializeMac(t *testing.T, mac *macaroon.Macaroon) []byte { +func serializeMac(mac *macaroon.Macaroon) []byte { macBytes, err := mac.MarshalBinary() if err != nil { - t.Fatalf("unable to serialize macaroon: %v", err) - return nil + panic(fmt.Errorf("unable to serialize macaroon: %v", err)) } return macBytes } diff --git a/lsat/store_test.go b/lsat/store_test.go index 2fba6f7..101021c 100644 --- a/lsat/store_test.go +++ b/lsat/store_test.go @@ -23,11 +23,11 @@ func TestFileStore(t *testing.T) { paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} paidToken = &Token{ Preimage: paidPreimage, - baseMac: makeMac(t), + baseMac: makeMac(), } pendingToken = &Token{ Preimage: zeroPreimage, - baseMac: makeMac(t), + baseMac: makeMac(), } )