diff --git a/test/chainnotifier_mock.go b/test/chainnotifier_mock.go index bedd766..fa8ff97 100644 --- a/test/chainnotifier_mock.go +++ b/test/chainnotifier_mock.go @@ -12,8 +12,10 @@ import ( ) type mockChainNotifier struct { - lnd *LndMockServices - wg sync.WaitGroup + sync.Mutex + lnd *LndMockServices + confRegistrations []*ConfRegistration + wg sync.WaitGroup } // SpendRegistration contains registration details. @@ -29,6 +31,7 @@ type ConfRegistration struct { PkScript []byte HeightHint int32 NumConfs int32 + ConfChan chan *chainntnfs.TxConfirmation } func (c *mockChainNotifier) RegisterSpendNtfn(ctx context.Context, @@ -103,7 +106,18 @@ func (c *mockChainNotifier) RegisterConfirmationsNtfn(ctx context.Context, txid *chainhash.Hash, pkScript []byte, numConfs, heightHint int32) ( chan *chainntnfs.TxConfirmation, chan error, error) { - confChan := make(chan *chainntnfs.TxConfirmation, 1) + reg := &ConfRegistration{ + PkScript: pkScript, + TxID: txid, + HeightHint: heightHint, + NumConfs: numConfs, + ConfChan: make(chan *chainntnfs.TxConfirmation, 1), + } + + c.Lock() + c.confRegistrations = append(c.confRegistrations, reg) + c.Unlock() + errChan := make(chan error, 1) c.wg.Add(1) @@ -112,26 +126,35 @@ func (c *mockChainNotifier) RegisterConfirmationsNtfn(ctx context.Context, select { case m := <-c.lnd.ConfChannel: - if bytes.Equal(m.Tx.TxOut[0].PkScript, pkScript) { - select { - case confChan <- m: - case <-ctx.Done(): + c.Lock() + for i := 0; i < len(c.confRegistrations); i++ { + r := c.confRegistrations[i] + + // Whichever conf notifier catches the confirmation + // will forward it to all matching subscibers. + if bytes.Equal(m.Tx.TxOut[0].PkScript, r.PkScript) { + // Unregister the "notifier". + c.confRegistrations = append( + c.confRegistrations[:i], c.confRegistrations[i+1:]..., + ) + i-- + + select { + case r.ConfChan <- m: + case <-ctx.Done(): + } } } + c.Unlock() case <-ctx.Done(): } }() select { - case c.lnd.RegisterConfChannel <- &ConfRegistration{ - PkScript: pkScript, - TxID: txid, - HeightHint: heightHint, - NumConfs: numConfs, - }: + case c.lnd.RegisterConfChannel <- reg: case <-time.After(Timeout): return nil, nil, ErrTimeout } - return confChan, errChan, nil + return reg.ConfChan, errChan, nil }