diff --git a/libagent/device/trezor.py b/libagent/device/trezor.py index 0fbfc2b..19efccb 100644 --- a/libagent/device/trezor.py +++ b/libagent/device/trezor.py @@ -7,6 +7,7 @@ import mnemonic import semver from . import interface +from .. import util log = logging.getLogger(__name__) @@ -46,7 +47,8 @@ class Trezor(interface.Device): conn.callback_PinMatrixRequest = new_handler - cached_passphrase_ack = None + # Remembers the passphrase for an hour. + cached_passphrase_ack = util.ExpiringCache(seconds=60*60) cached_state = None def _override_passphrase_handler(self, conn): @@ -57,9 +59,10 @@ class Trezor(interface.Device): try: if msg.on_device is True: return self._defs.PassphraseAck() - if self.__class__.cached_passphrase_ack: + ack = self.__class__.cached_passphrase_ack.get() + if ack: log.debug('re-using cached %s passphrase', self) - return self.__class__.cached_passphrase_ack + return ack passphrase = self.ui.get_passphrase() passphrase = mnemonic.Mnemonic.normalize_string(passphrase) @@ -70,7 +73,7 @@ class Trezor(interface.Device): msg = 'Too long passphrase ({} chars)'.format(length) raise ValueError(msg) - self.__class__.cached_passphrase_ack = ack + self.__class__.cached_passphrase_ack.set(ack) return ack except: # noqa conn.init_device() diff --git a/libagent/tests/test_util.py b/libagent/tests/test_util.py index 2cef300..e85bc2f 100644 --- a/libagent/tests/test_util.py +++ b/libagent/tests/test_util.py @@ -121,3 +121,26 @@ def test_assuan_serialize(): assert util.assuan_serialize(b'') == b'' assert util.assuan_serialize(b'123\n456') == b'123%0A456' assert util.assuan_serialize(b'\r\n') == b'%0D%0A' + + +def test_cache(): + timer = mock.Mock(side_effect=range(7)) + c = util.ExpiringCache(seconds=2, timer=timer) # t=0 + assert c.get() is None # t=1 + obj = 'foo' + c.set(obj) # t=2 + assert c.get() is obj # t=3 + assert c.get() is obj # t=4 + assert c.get() is None # t=5 + assert c.get() is None # t=6 + + +def test_cache_inf(): + timer = mock.Mock(side_effect=range(6)) + c = util.ExpiringCache(seconds=float('inf'), timer=timer) + obj = 'foo' + c.set(obj) + assert c.get() is obj + assert c.get() is obj + assert c.get() is obj + assert c.get() is obj diff --git a/libagent/util.py b/libagent/util.py index c98891f..7df843b 100644 --- a/libagent/util.py +++ b/libagent/util.py @@ -5,6 +5,7 @@ import functools import io import logging import struct +import time log = logging.getLogger(__name__) @@ -255,3 +256,25 @@ def assuan_serialize(data): escaped = '%{:02X}'.format(ord(c)).encode('ascii') data = data.replace(c, escaped) return data + + +class ExpiringCache(object): + """Simple cache with a deadline.""" + + def __init__(self, seconds, timer=time.time): + """C-tor.""" + self.duration = seconds + self.timer = timer + self.value = None + self.set(None) + + def get(self): + """Returns existing value, or None if deadline has expired.""" + if self.timer() > self.deadline: + self.value = None + return self.value + + def set(self, value): + """Set new value and reset the deadline for expiration.""" + self.deadline = self.timer() + self.duration + self.value = value