trezor: allow expiring cached passphrase

master
Roman Zeyde 6 years ago
parent 91f70e7a96
commit 766536d2c4
No known key found for this signature in database
GPG Key ID: 87CAE5FA46917CBB

@ -7,6 +7,7 @@ import mnemonic
import semver
from . import interface
from .. import util
log = logging.getLogger(__name__)
@ -46,7 +47,8 @@ class Trezor(interface.Device):
conn.callback_PinMatrixRequest = new_handler
cached_passphrase_ack = None
# Remembers the passphrase for an hour.
cached_passphrase_ack = util.ExpiringCache(seconds=60*60)
cached_state = None
def _override_passphrase_handler(self, conn):
@ -57,9 +59,10 @@ class Trezor(interface.Device):
try:
if msg.on_device is True:
return self._defs.PassphraseAck()
if self.__class__.cached_passphrase_ack:
ack = self.__class__.cached_passphrase_ack.get()
if ack:
log.debug('re-using cached %s passphrase', self)
return self.__class__.cached_passphrase_ack
return ack
passphrase = self.ui.get_passphrase()
passphrase = mnemonic.Mnemonic.normalize_string(passphrase)
@ -70,7 +73,7 @@ class Trezor(interface.Device):
msg = 'Too long passphrase ({} chars)'.format(length)
raise ValueError(msg)
self.__class__.cached_passphrase_ack = ack
self.__class__.cached_passphrase_ack.set(ack)
return ack
except: # noqa
conn.init_device()

@ -121,3 +121,26 @@ def test_assuan_serialize():
assert util.assuan_serialize(b'') == b''
assert util.assuan_serialize(b'123\n456') == b'123%0A456'
assert util.assuan_serialize(b'\r\n') == b'%0D%0A'
def test_cache():
timer = mock.Mock(side_effect=range(7))
c = util.ExpiringCache(seconds=2, timer=timer) # t=0
assert c.get() is None # t=1
obj = 'foo'
c.set(obj) # t=2
assert c.get() is obj # t=3
assert c.get() is obj # t=4
assert c.get() is None # t=5
assert c.get() is None # t=6
def test_cache_inf():
timer = mock.Mock(side_effect=range(6))
c = util.ExpiringCache(seconds=float('inf'), timer=timer)
obj = 'foo'
c.set(obj)
assert c.get() is obj
assert c.get() is obj
assert c.get() is obj
assert c.get() is obj

@ -5,6 +5,7 @@ import functools
import io
import logging
import struct
import time
log = logging.getLogger(__name__)
@ -255,3 +256,25 @@ def assuan_serialize(data):
escaped = '%{:02X}'.format(ord(c)).encode('ascii')
data = data.replace(c, escaped)
return data
class ExpiringCache(object):
"""Simple cache with a deadline."""
def __init__(self, seconds, timer=time.time):
"""C-tor."""
self.duration = seconds
self.timer = timer
self.value = None
self.set(None)
def get(self):
"""Returns existing value, or None if deadline has expired."""
if self.timer() > self.deadline:
self.value = None
return self.value
def set(self, value):
"""Set new value and reset the deadline for expiration."""
self.deadline = self.timer() + self.duration
self.value = value

Loading…
Cancel
Save