diff --git a/client_test.go b/client_test.go index d0c5bfd..f4a9382 100644 --- a/client_test.go +++ b/client_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcutil" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" @@ -56,7 +57,7 @@ func TestSuccess(t *testing.T) { signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) // Expect client to register for conf. - confIntent := ctx.AssertRegisterConf() + confIntent := ctx.AssertRegisterConf(false) testSuccess(ctx, testRequest.Amount, *hash, signalPrepaymentResult, signalSwapPaymentResult, false, @@ -82,7 +83,7 @@ func TestFailOffchain(t *testing.T) { signalSwapPaymentResult := ctx.AssertPaid(swapInvoiceDesc) signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) - ctx.AssertRegisterConf() + ctx.AssertRegisterConf(false) signalSwapPaymentResult( errors.New(lndclient.PaymentResultUnknownPaymentHash), @@ -187,6 +188,7 @@ func testResume(t *testing.T, expired, preimageRevealed, expectSuccess bool) { if preimageRevealed { update.State = loopdb.StatePreimageRevealed + update.HtlcTxHash = &chainhash.Hash{1, 2, 6} } pendingSwap := &loopdb.LoopOut{ @@ -230,7 +232,7 @@ func testResume(t *testing.T, expired, preimageRevealed, expectSuccess bool) { signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) // Expect client to register for conf. - confIntent := ctx.AssertRegisterConf() + confIntent := ctx.AssertRegisterConf(preimageRevealed) signalSwapPaymentResult(nil) signalPrepaymentResult(nil) diff --git a/loopout.go b/loopout.go index 4418f28..12af47e 100644 --- a/loopout.go +++ b/loopout.go @@ -8,6 +8,7 @@ import ( "fmt" "time" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightninglabs/lndclient" @@ -56,6 +57,9 @@ type loopOutSwap struct { htlc *swap.Htlc + // htlcTxHash is the confirmed htlc tx id. + htlcTxHash *chainhash.Hash + swapPaymentChan chan lndclient.PaymentResult prePaymentChan chan lndclient.PaymentResult } @@ -210,6 +214,7 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig, } else { swap.state = lastUpdate.State swap.lastUpdateTime = lastUpdate.Time + swap.htlcTxHash = lastUpdate.HtlcTxHash } return swap, nil @@ -376,6 +381,7 @@ func (s *loopOutSwap) executeSwap(globalCtx context.Context) error { // Try to spend htlc and continue (rbf) until a spend has confirmed. spendDetails, err := s.waitForHtlcSpendConfirmed(globalCtx, + *htlcOutpoint, func() error { return s.sweep(globalCtx, *htlcOutpoint, htlcValue) }, @@ -419,8 +425,9 @@ func (s *loopOutSwap) persistState(ctx context.Context) error { err := s.store.UpdateLoopOut( s.hash, updateTime, loopdb.SwapStateData{ - State: s.state, - Cost: s.cost, + State: s.state, + Cost: s.cost, + HtlcTxHash: s.htlcTxHash, }, ) if err != nil { @@ -563,11 +570,21 @@ func (s *loopOutSwap) waitForConfirmedHtlc(globalCtx context.Context) ( s.InitiationHeight, ) + // If we've revealed the preimage in a previous run, we expect to have + // recorded the htlc tx hash. We use this to re-register for + // confirmation, to be sure that we'll keep tracking the same htlc. For + // older swaps, this field may not be populated even though the preimage + // has already been revealed. + if s.state == loopdb.StatePreimageRevealed && s.htlcTxHash == nil { + s.log.Warnf("No htlc tx hash available, registering with " + + "just the pkscript") + } + ctx, cancel := context.WithCancel(globalCtx) defer cancel() htlcConfChan, htlcErrChan, err := s.lnd.ChainNotifier.RegisterConfirmationsNtfn( - ctx, nil, s.htlc.PkScript, 1, + ctx, s.htlcTxHash, s.htlc.PkScript, 1, s.InitiationHeight, ) if err != nil { @@ -680,8 +697,10 @@ func (s *loopOutSwap) waitForConfirmedHtlc(globalCtx context.Context) ( } } - s.log.Infof("Htlc tx %v at height %v", txConf.Tx.TxHash(), - txConf.BlockHeight) + htlcTxHash := txConf.Tx.TxHash() + s.log.Infof("Htlc tx %v at height %v", htlcTxHash, txConf.BlockHeight) + + s.htlcTxHash = &htlcTxHash return txConf, nil } @@ -694,13 +713,14 @@ func (s *loopOutSwap) waitForConfirmedHtlc(globalCtx context.Context) ( // sweep offchain. So we must make sure we sweep successfully before on-chain // timeout. func (s *loopOutSwap) waitForHtlcSpendConfirmed(globalCtx context.Context, - spendFunc func() error) (*chainntnfs.SpendDetail, error) { + htlc wire.OutPoint, spendFunc func() error) (*chainntnfs.SpendDetail, + error) { // Register the htlc spend notification. ctx, cancel := context.WithCancel(globalCtx) defer cancel() spendChan, spendErr, err := s.lnd.ChainNotifier.RegisterSpendNtfn( - ctx, nil, s.htlc.PkScript, s.InitiationHeight, + ctx, &htlc, s.htlc.PkScript, s.InitiationHeight, ) if err != nil { return nil, fmt.Errorf("register spend ntfn: %v", err) diff --git a/loopout_test.go b/loopout_test.go index 956b95d..2263ff0 100644 --- a/loopout_test.go +++ b/loopout_test.go @@ -115,7 +115,7 @@ func TestLoopOutPaymentParameters(t *testing.T) { // Swap is expected to register for confirmation of the htlc. Assert // this to prevent a blocked channel in the mock. - ctx.AssertRegisterConf() + ctx.AssertRegisterConf(false) // Cancel the swap. There is nothing else we need to assert. The payment // parameters don't play a role in the remainder of the swap process. @@ -187,7 +187,7 @@ func TestLateHtlcPublish(t *testing.T) { signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) // Expect client to register for conf - ctx.AssertRegisterConf() + ctx.AssertRegisterConf(false) // // Wait too long before publishing htlc. blockEpochChan <- int32(swap.CltvExpiry - 10) @@ -283,7 +283,7 @@ func TestCustomSweepConfTarget(t *testing.T) { signalPrepaymentResult(nil) // Notify the confirmation notification for the HTLC. - ctx.AssertRegisterConf() + ctx.AssertRegisterConf(false) blockEpochChan <- ctx.Lnd.Height + 1 @@ -484,7 +484,7 @@ func TestPreimagePush(t *testing.T) { signalPrepaymentResult(nil) // Notify the confirmation notification for the HTLC. - ctx.AssertRegisterConf() + ctx.AssertRegisterConf(false) blockEpochChan <- ctx.Lnd.Height + 1 diff --git a/test/context.go b/test/context.go index 81740c4..146573c 100644 --- a/test/context.go +++ b/test/context.go @@ -113,14 +113,18 @@ func (ctx *Context) AssertTrackPayment() TrackPaymentMessage { } // AssertRegisterConf asserts that a register for conf has been received. -func (ctx *Context) AssertRegisterConf() *ConfRegistration { +func (ctx *Context) AssertRegisterConf(expectTxHash bool) *ConfRegistration { ctx.T.Helper() // Expect client to register for conf var confIntent *ConfRegistration select { case confIntent = <-ctx.Lnd.RegisterConfChannel: - if confIntent.TxID != nil { + switch { + case expectTxHash && confIntent.TxID == nil: + ctx.T.Fatalf("expected tx id for registration") + + case !expectTxHash && confIntent.TxID != nil: ctx.T.Fatalf("expected script only registration") } case <-time.After(Timeout):