swap: refactor htlc construction to allow passing of internal keys

This commit is a refactor of how we construct htlcs to make it possible
to pass in internal keys for the sender and receiver when creating P2TR
htlcs. Furthermore the commit also cleans up constructors to not pass in
script versions and output types to make the code more readable.
pull/541/head
Andras Banki-Horvath 1 year ago
parent 35e0120e8f
commit bdb4b773ed
No known key found for this signature in database
GPG Key ID: 80E5375C094198D8

@ -213,29 +213,24 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
SwapHash: swp.Hash,
LastUpdate: swp.LastUpdateTime(),
}
scriptVersion := GetHtlcScriptVersion(
swp.Contract.ProtocolVersion,
)
outputType := swap.HtlcP2WSH
if scriptVersion == swap.HtlcV3 {
outputType = swap.HtlcP2TR
}
htlc, err := swap.NewHtlc(
scriptVersion,
swp.Contract.CltvExpiry, swp.Contract.SenderKey,
swp.Contract.ReceiverKey, swp.Hash,
outputType, s.lndServices.ChainParams,
htlc, err := GetHtlc(
swp.Hash, &swp.Contract.SwapContract,
s.lndServices.ChainParams,
)
if err != nil {
return nil, err
}
if outputType == swap.HtlcP2TR {
swapInfo.HtlcAddressP2TR = htlc.Address
} else {
switch htlc.OutputType {
case swap.HtlcP2WSH:
swapInfo.HtlcAddressP2WSH = htlc.Address
case swap.HtlcP2TR:
swapInfo.HtlcAddressP2TR = htlc.Address
default:
return nil, swap.ErrInvalidOutputType
}
swaps = append(swaps, swapInfo)
@ -250,34 +245,23 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
LastUpdate: swp.LastUpdateTime(),
}
scriptVersion := GetHtlcScriptVersion(
swp.Contract.SwapContract.ProtocolVersion,
htlc, err := GetHtlc(
swp.Hash, &swp.Contract.SwapContract,
s.lndServices.ChainParams,
)
if err != nil {
return nil, err
}
if scriptVersion == swap.HtlcV3 {
htlcP2TR, err := swap.NewHtlc(
swap.HtlcV3, swp.Contract.CltvExpiry,
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
swp.Hash, swap.HtlcP2TR,
s.lndServices.ChainParams,
)
if err != nil {
return nil, err
}
switch htlc.OutputType {
case swap.HtlcP2WSH:
swapInfo.HtlcAddressP2WSH = htlc.Address
swapInfo.HtlcAddressP2TR = htlcP2TR.Address
} else {
htlcP2WSH, err := swap.NewHtlc(
swap.HtlcV2, swp.Contract.CltvExpiry,
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
swp.Hash, swap.HtlcP2WSH,
s.lndServices.ChainParams,
)
if err != nil {
return nil, err
}
case swap.HtlcP2TR:
swapInfo.HtlcAddressP2TR = htlc.Address
swapInfo.HtlcAddressP2WSH = htlcP2WSH.Address
default:
return nil, swap.ErrInvalidOutputType
}
swaps = append(swaps, swapInfo)

@ -284,16 +284,26 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
// Assert that the loopout htlc equals to the expected one.
scriptVersion := GetHtlcScriptVersion(protocolVersion)
var htlc *swap.Htlc
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
outputType = swap.HtlcP2WSH
switch scriptVersion {
case swap.HtlcV2:
htlc, err = swap.NewHtlcV2(
pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, hash, &chaincfg.TestNet3Params,
)
case swap.HtlcV3:
htlc, err = swap.NewHtlcV3(
pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, senderKey, receiverKey, hash,
&chaincfg.TestNet3Params,
)
default:
t.Fatalf(swap.ErrInvalidScriptVersion.Error())
}
htlc, err := swap.NewHtlc(
scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, hash, outputType, &chaincfg.TestNet3Params,
)
require.NoError(t, err)
require.Equal(t, htlc.PkScript, confIntent.PkScript)

@ -7,7 +7,6 @@ import (
"github.com/lightninglabs/lndclient"
"github.com/lightninglabs/loop"
"github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/swap"
)
// view prints all swaps currently in the database.
@ -49,24 +48,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
}
for _, s := range swaps {
scriptVersion := loop.GetHtlcScriptVersion(
s.Contract.ProtocolVersion,
)
var outputType swap.HtlcOutputType
switch scriptVersion {
case swap.HtlcV2:
outputType = swap.HtlcP2WSH
case swap.HtlcV3:
outputType = swap.HtlcP2TR
}
htlc, err := swap.NewHtlc(
loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion),
s.Contract.CltvExpiry,
s.Contract.SenderKey,
s.Contract.ReceiverKey,
s.Hash, outputType, chainParams,
htlc, err := loop.GetHtlc(
s.Hash, &s.Contract.SwapContract, chainParams,
)
if err != nil {
return err
@ -77,7 +60,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
s.Contract.InitiationTime, s.Contract.InitiationHeight,
)
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
fmt.Printf(" Htlc address: %v\n", htlc.Address)
fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType,
htlc.Address)
fmt.Printf(" Uncharge channels: %v\n",
s.Contract.OutgoingChanSet)
@ -113,12 +97,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
}
for _, s := range swaps {
htlc, err := swap.NewHtlc(
loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion),
s.Contract.CltvExpiry,
s.Contract.SenderKey,
s.Contract.ReceiverKey,
s.Hash, swap.HtlcP2WSH, chainParams,
htlc, err := loop.GetHtlc(
s.Hash, &s.Contract.SwapContract, chainParams,
)
if err != nil {
return err
@ -129,7 +109,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
s.Contract.InitiationTime, s.Contract.InitiationHeight,
)
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
fmt.Printf(" Htlc address: %v\n", htlc.Address)
fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType,
htlc.Address)
fmt.Printf(" Amt: %v, Expiry: %v\n",
s.Contract.AmountRequested, s.Contract.CltvExpiry,
)

@ -404,26 +404,26 @@ func validateLoopInContract(lnd *lndclient.LndServices,
// initHtlcs creates and updates the native and nested segwit htlcs
// of the loopInSwap.
func (s *loopInSwap) initHtlcs() error {
if IsTaprootSwap(&s.SwapContract) {
htlcP2TR, err := s.swapKit.getHtlc(swap.HtlcP2TR)
if err != nil {
return err
}
htlc, err := GetHtlc(
s.hash, &s.SwapContract, s.swapKit.lnd.ChainParams,
)
if err != nil {
return err
}
s.swapKit.log.Infof("Htlc address (P2TR): %v", htlcP2TR.Address)
s.htlcP2TR = htlcP2TR
switch htlc.OutputType {
case swap.HtlcP2WSH:
s.htlcP2WSH = htlc
return nil
}
case swap.HtlcP2TR:
s.htlcP2TR = htlc
htlcP2WSH, err := s.swapKit.getHtlc(swap.HtlcP2WSH)
if err != nil {
return err
default:
return fmt.Errorf("invalid output type")
}
// Log htlc addresses for debugging.
s.swapKit.log.Infof("Htlc address (P2WSH): %v", htlcP2WSH.Address)
s.htlcP2WSH = htlcP2WSH
s.swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType,
htlc.Address)
return nil
}

@ -455,21 +455,32 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool,
pendSwap.Loop.Events[0].Cost = cost
}
scriptVersion := GetHtlcScriptVersion(storedVersion)
var (
htlc *swap.Htlc
err error
)
outputType := swap.HtlcP2WSH
if scriptVersion == swap.HtlcV3 {
outputType = swap.HtlcP2TR
switch GetHtlcScriptVersion(storedVersion) {
case swap.HtlcV2:
htlc, err = swap.NewHtlcV2(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, testPreimage.Hash(),
cfg.lnd.ChainParams,
)
case swap.HtlcV3:
htlc, err = swap.NewHtlcV3(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, contract.SenderKey,
contract.ReceiverKey, testPreimage.Hash(),
cfg.lnd.ChainParams,
)
default:
t.Fatalf("unknown HTLC script version")
}
htlc, err := swap.NewHtlc(
scriptVersion, contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, testPreimage.Hash(), outputType,
cfg.lnd.ChainParams,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract)
if err != nil {

@ -201,21 +201,17 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig,
swapKit.lastUpdateTime = initiationTime
scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion())
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
// Default to using P2WSH for legacy htlcs.
outputType = swap.HtlcP2WSH
}
// Create the htlc.
htlc, err := swapKit.getHtlc(outputType)
htlc, err := GetHtlc(
swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams,
)
if err != nil {
return nil, err
}
// Log htlc address for debugging.
swapKit.log.Infof("Htlc address: %v", htlc.Address)
swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType,
htlc.Address)
// Obtain the payment addr since we'll need it later for routing plugin
// recommendation and possibly for cancel.
@ -261,15 +257,10 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig,
hash, swap.TypeOut, cfg, &pend.Contract.SwapContract,
)
scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion)
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
// Default to using P2WSH for legacy htlcs.
outputType = swap.HtlcP2WSH
}
// Create the htlc.
htlc, err := swapKit.getHtlc(outputType)
htlc, err := GetHtlc(
swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams,
)
if err != nil {
return nil, err
}

