diff --git a/libagent/device/trezor.py b/libagent/device/trezor.py index 9b1f7e2..62f0bcd 100644 --- a/libagent/device/trezor.py +++ b/libagent/device/trezor.py @@ -106,13 +106,13 @@ class Trezor(interface.Device): def connect(self): """Enumerate and connect to the first available interface.""" - transports = self._defs.enumerate_transports() - if not transports: + transport = self._defs.find_device() + if not transport: raise interface.NotFoundError('{} not connected'.format(self)) - log.debug('transports: %s', transports) + log.debug('using transport: %s', transport) for _ in range(5): # Retry a few times in case of PIN failures - connection = self._defs.Client(transport=transports[0], + connection = self._defs.Client(transport=transport, state=self.__class__.cached_state) self._override_pin_handler(connection) self._override_passphrase_handler(connection) diff --git a/libagent/device/trezor_defs.py b/libagent/device/trezor_defs.py index 49662d3..cf798e4 100644 --- a/libagent/device/trezor_defs.py +++ b/libagent/device/trezor_defs.py @@ -1,13 +1,27 @@ """TREZOR-related definitions.""" # pylint: disable=unused-import,import-error +import os +import logging from trezorlib.client import CallException, PinException from trezorlib.client import TrezorClient as Client from trezorlib.messages import IdentityType, PassphraseAck, PinMatrixAck, PassphraseStateAck -from trezorlib.device import TrezorDevice +try: + from trezorlib.transport import get_transport +except ImportError: + from trezorlib.device import TrezorDevice + get_transport = TrezorDevice.find_by_path -def enumerate_transports(): - """Returns all available transports.""" - return TrezorDevice.enumerate() +log = logging.getLogger(__name__) + + +def find_device(): + """Selects a transport based on `TREZOR_PATH` env variable. + If unset, picks first connected device. + """ + try: + return get_transport(os.environ.get("TREZOR_PATH")) + except Exception as e: + log.debug("Failed to find a Trezor device: %s", e)