diff --git a/client.go b/client.go index b709c8a..a7d2f86 100644 --- a/client.go +++ b/client.go @@ -175,9 +175,9 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { for _, swp := range loopOutSwaps { htlc, err := swap.NewHtlc( - swp.Contract.CltvExpiry, swp.Contract.SenderKey, - swp.Contract.ReceiverKey, swp.Hash, swap.HtlcP2WSH, - s.lndServices.ChainParams, + swap.HtlcV1, swp.Contract.CltvExpiry, + swp.Contract.SenderKey, swp.Contract.ReceiverKey, + swp.Hash, swap.HtlcP2WSH, s.lndServices.ChainParams, ) if err != nil { return nil, err @@ -195,18 +195,18 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { for _, swp := range loopInSwaps { htlcNP2WSH, err := swap.NewHtlc( - swp.Contract.CltvExpiry, swp.Contract.SenderKey, - swp.Contract.ReceiverKey, swp.Hash, swap.HtlcNP2WSH, - s.lndServices.ChainParams, + swap.HtlcV1, swp.Contract.CltvExpiry, + swp.Contract.SenderKey, swp.Contract.ReceiverKey, + swp.Hash, swap.HtlcNP2WSH, s.lndServices.ChainParams, ) if err != nil { return nil, err } htlcP2WSH, err := swap.NewHtlc( - swp.Contract.CltvExpiry, swp.Contract.SenderKey, - swp.Contract.ReceiverKey, swp.Hash, swap.HtlcP2WSH, - s.lndServices.ChainParams, + swap.HtlcV1, swp.Contract.CltvExpiry, + swp.Contract.SenderKey, swp.Contract.ReceiverKey, + swp.Hash, swap.HtlcP2WSH, s.lndServices.ChainParams, ) if err != nil { return nil, err diff --git a/loopd/view.go b/loopd/view.go index a0e7a4f..8fd60a2 100644 --- a/loopd/view.go +++ b/loopd/view.go @@ -50,6 +50,7 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { for _, s := range swaps { htlc, err := swap.NewHtlc( + swap.HtlcV1, s.Contract.CltvExpiry, s.Contract.SenderKey, s.Contract.ReceiverKey, @@ -101,6 +102,7 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { for _, s := range swaps { htlc, err := swap.NewHtlc( + swap.HtlcV1, s.Contract.CltvExpiry, s.Contract.SenderKey, s.Contract.ReceiverKey, diff --git a/loopin.go b/loopin.go index 89777d9..0b818cb 100644 --- a/loopin.go +++ b/loopin.go @@ -756,12 +756,13 @@ func (s *loopInSwap) publishTimeoutTx(ctx context.Context, } witnessFunc := func(sig []byte) (wire.TxWitness, error) { - return s.htlc.GenTimeoutWitness(sig) + return s.htlc.GenTimeoutWitness(sig), nil } + sequence := uint32(0) timeoutTx, err := s.sweeper.CreateSweepTx( - ctx, s.height, s.htlc, *htlcOutpoint, s.SenderKey, witnessFunc, - htlcValue, fee, s.timeoutAddr, + ctx, s.height, sequence, s.htlc, *htlcOutpoint, s.SenderKey, + witnessFunc, htlcValue, fee, s.timeoutAddr, ) if err != nil { return err diff --git a/loopin_test.go b/loopin_test.go index 9848236..7b08bcc 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -331,8 +331,9 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool) { } htlc, err := swap.NewHtlc( - contract.CltvExpiry, contract.SenderKey, contract.ReceiverKey, - testPreimage.Hash(), swap.HtlcNP2WSH, cfg.lnd.ChainParams, + swap.HtlcV1, contract.CltvExpiry, contract.SenderKey, + contract.ReceiverKey, testPreimage.Hash(), swap.HtlcNP2WSH, + cfg.lnd.ChainParams, ) if err != nil { t.Fatal(err) diff --git a/loopout.go b/loopout.go index 0229cdb..5042e21 100644 --- a/loopout.go +++ b/loopout.go @@ -937,8 +937,8 @@ func (s *loopOutSwap) sweep(ctx context.Context, // Create sweep tx. sweepTx, err := s.sweeper.CreateSweepTx( - ctx, s.height, s.htlc, htlcOutpoint, s.ReceiverKey, witnessFunc, - htlcValue, fee, s.DestAddr, + ctx, s.height, s.htlc.SuccessSequence(), s.htlc, htlcOutpoint, + s.ReceiverKey, witnessFunc, htlcValue, fee, s.DestAddr, ) if err != nil { return err diff --git a/swap.go b/swap.go index fd5dc54..1ca50fd 100644 --- a/swap.go +++ b/swap.go @@ -51,7 +51,7 @@ func newSwapKit(hash lntypes.Hash, swapType swap.Type, cfg *swapConfig, // getHtlc composes and returns the on-chain swap script. func (s *swapKit) getHtlc(outputType swap.HtlcOutputType) (*swap.Htlc, error) { return swap.NewHtlc( - s.contract.CltvExpiry, s.contract.SenderKey, + swap.HtlcV1, s.contract.CltvExpiry, s.contract.SenderKey, s.contract.ReceiverKey, s.hash, outputType, s.swapConfig.lnd.ChainParams, ) diff --git a/swap/htlc.go b/swap/htlc.go index c207b3d..9982cc4 100644 --- a/swap/htlc.go +++ b/swap/htlc.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/sha256" "errors" + "fmt" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/txscript" @@ -25,9 +26,52 @@ const ( HtlcNP2WSH ) +// ScriptVersion defines the HTLC script version. +type ScriptVersion uint8 + +const ( + // HtlcV1 refers to the original version of the HTLC script. + HtlcV1 ScriptVersion = iota + + // HtlcV2 refers to the improved version of the HTLC script. + HtlcV2 +) + +// htlcScript defines an interface for the different HTLC implementations. +type HtlcScript interface { + // genSuccessWitness returns the success script to spend this htlc with + // the preimage. + genSuccessWitness(receiverSig []byte, preimage lntypes.Preimage) wire.TxWitness + + // GenTimeoutWitness returns the timeout script to spend this htlc after + // timeout. + GenTimeoutWitness(senderSig []byte) wire.TxWitness + + // IsSuccessWitness checks whether the given stack is valid for + // redeeming the htlc. + IsSuccessWitness(witness wire.TxWitness) bool + + // Script returns the htlc script. + Script() []byte + + // MaxSuccessWitnessSize returns the maximum witness size for the + // success case witness. + MaxSuccessWitnessSize() int + + // MaxTimeoutWitnessSize returns the maximum witness size for the + // timeout case witness. + MaxTimeoutWitnessSize() int + + // SuccessSequence returns the sequence to spend this htlc in the + // success case. + SuccessSequence() uint32 +} + // Htlc contains relevant htlc information from the receiver perspective. type Htlc struct { - Script []byte + HtlcScript + + Version ScriptVersion PkScript []byte Hash lntypes.Hash OutputType HtlcOutputType @@ -45,9 +89,12 @@ var ( // the maximum value for cltv expiry to get the maximum (worst case) // script size. QuoteHtlc, _ = NewHtlc( + HtlcV2, ^int32(0), quoteKey, quoteKey, quoteHash, HtlcP2WSH, &chaincfg.MainNetParams, ) + + ErrInvalidScriptVersion = fmt.Errorf("invalid script version") ) // String returns the string value of HtlcOutputType. @@ -65,18 +112,36 @@ func (h HtlcOutputType) String() string { } // NewHtlc returns a new instance. -func NewHtlc(cltvExpiry int32, senderKey, receiverKey [33]byte, +func NewHtlc(version ScriptVersion, cltvExpiry int32, + senderKey, receiverKey [33]byte, hash lntypes.Hash, outputType HtlcOutputType, chainParams *chaincfg.Params) (*Htlc, error) { - script, err := swapHTLCScript( - cltvExpiry, senderKey, receiverKey, hash, + var ( + err error + htlc HtlcScript ) + + switch version { + case HtlcV1: + htlc, err = newHTLCScriptV1( + cltvExpiry, senderKey, receiverKey, hash, + ) + + case HtlcV2: + htlc, err = newHTLCScriptV2( + cltvExpiry, senderKey, receiverKey, hash, + ) + + default: + return nil, ErrInvalidScriptVersion + } + if err != nil { return nil, err } - p2wshPkScript, err := input.WitnessScriptHash(script) + p2wshPkScript, err := input.WitnessScriptHash(htlc.Script()) if err != nil { return nil, err } @@ -134,8 +199,9 @@ func NewHtlc(cltvExpiry int32, senderKey, receiverKey [33]byte, } return &Htlc{ + HtlcScript: htlc, Hash: hash, - Script: script, + Version: version, PkScript: pkScript, OutputType: outputType, ChainParams: chainParams, @@ -144,20 +210,63 @@ func NewHtlc(cltvExpiry int32, senderKey, receiverKey [33]byte, }, nil } -// SwapHTLCScript returns the on-chain HTLC witness script. +// GenSuccessWitness returns the success script to spend this htlc with +// the preimage. +func (h *Htlc) GenSuccessWitness(receiverSig []byte, + preimage lntypes.Preimage) (wire.TxWitness, error) { + + if h.Hash != preimage.Hash() { + return nil, errors.New("preimage doesn't match hash") + } + + return h.genSuccessWitness(receiverSig, preimage), nil +} + +// AddSuccessToEstimator adds a successful spend to a weight estimator. +func (h *Htlc) AddSuccessToEstimator(estimator *input.TxWeightEstimator) { + maxSuccessWitnessSize := h.MaxSuccessWitnessSize() + + switch h.OutputType { + case HtlcP2WSH: + estimator.AddWitnessInput(maxSuccessWitnessSize) + + case HtlcNP2WSH: + estimator.AddNestedP2WSHInput(maxSuccessWitnessSize) + } +} + +// AddTimeoutToEstimator adds a timeout spend to a weight estimator. +func (h *Htlc) AddTimeoutToEstimator(estimator *input.TxWeightEstimator) { + maxTimeoutWitnessSize := h.MaxTimeoutWitnessSize() + + switch h.OutputType { + case HtlcP2WSH: + estimator.AddWitnessInput(maxTimeoutWitnessSize) + + case HtlcNP2WSH: + estimator.AddNestedP2WSHInput(maxTimeoutWitnessSize) + } +} + +// HtlcScriptV1 encapsulates the htlc v1 script. +type HtlcScriptV1 struct { + script []byte +} + +// newHTLCScriptV1 constructs an HtlcScript with the HTLC V1 witness script. // // OP_SIZE 32 OP_EQUAL // OP_IF -// OP_HASH160 OP_EQUALVERIFY -// +// OP_HASH160 OP_EQUALVERIFY +// // OP_ELSE // OP_DROP // OP_CHECKLOCKTIMEVERIFY OP_DROP -// +// // OP_ENDIF // OP_CHECKSIG -func swapHTLCScript(cltvExpiry int32, senderHtlcKey, - receiverHtlcKey [33]byte, swapHash lntypes.Hash) ([]byte, error) { +func newHTLCScriptV1(cltvExpiry int32, senderHtlcKey, + receiverHtlcKey [33]byte, swapHash lntypes.Hash) (*HtlcScriptV1, error) { builder := txscript.NewScriptBuilder() @@ -187,90 +296,219 @@ func swapHTLCScript(cltvExpiry int32, senderHtlcKey, builder.AddOp(txscript.OP_CHECKSIG) - return builder.Script() + script, err := builder.Script() + if err != nil { + return nil, err + } + + return &HtlcScriptV1{ + script: script, + }, nil } -// GenSuccessWitness returns the success script to spend this htlc with the -// preimage. -func (h *Htlc) GenSuccessWitness(receiverSig []byte, - preimage lntypes.Preimage) (wire.TxWitness, error) { - - if h.Hash != preimage.Hash() { - return nil, errors.New("preimage doesn't match hash") - } +// genSuccessWitness returns the success script to spend this htlc with +// the preimage. +func (h *HtlcScriptV1) genSuccessWitness(receiverSig []byte, + preimage lntypes.Preimage) wire.TxWitness { witnessStack := make(wire.TxWitness, 3) witnessStack[0] = append(receiverSig, byte(txscript.SigHashAll)) witnessStack[1] = preimage[:] - witnessStack[2] = h.Script + witnessStack[2] = h.script - return witnessStack, nil + return witnessStack +} + +// GenTimeoutWitness returns the timeout script to spend this htlc after +// timeout. +func (h *HtlcScriptV1) GenTimeoutWitness(senderSig []byte) wire.TxWitness { + + witnessStack := make(wire.TxWitness, 3) + witnessStack[0] = append(senderSig, byte(txscript.SigHashAll)) + witnessStack[1] = []byte{0} + witnessStack[2] = h.script + + return witnessStack } // IsSuccessWitness checks whether the given stack is valid for redeeming the // htlc. -func (h *Htlc) IsSuccessWitness(witness wire.TxWitness) bool { +func (h *HtlcScriptV1) IsSuccessWitness(witness wire.TxWitness) bool { if len(witness) != 3 { return false } isTimeoutTx := bytes.Equal([]byte{0}, witness[1]) + return !isTimeoutTx +} + +// Script returns the htlc script. +func (h *HtlcScriptV1) Script() []byte { + return h.script +} + +// MaxSuccessWitnessSize returns the maximum success witness size. +func (h *HtlcScriptV1) MaxSuccessWitnessSize() int { + // Calculate maximum success witness size + // + // - number_of_witness_elements: 1 byte + // - receiver_sig_length: 1 byte + // - receiver_sig: 73 bytes + // - preimage_length: 1 byte + // - preimage: 32 bytes + // - witness_script_length: 1 byte + // - witness_script: len(script) bytes + return 1 + 1 + 73 + 1 + 32 + 1 + len(h.script) +} + +// MaxTimeoutWitnessSize return the maximum timeout witness size. +func (h *HtlcScriptV1) MaxTimeoutWitnessSize() int { + // Calculate maximum timeout witness size + // + // - number_of_witness_elements: 1 byte + // - sender_sig_length: 1 byte + // - sender_sig: 73 bytes + // - zero_length: 1 byte + // - zero: 1 byte + // - witness_script_length: 1 byte + // - witness_script: len(script) bytes + return 1 + 1 + 73 + 1 + 1 + 1 + len(h.script) +} + +// SuccessSequence returns the sequence to spend this htlc in the success case. +func (h *HtlcScriptV1) SuccessSequence() uint32 { + return 0 +} + +// HtlcScriptV2 encapsulates the htlc v2 script. +type HtlcScriptV2 struct { + script []byte + senderKey [33]byte +} + +// newHTLCScriptV2 construct an HtlcScipt with the HTLC V2 witness script. +// +// OP_CHECKSIG OP_NOTIF +// OP_DUP OP_HASH160 OP_EQUALVERIFY OP_CHECKSIGVERIFY +// OP_CHECKLOCKTIMEVERIFY +// OP_ELSE +// OP_SIZE <20> OP_EQUALVERIFY OP_HASH160 OP_EQUALVERIFY 1 +// OP_CHECKSEQUENCEVERIFY +// OP_ENDIF +func newHTLCScriptV2(cltvExpiry int32, senderHtlcKey, + receiverHtlcKey [33]byte, swapHash lntypes.Hash) (*HtlcScriptV2, error) { + + builder := txscript.NewScriptBuilder() + builder.AddData(receiverHtlcKey[:]) + builder.AddOp(txscript.OP_CHECKSIG) + + builder.AddOp(txscript.OP_NOTIF) + + builder.AddOp(txscript.OP_DUP) + builder.AddOp(txscript.OP_HASH160) + senderHtlcKeyHash := sha256.Sum256(senderHtlcKey[:]) + builder.AddData(input.Ripemd160H(senderHtlcKeyHash[:])) + + builder.AddOp(txscript.OP_EQUALVERIFY) + builder.AddOp(txscript.OP_CHECKSIGVERIFY) + + builder.AddInt64(int64(cltvExpiry)) + builder.AddOp(txscript.OP_CHECKLOCKTIMEVERIFY) + + builder.AddOp(txscript.OP_ELSE) + + builder.AddOp(txscript.OP_SIZE) + builder.AddInt64(0x20) + builder.AddOp(txscript.OP_EQUALVERIFY) + builder.AddOp(txscript.OP_HASH160) + builder.AddData(input.Ripemd160H(swapHash[:])) + builder.AddOp(txscript.OP_EQUALVERIFY) + builder.AddOp(txscript.OP_1) + + builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY) + + builder.AddOp(txscript.OP_ENDIF) + + script, err := builder.Script() + if err != nil { + return nil, err + } + + return &HtlcScriptV2{ + script: script, + senderKey: senderHtlcKey, + }, nil +} + +// genSuccessWitness returns the success script to spend this htlc with +// the preimage. +func (h *HtlcScriptV2) genSuccessWitness(receiverSig []byte, + preimage lntypes.Preimage) wire.TxWitness { + + witnessStack := make(wire.TxWitness, 3) + witnessStack[0] = preimage[:] + witnessStack[1] = append(receiverSig, byte(txscript.SigHashAll)) + witnessStack[2] = h.script + + return witnessStack +} + +// IsSuccessWitness checks whether the given stack is valid for redeeming the +// htlc. +func (h *HtlcScriptV2) IsSuccessWitness(witness wire.TxWitness) bool { + isTimeoutTx := len(witness) == 4 return !isTimeoutTx } // GenTimeoutWitness returns the timeout script to spend this htlc after // timeout. -func (h *Htlc) GenTimeoutWitness(senderSig []byte) (wire.TxWitness, error) { +func (h *HtlcScriptV2) GenTimeoutWitness(senderSig []byte) wire.TxWitness { - witnessStack := make(wire.TxWitness, 3) + witnessStack := make(wire.TxWitness, 4) witnessStack[0] = append(senderSig, byte(txscript.SigHashAll)) - witnessStack[1] = []byte{0} - witnessStack[2] = h.Script + witnessStack[1] = h.senderKey[:] + witnessStack[2] = []byte{} + witnessStack[3] = h.script - return witnessStack, nil + return witnessStack } -// AddSuccessToEstimator adds a successful spend to a weight estimator. -func (h *Htlc) AddSuccessToEstimator(estimator *input.TxWeightEstimator) { +// Script returns the htlc script. +func (h *HtlcScriptV2) Script() []byte { + return h.script +} + +// MaxSuccessWitnessSize returns maximum success witness size. +func (h *HtlcScriptV2) MaxSuccessWitnessSize() int { // Calculate maximum success witness size // // - number_of_witness_elements: 1 byte // - receiver_sig_length: 1 byte // - receiver_sig: 73 bytes // - preimage_length: 1 byte - // - preimage: 33 bytes + // - preimage: 32 bytes // - witness_script_length: 1 byte // - witness_script: len(script) bytes - maxSuccessWitnessSize := 1 + 1 + 73 + 1 + 33 + 1 + len(h.Script) - - switch h.OutputType { - case HtlcP2WSH: - estimator.AddWitnessInput(maxSuccessWitnessSize) - - case HtlcNP2WSH: - estimator.AddNestedP2WSHInput(maxSuccessWitnessSize) - } + return 1 + 1 + 73 + 1 + 32 + 1 + len(h.script) } -// AddTimeoutToEstimator adds a timeout spend to a weight estimator. -func (h *Htlc) AddTimeoutToEstimator(estimator *input.TxWeightEstimator) { +// MaxTimeoutWitnessSize returns maximum timeout witness size. +func (h *HtlcScriptV2) MaxTimeoutWitnessSize() int { // Calculate maximum timeout witness size // // - number_of_witness_elements: 1 byte // - sender_sig_length: 1 byte // - sender_sig: 73 bytes - // - zero_length: 1 byte + // - sender_key_length: 1 byte + // - sender_key: 33 bytes // - zero: 1 byte // - witness_script_length: 1 byte // - witness_script: len(script) bytes - maxTimeoutWitnessSize := 1 + 1 + 73 + 1 + 1 + 1 + len(h.Script) - - switch h.OutputType { - case HtlcP2WSH: - estimator.AddWitnessInput(maxTimeoutWitnessSize) + return 1 + 1 + 73 + 1 + 33 + 1 + 1 + len(h.script) +} - case HtlcNP2WSH: - estimator.AddNestedP2WSHInput(maxTimeoutWitnessSize) - } +// SuccessSequence returns the sequence to spend this htlc in the success case. +func (h *HtlcScriptV2) SuccessSequence() uint32 { + return 1 } diff --git a/swap/htlc_test.go b/swap/htlc_test.go new file mode 100644 index 0000000..e2aa492 --- /dev/null +++ b/swap/htlc_test.go @@ -0,0 +1,315 @@ +package swap + +import ( + "bytes" + "crypto/sha256" + "fmt" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/lightninglabs/loop/test" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/stretchr/testify/require" +) + +// assertEngineExecution executes the VM returned by the newEngine closure, +// asserting the result matches the validity expectation. In the case where it +// doesn't match the expectation, it executes the script step-by-step and +// prints debug information to stdout. +// This code is adopted from: lnd/input/script_utils_test.go +func assertEngineExecution(t *testing.T, valid bool, + newEngine func() (*txscript.Engine, error)) { + + t.Helper() + + // Get a new VM to execute. + vm, err := newEngine() + require.NoError(t, err, "unable to create engine") + + // Execute the VM, only go on to the step-by-step execution if it + // doesn't validate as expected. + vmErr := vm.Execute() + executionValid := vmErr == nil + if valid == executionValid { + return + } + + // Now that the execution didn't match what we expected, fetch a new VM + // to step through. + vm, err = newEngine() + require.NoError(t, err, "unable to create engine") + + // This buffer will trace execution of the Script, dumping out to + // stdout. + var debugBuf bytes.Buffer + + done := false + for !done { + dis, err := vm.DisasmPC() + if err != nil { + t.Fatalf("stepping (%v)\n", err) + } + debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis)) + + done, err = vm.Step() + if err != nil && valid { + fmt.Println(debugBuf.String()) + t.Fatalf("spend test case failed, spend "+ + "should be valid: %v", err) + } else if err == nil && !valid && done { + fmt.Println(debugBuf.String()) + t.Fatalf("spend test case succeed, spend "+ + "should be invalid: %v", err) + } + + debugBuf.WriteString( + fmt.Sprintf("Stack: %v", vm.GetStack()), + ) + debugBuf.WriteString( + fmt.Sprintf("AltStack: %v", vm.GetAltStack()), + ) + } + + // If we get to this point the unexpected case was not reached + // during step execution, which happens for some checks, like + // the clean-stack rule. + validity := "invalid" + if valid { + validity = "valid" + } + + fmt.Println(debugBuf.String()) + t.Fatalf( + "%v spend test case execution ended with: %v", validity, vmErr, + ) +} + +// TestHtlcV2 tests the HTLC V2 script success and timeout spend cases. +func TestHtlcV2(t *testing.T) { + const ( + htlcValue = btcutil.Amount(1 * 10e8) + testCltvExpiry = 24 + ) + + var ( + testPreimage = lntypes.Preimage([32]byte{1, 2, 3}) + err error + ) + + // We generate a fake output, and the corresponding txin. This output + // doesn't need to exist, as we'll only be validating spending from the + // transaction that references this. + fundingOut := &wire.OutPoint{ + Hash: chainhash.Hash(sha256.Sum256([]byte{1, 2, 3})), + Index: 50, + } + fakeFundingTxIn := wire.NewTxIn(fundingOut, nil, nil) + + sweepTx := wire.NewMsgTx(2) + sweepTx.AddTxIn(fakeFundingTxIn) + sweepTx.AddTxOut( + &wire.TxOut{ + PkScript: []byte("doesn't matter"), + Value: int64(htlcValue), + }, + ) + + // Create sender and receiver keys. + senderPrivKey, senderPubKey := test.CreateKey(1) + receiverPrivKey, receiverPubKey := test.CreateKey(2) + + var ( + senderKey [33]byte + receiverKey [33]byte + ) + copy(senderKey[:], senderPubKey.SerializeCompressed()) + copy(receiverKey[:], receiverPubKey.SerializeCompressed()) + + hash := sha256.Sum256(testPreimage[:]) + + // Create the htlc. + htlc, err := NewHtlc( + HtlcV2, testCltvExpiry, + senderKey, receiverKey, hash, + HtlcP2WSH, &chaincfg.MainNetParams, + ) + require.NoError(t, err) + + // Create the htlc output we'll try to spend. + htlcOutput := &wire.TxOut{ + Value: int64(htlcValue), + PkScript: htlc.PkScript, + } + + // Create signers for sender and receiver. + senderSigner := &input.MockSigner{ + Privkeys: []*btcec.PrivateKey{senderPrivKey}, + } + receiverSigner := &input.MockSigner{ + Privkeys: []*btcec.PrivateKey{receiverPrivKey}, + } + + signTx := func(tx *wire.MsgTx, pubkey *btcec.PublicKey, + signer *input.MockSigner) (input.Signature, error) { + + signDesc := &input.SignDescriptor{ + KeyDesc: keychain.KeyDescriptor{ + PubKey: pubkey, + }, + + WitnessScript: htlc.Script(), + Output: htlcOutput, + HashType: txscript.SigHashAll, + SigHashes: txscript.NewTxSigHashes(tx), + InputIndex: 0, + } + + return signer.SignOutputRaw(tx, signDesc) + } + + testCases := []struct { + name string + witness func(*testing.T) wire.TxWitness + valid bool + }{ + { + // Receiver can spend with valid preimage. + "success case spend with valid preimage", + func(t *testing.T) wire.TxWitness { + sweepTx.TxIn[0].Sequence = htlc.SuccessSequence() + sweepSig, err := signTx( + sweepTx, receiverPubKey, receiverSigner, + ) + require.NoError(t, err) + + witness, err := htlc.GenSuccessWitness( + sweepSig.Serialize(), testPreimage, + ) + require.NoError(t, err) + + return witness + + }, true, + }, + { + // Receiver can't spend with the valid preimage and with + // zero sequence. + "success case no spend with valid preimage and zero sequence", + func(t *testing.T) wire.TxWitness { + sweepTx.TxIn[0].Sequence = 0 + sweepSig, err := signTx( + sweepTx, receiverPubKey, receiverSigner, + ) + require.NoError(t, err) + + witness, err := htlc.GenSuccessWitness( + sweepSig.Serialize(), testPreimage, + ) + require.NoError(t, err) + + return witness + }, false, + }, + { + // Sender can't spend when haven't yet timed out. + "timeout case no spend before timeout", + func(t *testing.T) wire.TxWitness { + sweepTx.LockTime = testCltvExpiry - 1 + sweepSig, err := signTx( + sweepTx, senderPubKey, senderSigner, + ) + require.NoError(t, err) + + return htlc.GenTimeoutWitness( + sweepSig.Serialize(), + ) + }, false, + }, + { + // Sender can spend after timeout. + "timeout case spend after timeout", + func(t *testing.T) wire.TxWitness { + sweepTx.LockTime = testCltvExpiry + sweepSig, err := signTx( + sweepTx, senderPubKey, senderSigner, + ) + require.NoError(t, err) + + return htlc.GenTimeoutWitness( + sweepSig.Serialize(), + ) + }, true, + }, + { + // Receiver can't spend after timeout. + "timeout case receiver cannot spend", + func(t *testing.T) wire.TxWitness { + sweepTx.LockTime = testCltvExpiry + sweepSig, err := signTx( + sweepTx, receiverPubKey, receiverSigner, + ) + require.NoError(t, err) + + return htlc.GenTimeoutWitness( + sweepSig.Serialize(), + ) + }, false, + }, + { + // Sender can't spend after timeout with wrong sender + // key. + "timeout case cannot spend with wrong key", + func(t *testing.T) wire.TxWitness { + bogusKey := [33]byte{0xb, 0xa, 0xd} + + // Create the htlc with the bogus key. + htlc, err = NewHtlc( + HtlcV2, testCltvExpiry, + bogusKey, receiverKey, hash, + HtlcP2WSH, &chaincfg.MainNetParams, + ) + require.NoError(t, err) + + // Create the htlc output we'll try to spend. + htlcOutput = &wire.TxOut{ + Value: int64(htlcValue), + PkScript: htlc.PkScript, + } + + sweepTx.LockTime = testCltvExpiry + sweepSig, err := signTx( + sweepTx, senderPubKey, senderSigner, + ) + require.NoError(t, err) + + return htlc.GenTimeoutWitness( + sweepSig.Serialize(), + ) + }, false, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + sweepTx.TxIn[0].Witness = testCase.witness(t) + + newEngine := func() (*txscript.Engine, error) { + return txscript.NewEngine( + htlc.PkScript, sweepTx, 0, + txscript.StandardVerifyFlags, nil, + nil, int64(htlcValue)) + } + + assertEngineExecution(t, testCase.valid, newEngine) + }) + } +} diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 22db0c5..c22b8da 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -21,7 +21,7 @@ type Sweeper struct { // CreateSweepTx creates an htlc sweep tx. func (s *Sweeper) CreateSweepTx( - globalCtx context.Context, height int32, + globalCtx context.Context, height int32, sequence uint32, htlc *swap.Htlc, htlcOutpoint wire.OutPoint, keyBytes [33]byte, witnessFunc func(sig []byte) (wire.TxWitness, error), @@ -37,6 +37,7 @@ func (s *Sweeper) CreateSweepTx( sweepTx.AddTxIn(&wire.TxIn{ PreviousOutPoint: htlcOutpoint, SignatureScript: htlc.SigScript, + Sequence: sequence, }) // Add output for the destination address. @@ -58,7 +59,7 @@ func (s *Sweeper) CreateSweepTx( } signDesc := lndclient.SignDescriptor{ - WitnessScript: htlc.Script, + WitnessScript: htlc.Script(), Output: &wire.TxOut{ Value: int64(amount), },