diff --git a/client.go b/client.go index 15ebf3f..177c58d 100644 --- a/client.go +++ b/client.go @@ -49,8 +49,15 @@ var ( ErrSweepConfTargetTooFar = errors.New("sweep confirmation target is " + "beyond swap expiration height") + // serverRPCTimeout is the maximum time a gRPC request to the server + // should be allowed to take. serverRPCTimeout = 30 * time.Second + // globalCallTimeout is the maximum time any call of the client to the + // server is allowed to take, including the time it may take to get + // and pay for an LSAT token. + globalCallTimeout = serverRPCTimeout + lsat.PaymentTimeout + republishDelay = 10 * time.Second ) diff --git a/lsat/interceptor.go b/lsat/interceptor.go index d836e80..ef31382 100644 --- a/lsat/interceptor.go +++ b/lsat/interceptor.go @@ -60,18 +60,22 @@ var ( // challenges with embedded payment requests. It uses a connection to lnd to // automatically pay for an authentication token. type Interceptor struct { - lnd *lndclient.LndServices - store Store - lock sync.Mutex + lnd *lndclient.LndServices + store Store + callTimeout time.Duration + lock sync.Mutex } // NewInterceptor creates a new gRPC client interceptor that uses the provided // lnd connection to automatically acquire and pay for LSAT tokens, unless the // indicated store already contains a usable token. -func NewInterceptor(lnd *lndclient.LndServices, store Store) *Interceptor { +func NewInterceptor(lnd *lndclient.LndServices, store Store, + rpcCallTimeout time.Duration) *Interceptor { + return &Interceptor{ - lnd: lnd, - store: store, + lnd: lnd, + store: store, + callTimeout: rpcCallTimeout, } } @@ -133,7 +137,9 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string, // method again later with the paid LSAT token. trailerMetadata := &metadata.MD{} opts = append(opts, grpc.Trailer(trailerMetadata)) - err = invoker(ctx, method, req, reply, cc, opts...) + rpcCtx, cancel := context.WithTimeout(ctx, i.callTimeout) + defer cancel() + err = invoker(rpcCtx, method, req, reply, cc, opts...) // Only handle the LSAT error message that comes in the form of // a gRPC status error. @@ -149,7 +155,9 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string, // Execute the same request again, now with the LSAT // token added as an RPC credential. - return invoker(ctx, method, req, reply, cc, opts...) + rpcCtx2, cancel2 := context.WithTimeout(ctx, i.callTimeout) + defer cancel2() + return invoker(rpcCtx2, method, req, reply, cc, opts...) } return err } diff --git a/lsat/interceptor_test.go b/lsat/interceptor_test.go index ae5b372..2f18a3a 100644 --- a/lsat/interceptor_test.go +++ b/lsat/interceptor_test.go @@ -45,10 +45,12 @@ func TestInterceptor(t *testing.T) { t.Parallel() var ( - lnd = test.NewMockLnd() - store = &mockStore{} - testTimeout = 5 * time.Second - interceptor = NewInterceptor(&lnd.LndServices, store) + lnd = test.NewMockLnd() + store = &mockStore{} + testTimeout = 5 * time.Second + interceptor = NewInterceptor( + &lnd.LndServices, store, testTimeout, + ) testMac = makeMac(t) testMacBytes = serializeMac(t, testMac) testMacHex = hex.EncodeToString(testMacBytes) diff --git a/swap_server_client.go b/swap_server_client.go index 07a4425..fc071cc 100644 --- a/swap_server_client.go +++ b/swap_server_client.go @@ -56,7 +56,9 @@ func newSwapServerClient(address string, insecure bool, tlsPath string, // Create the server connection with the interceptor that will handle // the LSAT protocol for us. - clientInterceptor := lsat.NewInterceptor(lnd, lsatStore) + clientInterceptor := lsat.NewInterceptor( + lnd, lsatStore, serverRPCTimeout, + ) serverConn, err := getSwapServerConn( address, insecure, tlsPath, clientInterceptor, ) @@ -75,7 +77,7 @@ func newSwapServerClient(address string, insecure bool, tlsPath string, func (s *grpcSwapServerClient) GetLoopOutTerms(ctx context.Context) ( *LoopOutTerms, error) { - rpcCtx, rpcCancel := context.WithTimeout(ctx, serverRPCTimeout) + rpcCtx, rpcCancel := context.WithTimeout(ctx, globalCallTimeout) defer rpcCancel() terms, err := s.server.LoopOutTerms(rpcCtx, &looprpc.ServerLoopOutTermsRequest{}, @@ -93,7 +95,7 @@ func (s *grpcSwapServerClient) GetLoopOutTerms(ctx context.Context) ( func (s *grpcSwapServerClient) GetLoopOutQuote(ctx context.Context, amt btcutil.Amount) (*LoopOutQuote, error) { - rpcCtx, rpcCancel := context.WithTimeout(ctx, serverRPCTimeout) + rpcCtx, rpcCancel := context.WithTimeout(ctx, globalCallTimeout) defer rpcCancel() quoteResp, err := s.server.LoopOutQuote(rpcCtx, &looprpc.ServerLoopOutQuoteRequest{ @@ -125,7 +127,7 @@ func (s *grpcSwapServerClient) GetLoopOutQuote(ctx context.Context, func (s *grpcSwapServerClient) GetLoopInTerms(ctx context.Context) ( *LoopInTerms, error) { - rpcCtx, rpcCancel := context.WithTimeout(ctx, serverRPCTimeout) + rpcCtx, rpcCancel := context.WithTimeout(ctx, globalCallTimeout) defer rpcCancel() terms, err := s.server.LoopInTerms(rpcCtx, &looprpc.ServerLoopInTermsRequest{}, @@ -143,7 +145,7 @@ func (s *grpcSwapServerClient) GetLoopInTerms(ctx context.Context) ( func (s *grpcSwapServerClient) GetLoopInQuote(ctx context.Context, amt btcutil.Amount) (*LoopInQuote, error) { - rpcCtx, rpcCancel := context.WithTimeout(ctx, serverRPCTimeout) + rpcCtx, rpcCancel := context.WithTimeout(ctx, globalCallTimeout) defer rpcCancel() quoteResp, err := s.server.LoopInQuote(rpcCtx, &looprpc.ServerLoopInQuoteRequest{ @@ -165,7 +167,7 @@ func (s *grpcSwapServerClient) NewLoopOutSwap(ctx context.Context, receiverKey [33]byte, swapPublicationDeadline time.Time) ( *newLoopOutResponse, error) { - rpcCtx, rpcCancel := context.WithTimeout(ctx, serverRPCTimeout) + rpcCtx, rpcCancel := context.WithTimeout(ctx, globalCallTimeout) defer rpcCancel() swapResp, err := s.server.NewLoopOutSwap(rpcCtx, &looprpc.ServerLoopOutRequest{ @@ -200,7 +202,7 @@ func (s *grpcSwapServerClient) NewLoopInSwap(ctx context.Context, swapHash lntypes.Hash, amount btcutil.Amount, senderKey [33]byte, swapInvoice string) (*newLoopInResponse, error) { - rpcCtx, rpcCancel := context.WithTimeout(ctx, serverRPCTimeout) + rpcCtx, rpcCancel := context.WithTimeout(ctx, globalCallTimeout) defer rpcCancel() swapResp, err := s.server.NewLoopInSwap(rpcCtx, &looprpc.ServerLoopInRequest{