@ -4,6 +4,7 @@ import (
"context"
"time"
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightninglabs/lndclient"
"github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/swap"
@ -67,14 +68,28 @@ func IsTaprootSwap(swapContract *loopdb.SwapContract) bool {
return GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3
}
// getHtlc composes and returns the on-chain swap script.
func (s *swapKit) getHtlc(outputType swap.HtlcOutputType) (*swap.Htlc, error) {
return swap.NewHtlc(
GetHtlcScriptVersion(s.contract.ProtocolVersion),
s.contract.CltvExpiry, s.contract.SenderKey,
s.contract.ReceiverKey, s.hash, outputType,
s.swapConfig.lnd.ChainParams,
)
// GetHtlc composes and returns the on-chain swap script.
func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract,
chainParams *chaincfg.Params) (*swap.Htlc, error) {
switch GetHtlcScriptVersion(contract.ProtocolVersion) {
case swap.HtlcV2:
return swap.NewHtlcV2(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, hash,
chainParams,
)
case swap.HtlcV3:
return swap.NewHtlcV3(
contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, contract.SenderKey,
contract.ReceiverKey, hash,
chainParams,
)
}
return nil, swap.ErrInvalidScriptVersion
}
// swapInfo constructs and returns a filled SwapInfo from

@ -114,16 +114,16 @@ var (
// QuoteHtlcP2WSH is a template script just used for sweep fee
// estimation.
QuoteHtlcP2WSH, _ = NewHtlc(
HtlcV2, ^int32(0), dummyPubKey, dummyPubKey, quoteHash,
HtlcP2WSH, &chaincfg.MainNetParams,
QuoteHtlcP2WSH, _ = NewHtlcV2(
^int32(0), dummyPubKey, dummyPubKey, quoteHash,
&chaincfg.MainNetParams,
)
// QuoteHtlcP2TR is a template script just used for sweep fee
// estimation.
QuoteHtlcP2TR, _ = NewHtlc(
HtlcV3, ^int32(0), dummyPubKey, dummyPubKey, quoteHash,
HtlcP2TR, &chaincfg.MainNetParams,
QuoteHtlcP2TR, _ = NewHtlcV3(
^int32(0), dummyPubKey, dummyPubKey, dummyPubKey, dummyPubKey,
quoteHash, &chaincfg.MainNetParams,
)
// ErrInvalidScriptVersion is returned when an unknown htlc version
@ -135,6 +135,10 @@ var (
// selected for a v2 script.
ErrInvalidOutputSelected = fmt.Errorf("taproot output selected for " +
"non taproot htlc")
// ErrInvalidOutputType is returned when an unknown output type is
// associated with a certain swap htlc.
ErrInvalidOutputType = fmt.Errorf("invalid htlc output type")
)
// String returns the string value of HtlcOutputType.
@ -151,38 +155,54 @@ func (h HtlcOutputType) String() string {
}
}
// NewHtlc returns a new instance. For v3 scripts, an internal pubkey generated
// by both participants must be provided.
func NewHtlc(version ScriptVersion, cltvExpiry int32,
senderKey, receiverKey [33]byte, hash lntypes.Hash,
outputType HtlcOutputType, chainParams *chaincfg.Params) (*Htlc, error) {
// NewHtlcV2 returns a new V2 (P2WSH) HTLC instance.
func NewHtlcV2(cltvExpiry int32, senderKey, receiverKey [33]byte,
hash lntypes.Hash, chainParams *chaincfg.Params) (*Htlc, error) {
var (
err error
htlc HtlcScript
htlc, err := newHTLCScriptV2(
cltvExpiry, senderKey, receiverKey, hash,
)
switch version {
case HtlcV2:
htlc, err = newHTLCScriptV2(
cltvExpiry, senderKey, receiverKey, hash,
)
case HtlcV3:
htlc, err = newHTLCScriptV3(
cltvExpiry, senderKey, receiverKey, hash,
)
if err != nil {
return nil, err
}
default:
return nil, ErrInvalidScriptVersion
address, pkScript, sigScript, err := htlc.lockingConditions(
HtlcP2WSH, chainParams,
)
if err != nil {
return nil, fmt.Errorf("could not get address: %w", err)
}
return &Htlc{
HtlcScript: htlc,
Hash: hash,
Version: HtlcV2,
PkScript: pkScript,
OutputType: HtlcP2WSH,
ChainParams: chainParams,
Address: address,
SigScript: sigScript,
}, nil
}
// NewHtlcV3 returns a new V3 HTLC (P2TR) instance. Internal pubkey generated
// by both participants must be provided.
func NewHtlcV3(cltvExpiry int32, senderInternalKey, receiverInternalKey,
senderKey, receiverKey [33]byte, hash lntypes.Hash,
chainParams *chaincfg.Params) (*Htlc, error) {
htlc, err := newHTLCScriptV3(
cltvExpiry, senderInternalKey, receiverInternalKey,
senderKey, receiverKey, hash,
)
if err != nil {
return nil, err
}
address, pkScript, sigScript, err := htlc.lockingConditions(
outputType, chainParams,
HtlcP2TR, chainParams,
)
if err != nil {
return nil, fmt.Errorf("could not get address: %w", err)
@ -191,9 +211,9 @@ func NewHtlc(version ScriptVersion, cltvExpiry int32,
return &Htlc{
HtlcScript: htlc,
Hash: hash,
Version: version,
Version: HtlcV3,
PkScript: pkScript,
OutputType: outputType,
OutputType: HtlcP2TR,
ChainParams: chainParams,
Address: address,
SigScript: sigScript,
@ -481,7 +501,8 @@ type HtlcScriptV3 struct {
}
// newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script.
func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte,
func newHTLCScriptV3(cltvExpiry int32, senderInternalKey, receiverInternalKey,
senderHtlcKey, receiverHtlcKey [33]byte,
swapHash lntypes.Hash) (*HtlcScriptV3, error) {
senderPubKey, err := schnorr.ParsePubKey(senderHtlcKey[1:])
@ -494,13 +515,6 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte,
return nil, err
}
aggregateKey, _, _, err := musig2.AggregateKeys(
[]*btcec.PublicKey{senderPubKey, receiverPubKey}, true,
)
if err != nil {
return nil, err
}
// Create our success path script, we'll use this separately
// to generate the success path leaf.
successPathScript, err := GenSuccessPathScript(
@ -527,6 +541,31 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte,
rootHash := tree.RootNode.TapHash()
// Parse the pub keys used in the internal aggregate key. They are
// optional and may just be the same keys that are used for the script
// paths.
senderInternalPubKey, err := schnorr.ParsePubKey(senderInternalKey[1:])
if err != nil {
return nil, err
}
receiverInternalPubKey, err := schnorr.ParsePubKey(
receiverInternalKey[1:],
)
if err != nil {
return nil, err
}
// Calculate the internal aggregate key.
aggregateKey, _, _, err := musig2.AggregateKeys(
[]*btcec.PublicKey{
senderInternalPubKey, receiverInternalPubKey,
}, true,
)
if err != nil {
return nil, err
}
// Calculate top level taproot key.
taprootKey := txscript.ComputeTaprootOutputKey(
aggregateKey.PreTweakedKey, rootHash[:],

@ -134,9 +134,9 @@ func TestHtlcV2(t *testing.T) {
hash := sha256.Sum256(testPreimage[:])
// Create the htlc.
htlc, err := NewHtlc(
HtlcV2, testCltvExpiry, senderKey, receiverKey, hash,
HtlcP2WSH, &chaincfg.MainNetParams,
htlc, err := NewHtlcV2(
testCltvExpiry, senderKey, receiverKey, hash,
&chaincfg.MainNetParams,
)
require.NoError(t, err)
@ -287,10 +287,9 @@ func TestHtlcV2(t *testing.T) {
bogusKey := [33]byte{0xb, 0xa, 0xd}
// Create the htlc with the bogus key.
htlc, err = NewHtlc(
HtlcV2, testCltvExpiry,
bogusKey, receiverKey, hash,
HtlcP2WSH, &chaincfg.MainNetParams,
htlc, err = NewHtlcV2(
testCltvExpiry, bogusKey, receiverKey,
hash, &chaincfg.MainNetParams,
)
require.NoError(t, err)
@ -357,9 +356,9 @@ func TestHtlcV3(t *testing.T) {
copy(receiverKey[:], receiverPubKey.SerializeCompressed())
copy(senderKey[:], senderPubKey.SerializeCompressed())
htlc, err := NewHtlc(
HtlcV3, cltvExpiry, senderKey, receiverKey,
hashedPreimage, HtlcP2TR, &chaincfg.MainNetParams,
htlc, err := NewHtlcV3(
cltvExpiry, senderKey, receiverKey, senderKey, receiverKey,
hashedPreimage, &chaincfg.MainNetParams,
)
require.NoError(t, err)
@ -540,10 +539,10 @@ func TestHtlcV3(t *testing.T) {
bogusKey.SerializeCompressed(),
)
htlc, err := NewHtlc(
HtlcV3, cltvExpiry, bogusKeyBytes,
receiverKey, hashedPreimage, HtlcP2TR,
&chaincfg.MainNetParams,
htlc, err := NewHtlcV3(
cltvExpiry, senderKey,
receiverKey, bogusKeyBytes, receiverKey,
hashedPreimage, &chaincfg.MainNetParams,
)
require.NoError(t, err)

Loading…
Cancel
Save