diff --git a/client.go b/client.go index e43dee4..d34dfb4 100644 --- a/client.go +++ b/client.go @@ -192,24 +192,39 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) { swaps := make([]*SwapInfo, 0, len(loopInSwaps)+len(loopOutSwaps)) for _, swp := range loopOutSwaps { + swapInfo := &SwapInfo{ + SwapType: swap.TypeOut, + SwapContract: swp.Contract.SwapContract, + SwapStateData: swp.State(), + SwapHash: swp.Hash, + LastUpdate: swp.LastUpdateTime(), + } + scriptVersion := GetHtlcScriptVersion( + swp.Contract.ProtocolVersion, + ) + + outputType := swap.HtlcP2WSH + if scriptVersion == swap.HtlcV3 { + outputType = swap.HtlcP2TR + } + htlc, err := swap.NewHtlc( - GetHtlcScriptVersion(swp.Contract.ProtocolVersion), + scriptVersion, swp.Contract.CltvExpiry, swp.Contract.SenderKey, - swp.Contract.ReceiverKey, swp.Hash, swap.HtlcP2WSH, - s.lndServices.ChainParams, + swp.Contract.ReceiverKey, swp.Hash, + outputType, s.lndServices.ChainParams, ) if err != nil { return nil, err } - swaps = append(swaps, &SwapInfo{ - SwapType: swap.TypeOut, - SwapContract: swp.Contract.SwapContract, - SwapStateData: swp.State(), - SwapHash: swp.Hash, - LastUpdate: swp.LastUpdateTime(), - HtlcAddressP2WSH: htlc.Address, - }) + if outputType == swap.HtlcP2TR { + swapInfo.HtlcAddressP2TR = htlc.Address + } else { + swapInfo.HtlcAddressP2WSH = htlc.Address + } + + swaps = append(swaps, swapInfo) } for _, swp := range loopInSwaps { @@ -426,9 +441,9 @@ func (s *Client) LoopOut(globalCtx context.Context, // Return hash so that the caller can identify this swap in the updates // stream. return &LoopOutSwapInfo{ - SwapHash: swap.hash, - HtlcAddressP2WSH: swap.htlc.Address, - ServerMessage: initResult.serverMessage, + SwapHash: swap.hash, + HtlcAddress: swap.htlc.Address, + ServerMessage: initResult.serverMessage, }, nil } diff --git a/client_test.go b/client_test.go index b12f1bf..7a77f10 100644 --- a/client_test.go +++ b/client_test.go @@ -159,6 +159,7 @@ func TestLoopOutResume(t *testing.T) { storedVersion := []loopdb.ProtocolVersion{ loopdb.ProtocolVersionUnrecorded, loopdb.ProtocolVersionHtlcV2, + loopdb.ProtocolVersionHtlcV3, } for _, version := range storedVersion { @@ -283,9 +284,15 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed, // Assert that the loopout htlc equals to the expected one. scriptVersion := GetHtlcScriptVersion(protocolVersion) + + outputType := swap.HtlcP2TR + if scriptVersion != swap.HtlcV3 { + outputType = swap.HtlcP2WSH + } + htlc, err := swap.NewHtlc( scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey, - receiverKey, hash, swap.HtlcP2WSH, &chaincfg.TestNet3Params, + receiverKey, hash, outputType, &chaincfg.TestNet3Params, ) require.NoError(t, err) require.Equal(t, htlc.PkScript, confIntent.PkScript) @@ -345,8 +352,15 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash, ctx.T.Fatalf("client not sweeping from htlc tx") } - preImageIndex := 1 - if scriptVersion == swap.HtlcV2 { + var preImageIndex int + switch scriptVersion { + case swap.HtlcV1: + preImageIndex = 1 + + case swap.HtlcV2: + preImageIndex = 0 + + case swap.HtlcV3: preImageIndex = 0 } diff --git a/interface.go b/interface.go index 9318d54..72b3e0e 100644 --- a/interface.go +++ b/interface.go @@ -312,9 +312,9 @@ type LoopOutSwapInfo struct { // nolint:revive // SwapHash contains the sha256 hash of the swap preimage. SwapHash lntypes.Hash - // HtlcAddressP2WSH contains the native segwit swap htlc address that - // the server will publish to. - HtlcAddressP2WSH btcutil.Address + // HtlcAddress contains the swap htlc address that the server will + // publish to. + HtlcAddress btcutil.Address // ServerMessages is the human-readable message received from the loop // server. diff --git a/liquidity/liquidity.go b/liquidity/liquidity.go index 59d6426..360e35c 100644 --- a/liquidity/liquidity.go +++ b/liquidity/liquidity.go @@ -386,8 +386,7 @@ func (m *Manager) autoloop(ctx context.Context) error { } log.Infof("loop out automatically dispatched: hash: %v, "+ - "address: %v", loopOut.SwapHash, - loopOut.HtlcAddressP2WSH) + "address: %v", loopOut.SwapHash, loopOut.HtlcAddress) } for _, in := range suggestion.InSwaps { diff --git a/loopd/swapclient_server.go b/loopd/swapclient_server.go index f97bfc6..d2bfbb6 100644 --- a/loopd/swapclient_server.go +++ b/loopd/swapclient_server.go @@ -150,13 +150,21 @@ func (s *swapClientServer) LoopOut(ctx context.Context, return nil, err } - return &clientrpc.SwapResponse{ - Id: info.SwapHash.String(), - IdBytes: info.SwapHash[:], - HtlcAddress: info.HtlcAddressP2WSH.String(), - HtlcAddressP2Wsh: info.HtlcAddressP2WSH.String(), - ServerMessage: info.ServerMessage, - }, nil + htlcAddress := info.HtlcAddress.String() + resp := &clientrpc.SwapResponse{ + Id: info.SwapHash.String(), + IdBytes: info.SwapHash[:], + HtlcAddress: htlcAddress, + ServerMessage: info.ServerMessage, + } + + if loopdb.CurrentProtocolVersion() < loopdb.ProtocolVersionHtlcV3 { + resp.HtlcAddressP2Wsh = htlcAddress + } else { + resp.HtlcAddressP2Tr = htlcAddress + } + + return resp, nil } func (s *swapClientServer) marshallSwap(loopSwap *loop.SwapInfo) ( @@ -252,8 +260,13 @@ func (s *swapClientServer) marshallSwap(loopSwap *loop.SwapInfo) ( case swap.TypeOut: swapType = clientrpc.SwapType_LOOP_OUT - htlcAddressP2WSH = loopSwap.HtlcAddressP2WSH.EncodeAddress() - htlcAddress = htlcAddressP2WSH + if loopSwap.HtlcAddressP2WSH != nil { + htlcAddressP2WSH = loopSwap.HtlcAddressP2WSH.EncodeAddress() + htlcAddress = htlcAddressP2WSH + } else { + htlcAddressP2TR = loopSwap.HtlcAddressP2TR.EncodeAddress() + htlcAddress = htlcAddressP2TR + } outGoingChanSet = loopSwap.OutgoingChanSet diff --git a/loopout.go b/loopout.go index 10d01de..56d679d 100644 --- a/loopout.go +++ b/loopout.go @@ -183,14 +183,20 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, } swapKit := newSwapKit( - swapHash, swap.TypeOut, - cfg, &contract.SwapContract, + swapHash, swap.TypeOut, cfg, &contract.SwapContract, ) swapKit.lastUpdateTime = initiationTime + scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion()) + outputType := swap.HtlcP2TR + if scriptVersion != swap.HtlcV3 { + // Default to using P2WSH for legacy htlcs. + outputType = swap.HtlcP2WSH + } + // Create the htlc. - htlc, err := swapKit.getHtlc(swap.HtlcP2WSH) + htlc, err := swapKit.getHtlc(outputType) if err != nil { return nil, err } @@ -239,12 +245,18 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig, log.Infof("Resuming loop out swap %v", hash) swapKit := newSwapKit( - hash, swap.TypeOut, cfg, - &pend.Contract.SwapContract, + hash, swap.TypeOut, cfg, &pend.Contract.SwapContract, ) + scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion) + outputType := swap.HtlcP2TR + if scriptVersion != swap.HtlcV3 { + // Default to using P2WSH for legacy htlcs. + outputType = swap.HtlcP2WSH + } + // Create the htlc. - htlc, err := swapKit.getHtlc(swap.HtlcP2WSH) + htlc, err := swapKit.getHtlc(outputType) if err != nil { return nil, err } diff --git a/loopout_test.go b/loopout_test.go index 8350a5a..e821aaa 100644 --- a/loopout_test.go +++ b/loopout_test.go @@ -24,6 +24,22 @@ import ( // TestLoopOutPaymentParameters tests the first part of the loop out process up // to the point where the off-chain payments are made. func TestLoopOutPaymentParameters(t *testing.T) { + t.Run("stable protocol", func(t *testing.T) { + testLoopOutPaymentParameters(t) + }) + + t.Run("experimental protocol", func(t *testing.T) { + loopdb.EnableExperimentalProtocol() + defer loopdb.ResetCurrentProtocolVersion() + + testLoopOutPaymentParameters(t) + }) +} + +// TestLoopOutPaymentParameters tests the first part of the loop out process up +// to the point where the off-chain payments are made. +func testLoopOutPaymentParameters(t *testing.T) { + defer test.Guard(t)() // Set up test context objects. @@ -144,6 +160,19 @@ func TestLoopOutPaymentParameters(t *testing.T) { // TestLateHtlcPublish tests that the client is not revealing the preimage if // there are not enough blocks left. func TestLateHtlcPublish(t *testing.T) { + t.Run("stable protocol", func(t *testing.T) { + testLateHtlcPublish(t) + }) + + t.Run("experimental protocol", func(t *testing.T) { + loopdb.EnableExperimentalProtocol() + defer loopdb.ResetCurrentProtocolVersion() + + testLateHtlcPublish(t) + }) +} + +func testLateHtlcPublish(t *testing.T) { defer test.Guard(t)() lnd := test.NewMockLnd() @@ -232,6 +261,19 @@ func TestLateHtlcPublish(t *testing.T) { // TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a // custom confirmation target. func TestCustomSweepConfTarget(t *testing.T) { + t.Run("stable protocol", func(t *testing.T) { + testCustomSweepConfTarget(t) + }) + + t.Run("experimental protocol", func(t *testing.T) { + loopdb.EnableExperimentalProtocol() + defer loopdb.ResetCurrentProtocolVersion() + + testCustomSweepConfTarget(t) + }) +} + +func testCustomSweepConfTarget(t *testing.T) { defer test.Guard(t)() lnd := test.NewMockLnd() @@ -433,6 +475,19 @@ func TestCustomSweepConfTarget(t *testing.T) { // to start with a fee rate that will be too high, then progress to an // acceptable one. func TestPreimagePush(t *testing.T) { + t.Run("stable protocol", func(t *testing.T) { + testPreimagePush(t) + }) + + t.Run("experimental protocol", func(t *testing.T) { + loopdb.EnableExperimentalProtocol() + defer loopdb.ResetCurrentProtocolVersion() + + testPreimagePush(t) + }) +} + +func testPreimagePush(t *testing.T) { defer test.Guard(t)() lnd := test.NewMockLnd() @@ -604,6 +659,19 @@ func TestPreimagePush(t *testing.T) { // we have revealed our preimage, demonstrating that we do not reveal our // preimage once we've reached our expiry height. func TestExpiryBeforeReveal(t *testing.T) { + t.Run("stable protocol", func(t *testing.T) { + testExpiryBeforeReveal(t) + }) + + t.Run("experimental protocol", func(t *testing.T) { + loopdb.EnableExperimentalProtocol() + defer loopdb.ResetCurrentProtocolVersion() + + testExpiryBeforeReveal(t) + }) +} + +func testExpiryBeforeReveal(t *testing.T) { defer test.Guard(t)() lnd := test.NewMockLnd() @@ -719,6 +787,19 @@ func TestExpiryBeforeReveal(t *testing.T) { // TestFailedOffChainCancelation tests sending of a cancelation message to // the server when a swap fails due to off-chain routing. func TestFailedOffChainCancelation(t *testing.T) { + t.Run("stable protocol", func(t *testing.T) { + testFailedOffChainCancelation(t) + }) + + t.Run("experimental protocol", func(t *testing.T) { + loopdb.EnableExperimentalProtocol() + defer loopdb.ResetCurrentProtocolVersion() + + testFailedOffChainCancelation(t) + }) +} + +func testFailedOffChainCancelation(t *testing.T) { defer test.Guard(t)() lnd := test.NewMockLnd()