diff --git a/loopdb/sqlc/batch.sql.go b/loopdb/sqlc/batch.sql.go index 9469292..65b88a3 100644 --- a/loopdb/sqlc/batch.sql.go +++ b/loopdb/sqlc/batch.sql.go @@ -144,6 +144,54 @@ func (q *Queries) GetBatchSweeps(ctx context.Context, batchID int32) ([]GetBatch return items, nil } +const getBatchSweptAmount = `-- name: GetBatchSweptAmount :one +SELECT + SUM(amt) AS total +FROM + sweeps +WHERE + batch_id = $1 +AND + completed = TRUE +` + +func (q *Queries) GetBatchSweptAmount(ctx context.Context, batchID int32) (int64, error) { + row := q.db.QueryRowContext(ctx, getBatchSweptAmount, batchID) + var total int64 + err := row.Scan(&total) + return total, err +} + +const getParentBatch = `-- name: GetParentBatch :one +SELECT + sweep_batches.id, sweep_batches.confirmed, sweep_batches.batch_tx_id, sweep_batches.batch_pk_script, sweep_batches.last_rbf_height, sweep_batches.last_rbf_sat_per_kw, sweep_batches.max_timeout_distance +FROM + sweep_batches +JOIN + sweeps ON sweep_batches.id = sweeps.batch_id +WHERE + sweeps.swap_hash = $1 +AND + sweeps.completed = TRUE +AND + sweep_batches.confirmed = TRUE +` + +func (q *Queries) GetParentBatch(ctx context.Context, swapHash []byte) (SweepBatch, error) { + row := q.db.QueryRowContext(ctx, getParentBatch, swapHash) + var i SweepBatch + err := row.Scan( + &i.ID, + &i.Confirmed, + &i.BatchTxID, + &i.BatchPkScript, + &i.LastRbfHeight, + &i.LastRbfSatPerKw, + &i.MaxTimeoutDistance, + ) + return i, err +} + const getSweepStatus = `-- name: GetSweepStatus :one SELECT COALESCE(s.completed, f.false_value) AS completed diff --git a/loopdb/sqlc/querier.go b/loopdb/sqlc/querier.go index bf9f71e..e2355c7 100644 --- a/loopdb/sqlc/querier.go +++ b/loopdb/sqlc/querier.go @@ -13,6 +13,7 @@ type Querier interface { CreateReservation(ctx context.Context, arg CreateReservationParams) error FetchLiquidityParams(ctx context.Context) ([]byte, error) GetBatchSweeps(ctx context.Context, batchID int32) ([]GetBatchSweepsRow, error) + GetBatchSweptAmount(ctx context.Context, batchID int32) (int64, error) GetInstantOutSwap(ctx context.Context, swapHash []byte) (GetInstantOutSwapRow, error) GetInstantOutSwapUpdates(ctx context.Context, swapHash []byte) ([]InstantoutUpdate, error) GetInstantOutSwaps(ctx context.Context) ([]GetInstantOutSwapsRow, error) @@ -20,6 +21,7 @@ type Querier interface { GetLoopInSwaps(ctx context.Context) ([]GetLoopInSwapsRow, error) GetLoopOutSwap(ctx context.Context, swapHash []byte) (GetLoopOutSwapRow, error) GetLoopOutSwaps(ctx context.Context) ([]GetLoopOutSwapsRow, error) + GetParentBatch(ctx context.Context, swapHash []byte) (SweepBatch, error) GetReservation(ctx context.Context, reservationID []byte) (Reservation, error) GetReservationUpdates(ctx context.Context, reservationID []byte) ([]ReservationUpdate, error) GetReservations(ctx context.Context) ([]Reservation, error) diff --git a/loopdb/sqlc/queries/batch.sql b/loopdb/sqlc/queries/batch.sql index 336ccf1..a9793f9 100644 --- a/loopdb/sqlc/queries/batch.sql +++ b/loopdb/sqlc/queries/batch.sql @@ -62,6 +62,29 @@ INSERT INTO sweeps ( amt = $5, completed = $6; +-- name: GetParentBatch :one +SELECT + sweep_batches.* +FROM + sweep_batches +JOIN + sweeps ON sweep_batches.id = sweeps.batch_id +WHERE + sweeps.swap_hash = $1 +AND + sweeps.completed = TRUE +AND + sweep_batches.confirmed = TRUE; + +-- name: GetBatchSweptAmount :one +SELECT + SUM(amt) AS total +FROM + sweeps +WHERE + batch_id = $1 +AND + completed = TRUE; -- name: GetBatchSweeps :many SELECT diff --git a/loopin.go b/loopin.go index 5aecb67..b92a3cf 100644 --- a/loopin.go +++ b/loopin.go @@ -919,9 +919,7 @@ func (s *loopInSwap) waitForSwapComplete(ctx context.Context, s.log.Infof("Htlc spend by tx: %v", spendDetails.SpenderTxHash) - err := s.processHtlcSpend( - ctx, spendDetails, htlcValue, sweepFee, - ) + err := s.processHtlcSpend(ctx, spendDetails, sweepFee) if err != nil { return err } @@ -959,8 +957,6 @@ func (s *loopInSwap) waitForSwapComplete(ctx context.Context, switch update.State { // Swap invoice was paid, so update server cost balance. case invpkg.ContractSettled: - s.cost.Server -= update.AmtPaid - // If invoice settlement and htlc spend happen // in the expected order, move the swap to an // intermediate state that indicates that the @@ -977,6 +973,8 @@ func (s *loopInSwap) waitForSwapComplete(ctx context.Context, invoiceFinalized = true htlcKeyRevealed = s.tryPushHtlcKey(ctx) + s.cost.Server = s.AmountRequested - + update.AmtPaid // Canceled invoice has no effect on server cost // balance. @@ -1023,8 +1021,7 @@ func (s *loopInSwap) tryPushHtlcKey(ctx context.Context) bool { } func (s *loopInSwap) processHtlcSpend(ctx context.Context, - spend *chainntnfs.SpendDetail, htlcValue, - sweepFee btcutil.Amount) error { + spend *chainntnfs.SpendDetail, sweepFee btcutil.Amount) error { // Determine the htlc input of the spending tx and inspect the witness // to find out whether a success or a timeout tx spent the htlc. @@ -1032,10 +1029,6 @@ func (s *loopInSwap) processHtlcSpend(ctx context.Context, if s.htlc.IsSuccessWitness(htlcInput.Witness) { s.setState(loopdb.StateSuccess) - - // Server swept the htlc. The htlc value can be added to the - // server cost balance. - s.cost.Server += htlcValue } else { // We needed another on chain tx to sweep the timeout clause, // which we now include in our costs. diff --git a/loopout.go b/loopout.go index 8035b9d..1fc9cdd 100644 --- a/loopout.go +++ b/loopout.go @@ -452,7 +452,8 @@ func (s *loopOutSwap) handlePaymentResult(result paymentResult) error { return nil case result.status.State == lnrpc.Payment_SUCCEEDED: - s.cost.Server += result.status.Value.ToSatoshis() + s.cost.Server += result.status.Value.ToSatoshis() - + s.AmountRequested s.cost.Offchain += result.status.Fee.ToSatoshis() return nil @@ -514,7 +515,7 @@ func (s *loopOutSwap) executeSwap(globalCtx context.Context) error { } // Try to spend htlc and continue (rbf) until a spend has confirmed. - spendTx, err := s.waitForHtlcSpendConfirmedV2( + spend, err := s.waitForHtlcSpendConfirmedV2( globalCtx, *htlcOutpoint, htlcValue, ) if err != nil { @@ -523,7 +524,7 @@ func (s *loopOutSwap) executeSwap(globalCtx context.Context) error { // If spend details are nil, we resolved the swap without waiting for // its spend, so we can exit. - if spendTx == nil { + if spend == nil { return nil } @@ -531,7 +532,7 @@ func (s *loopOutSwap) executeSwap(globalCtx context.Context) error { // don't just try to match with the hash of our sweep tx, because it // may be swept by a different (fee) sweep tx from a previous run. htlcInput, err := swap.GetTxInputByOutpoint( - spendTx, htlcOutpoint, + spend.Tx, htlcOutpoint, ) if err != nil { return err @@ -539,11 +540,7 @@ func (s *loopOutSwap) executeSwap(globalCtx context.Context) error { sweepSuccessful := s.htlc.IsSuccessWitness(htlcInput.Witness) if sweepSuccessful { - s.cost.Server -= htlcValue - - s.cost.Onchain = htlcValue - - btcutil.Amount(spendTx.TxOut[0].Value) - + s.cost.Onchain = spend.OnChainFeePortion s.state = loopdb.StateSuccess } else { s.state = loopdb.StateFailSweepTimeout @@ -1005,9 +1002,9 @@ func (s *loopOutSwap) waitForConfirmedHtlc(globalCtx context.Context) ( // sweep or a server revocation tx. func (s *loopOutSwap) waitForHtlcSpendConfirmedV2(globalCtx context.Context, htlcOutpoint wire.OutPoint, htlcValue btcutil.Amount) ( - *wire.MsgTx, error) { + *sweepbatcher.SpendDetail, error) { - spendChan := make(chan *wire.MsgTx) + spendChan := make(chan *sweepbatcher.SpendDetail) spendErrChan := make(chan error, 1) quitChan := make(chan bool, 1) @@ -1054,10 +1051,10 @@ func (s *loopOutSwap) waitForHtlcSpendConfirmedV2(globalCtx context.Context, for { select { // Htlc spend, break loop. - case spendTx := <-spendChan: - s.log.Infof("Htlc spend by tx: %v", spendTx.TxHash()) + case spend := <-spendChan: + s.log.Infof("Htlc spend by tx: %v", spend.Tx.TxHash()) - return spendTx, nil + return spend, nil // Spend notification error. case err := <-spendErrChan: diff --git a/sweepbatcher/store.go b/sweepbatcher/store.go index 2bd3ea2..3fcc32e 100644 --- a/sweepbatcher/store.go +++ b/sweepbatcher/store.go @@ -22,9 +22,17 @@ type BaseDB interface { GetBatchSweeps(ctx context.Context, batchID int32) ( []sqlc.GetBatchSweepsRow, error) + // GetBatchSweptAmount returns the total amount of sats swept by a + // (confirmed) batch. + GetBatchSweptAmount(ctx context.Context, batchID int32) (int64, error) + // GetSweepStatus returns true if the sweep has been completed. GetSweepStatus(ctx context.Context, swapHash []byte) (bool, error) + // GetParentBatch fetches the parent batch of a completed sweep. + GetParentBatch(ctx context.Context, swapHash []byte) (sqlc.SweepBatch, + error) + // GetSwapUpdates fetches all the updates for a swap. GetSwapUpdates(ctx context.Context, swapHash []byte) ( []sqlc.SwapUpdate, error) @@ -148,6 +156,34 @@ func (s *SQLStore) FetchBatchSweeps(ctx context.Context, id int32) ( return sweeps, nil } +// TotalSweptAmount returns the total amount swept by a (confirmed) batch. +func (s *SQLStore) TotalSweptAmount(ctx context.Context, id int32) ( + btcutil.Amount, error) { + + amt, err := s.baseDb.GetBatchSweptAmount(ctx, id) + if err != nil { + return 0, err + } + + return btcutil.Amount(amt), nil +} + +// GetParentBatch fetches the parent batch of a completed sweep. +func (s *SQLStore) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( + *dbBatch, error) { + + batch, err := s.baseDb.GetParentBatch(ctx, swapHash[:]) + if err != nil { + return nil, err + } + + if err != nil { + return nil, err + } + + return convertBatchRow(batch), nil +} + // UpsertSweep inserts a sweep into the database, or updates an existing sweep // if it already exists. func (s *SQLStore) UpsertSweep(ctx context.Context, sweep *dbSweep) error { diff --git a/sweepbatcher/store_mock.go b/sweepbatcher/store_mock.go index fef88c7..f27fcb4 100644 --- a/sweepbatcher/store_mock.go +++ b/sweepbatcher/store_mock.go @@ -5,6 +5,7 @@ import ( "errors" "sort" + "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/lntypes" ) @@ -123,3 +124,44 @@ func (s *StoreMock) AssertSweepStored(id lntypes.Hash) bool { _, ok := s.sweeps[id] return ok } + +// GetParentBatch returns the parent batch of a swap. +func (s *StoreMock) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( + *dbBatch, error) { + + for _, sweep := range s.sweeps { + if sweep.SwapHash == swapHash { + batch, ok := s.batches[sweep.BatchID] + if !ok { + return nil, errors.New("batch not found") + } + return &batch, nil + } + } + + return nil, errors.New("batch not found") +} + +// TotalSweptAmount returns the total amount of BTC that has been swept from a +// batch. +func (s *StoreMock) TotalSweptAmount(ctx context.Context, batchID int32) ( + btcutil.Amount, error) { + + batch, ok := s.batches[batchID] + if !ok { + return 0, errors.New("batch not found") + } + + if batch.State != batchConfirmed && batch.State != batchClosed { + return 0, nil + } + + var total btcutil.Amount + for _, sweep := range s.sweeps { + if sweep.BatchID == batchID { + total += sweep.Amount + } + } + + return 0, nil +} diff --git a/sweepbatcher/sweep_batch.go b/sweepbatcher/sweep_batch.go index 971b83d..0960f31 100644 --- a/sweepbatcher/sweep_batch.go +++ b/sweepbatcher/sweep_batch.go @@ -1136,6 +1136,33 @@ func (b *batch) monitorConfirmations(ctx context.Context) error { return nil } +// getFeePortionForSweep calculates the fee portion that each sweep should pay +// for the batch transaction. The fee is split evenly among the sweeps, If the +// fee cannot be split evenly, the remainder is paid by the first sweep. +func getFeePortionForSweep(spendTx *wire.MsgTx, numSweeps int, + totalSweptAmt btcutil.Amount) (btcutil.Amount, btcutil.Amount) { + + totalFee := spendTx.TxOut[0].Value - int64(totalSweptAmt) + feePortionPerSweep := (int64(totalSweptAmt) - + spendTx.TxOut[0].Value) / int64(numSweeps) + roundingDiff := totalFee - (int64(numSweeps) * feePortionPerSweep) + + return btcutil.Amount(feePortionPerSweep), btcutil.Amount(roundingDiff) +} + +// getFeePortionPaidBySweep returns the fee portion that the sweep should pay +// for the batch transaction. If the sweep is the first sweep in the batch, it +// pays the rounding difference. +func getFeePortionPaidBySweep(spendTx *wire.MsgTx, feePortionPerSweep, + roundingDiff btcutil.Amount, sweep *sweep) btcutil.Amount { + + if bytes.Equal(spendTx.TxIn[0].SignatureScript, sweep.htlc.SigScript) { + return feePortionPerSweep + roundingDiff + } + + return feePortionPerSweep +} + // handleSpend handles a spend notification. func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { var ( @@ -1151,12 +1178,14 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { // sweeps that did not make it to the confirmed transaction and feed // them back to the batcher. This will ensure that the sweeps will enter // a new batch instead of remaining dangling. + var totalSweptAmt btcutil.Amount for _, sweep := range b.sweeps { found := false for _, txIn := range spendTx.TxIn { if txIn.PreviousOutPoint == sweep.outpoint { found = true + totalSweptAmt += sweep.value notifyList = append(notifyList, sweep) } } @@ -1176,7 +1205,13 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { } } + // Calculate the fee portion that each sweep should pay for the batch. + feePortionPaidPerSweep, roundingDifference := getFeePortionForSweep( + spendTx, len(notifyList), totalSweptAmt, + ) + for _, sweep := range notifyList { + sweep := sweep // Save the sweep as completed. err := b.persistSweep(ctx, sweep, true) if err != nil { @@ -1192,9 +1227,17 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { continue } + spendDetail := SpendDetail{ + Tx: spendTx, + OnChainFeePortion: getFeePortionPaidBySweep( + spendTx, feePortionPaidPerSweep, + roundingDifference, &sweep, + ), + } + // Dispatch the sweep notifier, we don't care about the outcome // of this action so we don't wait for it. - go notifySweepSpend(ctx, sweep, spendTx) + go sweep.notifySweepSpend(ctx, &spendDetail) } // Proceed with purging the sweeps. This will feed the sweeps that @@ -1318,10 +1361,12 @@ func (b *batch) insertAndAcquireID(ctx context.Context) (int32, error) { } // notifySweepSpend writes the spendTx to the sweep's notifier channel. -func notifySweepSpend(ctx context.Context, s sweep, spendTx *wire.MsgTx) { +func (s *sweep) notifySweepSpend(ctx context.Context, + spendDetail *SpendDetail) { + select { // Try to write the update to the notification channel. - case s.notifier.SpendChan <- spendTx: + case s.notifier.SpendChan <- spendDetail: // If a quit signal was provided by the swap, continue. case <-s.notifier.QuitChan: diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index 5548085..7065a47 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -72,6 +72,14 @@ type BatcherStore interface { // GetSweepStatus returns the completed status of the sweep. GetSweepStatus(ctx context.Context, swapHash lntypes.Hash) ( bool, error) + + // GetParentBatch returns the parent batch of a (completed) sweep. + GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( + *dbBatch, error) + + // TotalSweptAmount returns the total amount swept by a (confirmed) + // batch. + TotalSweptAmount(ctx context.Context, id int32) (btcutil.Amount, error) } // MuSig2SignSweep is a function that can be used to sign a sweep transaction @@ -102,11 +110,22 @@ type SweepRequest struct { Notifier *SpendNotifier } +type SpendDetail struct { + // Tx is the transaction that spent the outpoint. + Tx *wire.MsgTx + + // OnChainFeePortion is the fee portion that was paid to get this sweep + // confirmed on chain. This is the difference between the value of the + // outpoint and the value of all sweeps that were included in the batch + // divided by the number of sweeps. + OnChainFeePortion btcutil.Amount +} + // SpendNotifier is a notifier that is used to notify the requester of a sweep // that the sweep was successful. type SpendNotifier struct { // SpendChan is a channel where the spend details are received. - SpendChan chan *wire.MsgTx + SpendChan chan *SpendDetail // SpendErrChan is a channel where spend errors are received. SpendErrChan chan error @@ -270,8 +289,7 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, // can't attach its notifier to the batch as that is no longer running. // Instead we directly detect and return the spend here. if completed && *notifier != (SpendNotifier{}) { - go b.monitorSpendAndNotify(ctx, sweep, notifier) - return nil + return b.monitorSpendAndNotify(ctx, sweep, notifier) } sweep.notifier = notifier @@ -509,57 +527,86 @@ func (b *Batcher) FetchUnconfirmedBatches(ctx context.Context) ([]*batch, // monitorSpendAndNotify monitors the spend of a specific outpoint and writes // the response back to the response channel. func (b *Batcher) monitorSpendAndNotify(ctx context.Context, sweep *sweep, - notifier *SpendNotifier) { - - b.wg.Add(1) - defer b.wg.Done() + notifier *SpendNotifier) error { spendCtx, cancel := context.WithCancel(ctx) defer cancel() + // First get the batch that completed the sweep. + parentBatch, err := b.store.GetParentBatch(ctx, sweep.swapHash) + if err != nil { + return err + } + + // Then we get the total amount that was swept by the batch. + totalSwept, err := b.store.TotalSweptAmount(ctx, parentBatch.ID) + if err != nil { + return err + } + spendChan, spendErr, err := b.chainNotifier.RegisterSpendNtfn( spendCtx, &sweep.outpoint, sweep.htlc.PkScript, sweep.initiationHeight, ) if err != nil { - select { - case notifier.SpendErrChan <- err: - case <-ctx.Done(): - } - - _ = b.writeToErrChan(ctx, err) - - return + return err } - log.Infof("Batcher monitoring spend for swap %x", sweep.swapHash[:6]) + b.wg.Add(1) + go func() { + defer b.wg.Done() + log.Infof("Batcher monitoring spend for swap %x", + sweep.swapHash[:6]) - for { - select { - case spend := <-spendChan: + for { select { - case notifier.SpendChan <- spend.SpendingTx: - case <-ctx.Done(): - } - - return + case spend := <-spendChan: + spendTx := spend.SpendingTx + // Calculate the fee portion that each sweep + // should pay for the batch. + feePortionPerSweep, roundingDifference := + getFeePortionForSweep( + spendTx, len(spendTx.TxIn), + totalSwept, + ) + + // Notify the requester of the spend + // with the spend details, including the fee + // portion for this particular sweep. + spendDetail := &SpendDetail{ + Tx: spendTx, + OnChainFeePortion: getFeePortionPaidBySweep( // nolint:lll + spendTx, feePortionPerSweep, + roundingDifference, sweep, + ), + } + + select { + case notifier.SpendChan <- spendDetail: + case <-ctx.Done(): + } + + return + + case err := <-spendErr: + select { + case notifier.SpendErrChan <- err: + case <-ctx.Done(): + } + + _ = b.writeToErrChan(ctx, err) + return + + case <-notifier.QuitChan: + return - case err := <-spendErr: - select { - case notifier.SpendErrChan <- err: case <-ctx.Done(): + return } - - _ = b.writeToErrChan(ctx, err) - return - - case <-notifier.QuitChan: - return - - case <-ctx.Done(): - return } - } + }() + + return nil } func (b *Batcher) writeToErrChan(ctx context.Context, err error) error { diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index 2d233b9..1671007 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -38,7 +38,7 @@ func testMuSig2SignSweep(ctx context.Context, } var dummyNotifier = SpendNotifier{ - SpendChan: make(chan *wire.MsgTx, ntfnBufferSize), + SpendChan: make(chan *SpendDetail, ntfnBufferSize), SpendErrChan: make(chan error, ntfnBufferSize), QuitChan: make(chan bool, ntfnBufferSize), }