multi: use prev output fetcher everywhere

pull/70/head
Oliver Gugger 1 year ago
parent 6ba000e1da
commit 890a1ca319
No known key found for this signature in database
GPG Key ID: 8E4256593F177720

@ -241,8 +241,11 @@ func closePoolAccount(extendedKey *hdkeychain.ExtendedKey, apiURL string,
// Calculate the fee based on the given fee rate and our weight // Calculate the fee based on the given fee rate and our weight
// estimation. // estimation.
var ( var (
estimator input.TxWeightEstimator estimator input.TxWeightEstimator
signDesc = &input.SignDescriptor{ prevOutFetcher = txscript.NewCannedPrevOutputFetcher(
pkScript, sweepValue,
)
signDesc = &input.SignDescriptor{
KeyDesc: keychain.KeyDescriptor{ KeyDesc: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{ KeyLocator: keychain.KeyLocator{
Family: poolscript.AccountKeyFamily, Family: poolscript.AccountKeyFamily,
@ -255,10 +258,8 @@ func closePoolAccount(extendedKey *hdkeychain.ExtendedKey, apiURL string,
PkScript: pkScript, PkScript: pkScript,
Value: sweepValue, Value: sweepValue,
}, },
InputIndex: 0, InputIndex: 0,
PrevOutputFetcher: txscript.NewCannedPrevOutputFetcher( PrevOutputFetcher: prevOutFetcher,
pkScript, sweepValue,
),
} }
) )
@ -267,7 +268,9 @@ func closePoolAccount(extendedKey *hdkeychain.ExtendedKey, apiURL string,
estimator.AddWitnessInput(poolscript.ExpiryWitnessSize) estimator.AddWitnessInput(poolscript.ExpiryWitnessSize)
signDesc.HashType = txscript.SigHashAll signDesc.HashType = txscript.SigHashAll
signDesc.SignMethod = input.WitnessV0SignMethod signDesc.SignMethod = input.WitnessV0SignMethod
signDesc.SigHashes = input.NewTxSigHashesV0Only(sweepTx) signDesc.SigHashes = txscript.NewTxSigHashes(
sweepTx, prevOutFetcher,
)
case account.VersionTaprootEnabled: case account.VersionTaprootEnabled:
estimator.AddWitnessInput(poolscript.TaprootExpiryWitnessSize) estimator.AddWitnessInput(poolscript.TaprootExpiryWitnessSize)

@ -258,8 +258,12 @@ func getSignedTx(signer *lnd.Signer, loopIn *loopdb.LoopIn, sweepTx *wire.MsgTx,
keyIndex uint32) ([]byte, error) { keyIndex uint32) ([]byte, error) {
// Create the sign descriptor. // Create the sign descriptor.
prevoutFetcher := txscript.NewCannedPrevOutputFetcher( prevTxOut := &wire.TxOut{
htlc.PkScript, int64(loopIn.Contract.AmountRequested), PkScript: htlc.PkScript,
Value: int64(loopIn.Contract.AmountRequested),
}
prevOutputFetcher := txscript.NewCannedPrevOutputFetcher(
prevTxOut.PkScript, prevTxOut.Value,
) )
signDesc := &input.SignDescriptor{ signDesc := &input.SignDescriptor{
@ -272,11 +276,8 @@ func getSignedTx(signer *lnd.Signer, loopIn *loopdb.LoopIn, sweepTx *wire.MsgTx,
WitnessScript: htlc.TimeoutScript(), WitnessScript: htlc.TimeoutScript(),
HashType: htlc.SigHash(), HashType: htlc.SigHash(),
InputIndex: 0, InputIndex: 0,
PrevOutputFetcher: prevoutFetcher, PrevOutputFetcher: prevOutputFetcher,
Output: &wire.TxOut{ Output: prevTxOut,
PkScript: htlc.PkScript,
Value: int64(loopIn.Contract.AmountRequested),
},
} }
switch htlc.Version { switch htlc.Version {
case swap.HtlcV2: case swap.HtlcV2:
@ -303,13 +304,13 @@ func getSignedTx(signer *lnd.Signer, loopIn *loopdb.LoopIn, sweepTx *wire.MsgTx,
return nil, err return nil, err
} }
sighashes := txscript.NewTxSigHashes(sweepTx, prevoutFetcher) sigHashes := txscript.NewTxSigHashes(sweepTx, prevOutputFetcher)
// Verify the signature. This will throw an error if the signature is // Verify the signature. This will throw an error if the signature is
// invalid and allows us to bruteforce the key index. // invalid and allows us to bruteforce the key index.
vm, err := txscript.NewEngine( vm, err := txscript.NewEngine(
htlc.PkScript, sweepTx, 0, txscript.StandardVerifyFlags, nil, prevTxOut.PkScript, sweepTx, 0, txscript.StandardVerifyFlags,
sighashes, int64(loopIn.Contract.AmountRequested), prevoutFetcher, nil, sigHashes, prevTxOut.Value, prevOutputFetcher,
) )
if err != nil { if err != nil {
return nil, err return nil, err

@ -173,6 +173,7 @@ func sweepRemoteClosed(extendedKey *hdkeychain.ExtendedKey, apiURL,
signDescs []*input.SignDescriptor signDescs []*input.SignDescriptor
sweepTx = wire.NewMsgTx(2) sweepTx = wire.NewMsgTx(2)
totalOutputValue = uint64(0) totalOutputValue = uint64(0)
prevOutFetcher = txscript.NewMultiPrevOutFetcher(nil)
) )
// Add all found target outputs. // Add all found target outputs.
@ -207,22 +208,25 @@ func sweepRemoteClosed(extendedKey *hdkeychain.ExtendedKey, apiURL,
sequence = 1 sequence = 1
} }
prevOutPoint := wire.OutPoint{
Hash: *txHash,
Index: uint32(vout.Outspend.Vin),
}
prevTxOut := &wire.TxOut{
PkScript: pkScript,
Value: int64(vout.Value),
}
prevOutFetcher.AddPrevOut(prevOutPoint, prevTxOut)
sweepTx.TxIn = append(sweepTx.TxIn, &wire.TxIn{ sweepTx.TxIn = append(sweepTx.TxIn, &wire.TxIn{
PreviousOutPoint: wire.OutPoint{ PreviousOutPoint: prevOutPoint,
Hash: *txHash, Sequence: sequence,
Index: uint32(vout.Outspend.Vin),
},
Sequence: sequence,
}) })
signDescs = append(signDescs, &input.SignDescriptor{ signDescs = append(signDescs, &input.SignDescriptor{
KeyDesc: *target.keyDesc, KeyDesc: *target.keyDesc,
WitnessScript: target.script, WitnessScript: target.script,
Output: &wire.TxOut{ Output: prevTxOut,
PkScript: pkScript, HashType: txscript.SigHashAll,
Value: int64(vout.Value),
},
HashType: txscript.SigHashAll,
}) })
} }
} }
@ -259,7 +263,7 @@ func sweepRemoteClosed(extendedKey *hdkeychain.ExtendedKey, apiURL,
ExtendedKey: extendedKey, ExtendedKey: extendedKey,
ChainParams: chainParams, ChainParams: chainParams,
} }
sigHashes = input.NewTxSigHashesV0Only(sweepTx) sigHashes = txscript.NewTxSigHashes(sweepTx, prevOutFetcher)
) )
for idx, desc := range signDescs { for idx, desc := range signDescs {
desc.SigHashes = sigHashes desc.SigHashes = sigHashes

@ -222,11 +222,13 @@ func sweepTimeLock(extendedKey *hdkeychain.ExtendedKey, apiURL string,
} }
api := &btc.ExplorerAPI{BaseURL: apiURL} api := &btc.ExplorerAPI{BaseURL: apiURL}
sweepTx := wire.NewMsgTx(2) var (
totalOutputValue := int64(0) sweepTx = wire.NewMsgTx(2)
signDescs := make([]*input.SignDescriptor, 0) totalOutputValue = int64(0)
var estimator input.TxWeightEstimator signDescs = make([]*input.SignDescriptor, 0)
prevOutFetcher = txscript.NewMultiPrevOutFetcher(nil)
estimator input.TxWeightEstimator
)
for _, target := range targets { for _, target := range targets {
// We can't rely on the CSV delay of the channel DB to be // We can't rely on the CSV delay of the channel DB to be
// correct. But it doesn't cost us a lot to just brute force it. // correct. But it doesn't cost us a lot to just brute force it.
@ -246,11 +248,17 @@ func sweepTimeLock(extendedKey *hdkeychain.ExtendedKey, apiURL string,
} }
// Create the transaction input. // Create the transaction input.
prevOutPoint := wire.OutPoint{
Hash: target.txid,
Index: target.index,
}
prevTxOut := &wire.TxOut{
PkScript: scriptHash,
Value: target.value,
}
prevOutFetcher.AddPrevOut(prevOutPoint, prevTxOut)
sweepTx.TxIn = append(sweepTx.TxIn, &wire.TxIn{ sweepTx.TxIn = append(sweepTx.TxIn, &wire.TxIn{
PreviousOutPoint: wire.OutPoint{ PreviousOutPoint: prevOutPoint,
Hash: target.txid,
Index: target.index,
},
Sequence: input.LockTimeToSequence( Sequence: input.LockTimeToSequence(
false, uint32(csvTimeout), false, uint32(csvTimeout),
), ),
@ -264,11 +272,8 @@ func sweepTimeLock(extendedKey *hdkeychain.ExtendedKey, apiURL string,
target.delayBasePointDesc.PubKey, target.delayBasePointDesc.PubKey,
), ),
WitnessScript: script, WitnessScript: script,
Output: &wire.TxOut{ Output: prevTxOut,
PkScript: scriptHash, HashType: txscript.SigHashAll,
Value: target.value,
},
HashType: txscript.SigHashAll,
} }
totalOutputValue += target.value totalOutputValue += target.value
signDescs = append(signDescs, signDesc) signDescs = append(signDescs, signDesc)
@ -298,7 +303,7 @@ func sweepTimeLock(extendedKey *hdkeychain.ExtendedKey, apiURL string,
}} }}
// Sign the transaction now. // Sign the transaction now.
sigHashes := input.NewTxSigHashesV0Only(sweepTx) sigHashes := txscript.NewTxSigHashes(sweepTx, prevOutFetcher)
for idx, desc := range signDescs { for idx, desc := range signDescs {
desc.SigHashes = sigHashes desc.SigHashes = sigHashes
desc.InputIndex = idx desc.InputIndex = idx

@ -257,7 +257,10 @@ func sweepTimeLockManual(extendedKey *hdkeychain.ExtendedKey, apiURL string,
totalFee, sweepValue, estimator.Weight()) totalFee, sweepValue, estimator.Weight())
// Create the sign descriptor for the input then sign the transaction. // Create the sign descriptor for the input then sign the transaction.
sigHashes := input.NewTxSigHashesV0Only(sweepTx) prevOutFetcher := txscript.NewCannedPrevOutputFetcher(
scriptHash, sweepValue,
)
sigHashes := txscript.NewTxSigHashes(sweepTx, prevOutFetcher)
signDesc := &input.SignDescriptor{ signDesc := &input.SignDescriptor{
KeyDesc: *delayDesc, KeyDesc: *delayDesc,
SingleTweak: input.SingleTweakBytes( SingleTweak: input.SingleTweakBytes(

@ -67,7 +67,6 @@ func (lc *LightningChannel) SignedCommitTx() (*wire.MsgTx, error) {
// With this, we then generate the full witness so the caller can // With this, we then generate the full witness so the caller can
// broadcast a fully signed transaction. // broadcast a fully signed transaction.
lc.SignDesc.SigHashes = input.NewTxSigHashesV0Only(commitTx)
ourSig, err := lc.TXSigner.SignOutputRaw(commitTx, lc.SignDesc) ourSig, err := lc.TXSigner.SignOutputRaw(commitTx, lc.SignDesc)
if err != nil { if err != nil {
return nil, err return nil, err

@ -12,6 +12,7 @@ import (
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/wallet"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
) )
@ -132,13 +133,16 @@ func (s *Signer) AddPartialSignature(packet *psbt.Packet,
inputIndex int) error { inputIndex int) error {
// Now we add our partial signature. // Now we add our partial signature.
prevOutFetcher := wallet.PsbtPrevOutputFetcher(packet)
signDesc := &input.SignDescriptor{ signDesc := &input.SignDescriptor{
KeyDesc: keyDesc, KeyDesc: keyDesc,
WitnessScript: witnessScript, WitnessScript: witnessScript,
Output: utxo, Output: utxo,
InputIndex: inputIndex, InputIndex: inputIndex,
HashType: txscript.SigHashAll, HashType: txscript.SigHashAll,
SigHashes: input.NewTxSigHashesV0Only(packet.UnsignedTx), SigHashes: txscript.NewTxSigHashes(
packet.UnsignedTx, prevOutFetcher,
),
} }
ourSigRaw, err := s.SignOutputRaw(packet.UnsignedTx, signDesc) ourSigRaw, err := s.SignOutputRaw(packet.UnsignedTx, signDesc)
if err != nil { if err != nil {

Loading…
Cancel
Save