loopout: enable p2tr without keyspend

pull/497/head
Andras Banki-Horvath 2 years ago
parent 901a935514
commit 391ef57ea3
No known key found for this signature in database
GPG Key ID: 80E5375C094198D8

@ -192,24 +192,39 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
swaps := make([]*SwapInfo, 0, len(loopInSwaps)+len(loopOutSwaps)) swaps := make([]*SwapInfo, 0, len(loopInSwaps)+len(loopOutSwaps))
for _, swp := range loopOutSwaps { for _, swp := range loopOutSwaps {
swapInfo := &SwapInfo{
SwapType: swap.TypeOut,
SwapContract: swp.Contract.SwapContract,
SwapStateData: swp.State(),
SwapHash: swp.Hash,
LastUpdate: swp.LastUpdateTime(),
}
scriptVersion := GetHtlcScriptVersion(
swp.Contract.ProtocolVersion,
)
outputType := swap.HtlcP2WSH
if scriptVersion == swap.HtlcV3 {
outputType = swap.HtlcP2TR
}
htlc, err := swap.NewHtlc( htlc, err := swap.NewHtlc(
GetHtlcScriptVersion(swp.Contract.ProtocolVersion), scriptVersion,
swp.Contract.CltvExpiry, swp.Contract.SenderKey, swp.Contract.CltvExpiry, swp.Contract.SenderKey,
swp.Contract.ReceiverKey, swp.Hash, swap.HtlcP2WSH, swp.Contract.ReceiverKey, swp.Hash,
s.lndServices.ChainParams, outputType, s.lndServices.ChainParams,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
swaps = append(swaps, &SwapInfo{ if outputType == swap.HtlcP2TR {
SwapType: swap.TypeOut, swapInfo.HtlcAddressP2TR = htlc.Address
SwapContract: swp.Contract.SwapContract, } else {
SwapStateData: swp.State(), swapInfo.HtlcAddressP2WSH = htlc.Address
SwapHash: swp.Hash, }
LastUpdate: swp.LastUpdateTime(),
HtlcAddressP2WSH: htlc.Address, swaps = append(swaps, swapInfo)
})
} }
for _, swp := range loopInSwaps { for _, swp := range loopInSwaps {
@ -426,9 +441,9 @@ func (s *Client) LoopOut(globalCtx context.Context,
// Return hash so that the caller can identify this swap in the updates // Return hash so that the caller can identify this swap in the updates
// stream. // stream.
return &LoopOutSwapInfo{ return &LoopOutSwapInfo{
SwapHash: swap.hash, SwapHash: swap.hash,
HtlcAddressP2WSH: swap.htlc.Address, HtlcAddress: swap.htlc.Address,
ServerMessage: initResult.serverMessage, ServerMessage: initResult.serverMessage,
}, nil }, nil
} }

@ -159,6 +159,7 @@ func TestLoopOutResume(t *testing.T) {
storedVersion := []loopdb.ProtocolVersion{ storedVersion := []loopdb.ProtocolVersion{
loopdb.ProtocolVersionUnrecorded, loopdb.ProtocolVersionUnrecorded,
loopdb.ProtocolVersionHtlcV2, loopdb.ProtocolVersionHtlcV2,
loopdb.ProtocolVersionHtlcV3,
} }
for _, version := range storedVersion { for _, version := range storedVersion {
@ -283,9 +284,15 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
// Assert that the loopout htlc equals to the expected one. // Assert that the loopout htlc equals to the expected one.
scriptVersion := GetHtlcScriptVersion(protocolVersion) scriptVersion := GetHtlcScriptVersion(protocolVersion)
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
outputType = swap.HtlcP2WSH
}
htlc, err := swap.NewHtlc( htlc, err := swap.NewHtlc(
scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey, scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, hash, swap.HtlcP2WSH, &chaincfg.TestNet3Params, receiverKey, hash, outputType, &chaincfg.TestNet3Params,
) )
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, htlc.PkScript, confIntent.PkScript) require.Equal(t, htlc.PkScript, confIntent.PkScript)
@ -345,8 +352,15 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash,
ctx.T.Fatalf("client not sweeping from htlc tx") ctx.T.Fatalf("client not sweeping from htlc tx")
} }
preImageIndex := 1 var preImageIndex int
if scriptVersion == swap.HtlcV2 { switch scriptVersion {
case swap.HtlcV1:
preImageIndex = 1
case swap.HtlcV2:
preImageIndex = 0
case swap.HtlcV3:
preImageIndex = 0 preImageIndex = 0
} }

@ -312,9 +312,9 @@ type LoopOutSwapInfo struct { // nolint:revive
// SwapHash contains the sha256 hash of the swap preimage. // SwapHash contains the sha256 hash of the swap preimage.
SwapHash lntypes.Hash SwapHash lntypes.Hash
// HtlcAddressP2WSH contains the native segwit swap htlc address that // HtlcAddress contains the swap htlc address that the server will
// the server will publish to. // publish to.
HtlcAddressP2WSH btcutil.Address HtlcAddress btcutil.Address
// ServerMessages is the human-readable message received from the loop // ServerMessages is the human-readable message received from the loop
// server. // server.

@ -386,8 +386,7 @@ func (m *Manager) autoloop(ctx context.Context) error {
} }
log.Infof("loop out automatically dispatched: hash: %v, "+ log.Infof("loop out automatically dispatched: hash: %v, "+
"address: %v", loopOut.SwapHash, "address: %v", loopOut.SwapHash, loopOut.HtlcAddress)
loopOut.HtlcAddressP2WSH)
} }
for _, in := range suggestion.InSwaps { for _, in := range suggestion.InSwaps {

@ -150,13 +150,21 @@ func (s *swapClientServer) LoopOut(ctx context.Context,
return nil, err return nil, err
} }
return &clientrpc.SwapResponse{ htlcAddress := info.HtlcAddress.String()
Id: info.SwapHash.String(), resp := &clientrpc.SwapResponse{
IdBytes: info.SwapHash[:], Id: info.SwapHash.String(),
HtlcAddress: info.HtlcAddressP2WSH.String(), IdBytes: info.SwapHash[:],
HtlcAddressP2Wsh: info.HtlcAddressP2WSH.String(), HtlcAddress: htlcAddress,
ServerMessage: info.ServerMessage, ServerMessage: info.ServerMessage,
}, nil }
if loopdb.CurrentProtocolVersion() < loopdb.ProtocolVersionHtlcV3 {
resp.HtlcAddressP2Wsh = htlcAddress
} else {
resp.HtlcAddressP2Tr = htlcAddress
}
return resp, nil
} }
func (s *swapClientServer) marshallSwap(loopSwap *loop.SwapInfo) ( func (s *swapClientServer) marshallSwap(loopSwap *loop.SwapInfo) (
@ -252,8 +260,13 @@ func (s *swapClientServer) marshallSwap(loopSwap *loop.SwapInfo) (
case swap.TypeOut: case swap.TypeOut:
swapType = clientrpc.SwapType_LOOP_OUT swapType = clientrpc.SwapType_LOOP_OUT
htlcAddressP2WSH = loopSwap.HtlcAddressP2WSH.EncodeAddress() if loopSwap.HtlcAddressP2WSH != nil {
htlcAddress = htlcAddressP2WSH htlcAddressP2WSH = loopSwap.HtlcAddressP2WSH.EncodeAddress()
htlcAddress = htlcAddressP2WSH
} else {
htlcAddressP2TR = loopSwap.HtlcAddressP2TR.EncodeAddress()
htlcAddress = htlcAddressP2TR
}
outGoingChanSet = loopSwap.OutgoingChanSet outGoingChanSet = loopSwap.OutgoingChanSet

@ -183,14 +183,20 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig,
} }
swapKit := newSwapKit( swapKit := newSwapKit(
swapHash, swap.TypeOut, swapHash, swap.TypeOut, cfg, &contract.SwapContract,
cfg, &contract.SwapContract,
) )
swapKit.lastUpdateTime = initiationTime swapKit.lastUpdateTime = initiationTime
scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion())
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
// Default to using P2WSH for legacy htlcs.
outputType = swap.HtlcP2WSH
}
// Create the htlc. // Create the htlc.
htlc, err := swapKit.getHtlc(swap.HtlcP2WSH) htlc, err := swapKit.getHtlc(outputType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -239,12 +245,18 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig,
log.Infof("Resuming loop out swap %v", hash) log.Infof("Resuming loop out swap %v", hash)
swapKit := newSwapKit( swapKit := newSwapKit(
hash, swap.TypeOut, cfg, hash, swap.TypeOut, cfg, &pend.Contract.SwapContract,
&pend.Contract.SwapContract,
) )
scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion)
outputType := swap.HtlcP2TR
if scriptVersion != swap.HtlcV3 {
// Default to using P2WSH for legacy htlcs.
outputType = swap.HtlcP2WSH
}
// Create the htlc. // Create the htlc.
htlc, err := swapKit.getHtlc(swap.HtlcP2WSH) htlc, err := swapKit.getHtlc(outputType)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -24,6 +24,22 @@ import (
// TestLoopOutPaymentParameters tests the first part of the loop out process up // TestLoopOutPaymentParameters tests the first part of the loop out process up
// to the point where the off-chain payments are made. // to the point where the off-chain payments are made.
func TestLoopOutPaymentParameters(t *testing.T) { func TestLoopOutPaymentParameters(t *testing.T) {
t.Run("stable protocol", func(t *testing.T) {
testLoopOutPaymentParameters(t)
})
t.Run("experimental protocol", func(t *testing.T) {
loopdb.EnableExperimentalProtocol()
defer loopdb.ResetCurrentProtocolVersion()
testLoopOutPaymentParameters(t)
})
}
// TestLoopOutPaymentParameters tests the first part of the loop out process up
// to the point where the off-chain payments are made.
func testLoopOutPaymentParameters(t *testing.T) {
defer test.Guard(t)() defer test.Guard(t)()
// Set up test context objects. // Set up test context objects.
@ -144,6 +160,19 @@ func TestLoopOutPaymentParameters(t *testing.T) {
// TestLateHtlcPublish tests that the client is not revealing the preimage if // TestLateHtlcPublish tests that the client is not revealing the preimage if
// there are not enough blocks left. // there are not enough blocks left.
func TestLateHtlcPublish(t *testing.T) { func TestLateHtlcPublish(t *testing.T) {
t.Run("stable protocol", func(t *testing.T) {
testLateHtlcPublish(t)
})
t.Run("experimental protocol", func(t *testing.T) {
loopdb.EnableExperimentalProtocol()
defer loopdb.ResetCurrentProtocolVersion()
testLateHtlcPublish(t)
})
}
func testLateHtlcPublish(t *testing.T) {
defer test.Guard(t)() defer test.Guard(t)()
lnd := test.NewMockLnd() lnd := test.NewMockLnd()
@ -232,6 +261,19 @@ func TestLateHtlcPublish(t *testing.T) {
// TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a // TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a
// custom confirmation target. // custom confirmation target.
func TestCustomSweepConfTarget(t *testing.T) { func TestCustomSweepConfTarget(t *testing.T) {
t.Run("stable protocol", func(t *testing.T) {
testCustomSweepConfTarget(t)
})
t.Run("experimental protocol", func(t *testing.T) {
loopdb.EnableExperimentalProtocol()
defer loopdb.ResetCurrentProtocolVersion()
testCustomSweepConfTarget(t)
})
}
func testCustomSweepConfTarget(t *testing.T) {
defer test.Guard(t)() defer test.Guard(t)()
lnd := test.NewMockLnd() lnd := test.NewMockLnd()
@ -433,6 +475,19 @@ func TestCustomSweepConfTarget(t *testing.T) {
// to start with a fee rate that will be too high, then progress to an // to start with a fee rate that will be too high, then progress to an
// acceptable one. // acceptable one.
func TestPreimagePush(t *testing.T) { func TestPreimagePush(t *testing.T) {
t.Run("stable protocol", func(t *testing.T) {
testPreimagePush(t)
})
t.Run("experimental protocol", func(t *testing.T) {
loopdb.EnableExperimentalProtocol()
defer loopdb.ResetCurrentProtocolVersion()
testPreimagePush(t)
})
}
func testPreimagePush(t *testing.T) {
defer test.Guard(t)() defer test.Guard(t)()
lnd := test.NewMockLnd() lnd := test.NewMockLnd()
@ -604,6 +659,19 @@ func TestPreimagePush(t *testing.T) {
// we have revealed our preimage, demonstrating that we do not reveal our // we have revealed our preimage, demonstrating that we do not reveal our
// preimage once we've reached our expiry height. // preimage once we've reached our expiry height.
func TestExpiryBeforeReveal(t *testing.T) { func TestExpiryBeforeReveal(t *testing.T) {
t.Run("stable protocol", func(t *testing.T) {
testExpiryBeforeReveal(t)
})
t.Run("experimental protocol", func(t *testing.T) {
loopdb.EnableExperimentalProtocol()
defer loopdb.ResetCurrentProtocolVersion()
testExpiryBeforeReveal(t)
})
}
func testExpiryBeforeReveal(t *testing.T) {
defer test.Guard(t)() defer test.Guard(t)()
lnd := test.NewMockLnd() lnd := test.NewMockLnd()
@ -719,6 +787,19 @@ func TestExpiryBeforeReveal(t *testing.T) {
// TestFailedOffChainCancelation tests sending of a cancelation message to // TestFailedOffChainCancelation tests sending of a cancelation message to
// the server when a swap fails due to off-chain routing. // the server when a swap fails due to off-chain routing.
func TestFailedOffChainCancelation(t *testing.T) { func TestFailedOffChainCancelation(t *testing.T) {
t.Run("stable protocol", func(t *testing.T) {
testFailedOffChainCancelation(t)
})
t.Run("experimental protocol", func(t *testing.T) {
loopdb.EnableExperimentalProtocol()
defer loopdb.ResetCurrentProtocolVersion()
testFailedOffChainCancelation(t)
})
}
func testFailedOffChainCancelation(t *testing.T) {
defer test.Guard(t)() defer test.Guard(t)()
lnd := test.NewMockLnd() lnd := test.NewMockLnd()

Loading…
Cancel
Save