diff --git a/.travis.yml b/.travis.yml index 81a35ba..6591b0f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,7 +5,7 @@ python: - "3.4" install: - - pip install ecdsa ed25519 # test without trezorlib for now + - pip install ecdsa ed25519 semver # test without trezorlib for now - pip install pylint coverage pep8 script: diff --git a/setup.py b/setup.py index 94c5121..3239f6c 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setup( author_email='roman.zeyde@gmail.com', url='http://github.com/romanz/trezor-agent', packages=['trezor_agent', 'trezor_agent.trezor'], - install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6', 'keepkey>=0.7.0'], + install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6', 'keepkey>=0.7.0', 'semver>=2.2'], platforms=['POSIX'], classifiers=[ 'Development Status :: 4 - Beta', diff --git a/tox.ini b/tox.ini index b4c0024..11358c5 100644 --- a/tox.ini +++ b/tox.ini @@ -7,6 +7,7 @@ deps= pep8 coverage pylint + semver commands= pep8 trezor_agent pylint --reports=no --rcfile .pylintrc trezor_agent diff --git a/trezor_agent/tests/test_trezor.py b/trezor_agent/tests/test_trezor.py index 909c460..0af55ca 100644 --- a/trezor_agent/tests/test_trezor.py +++ b/trezor_agent/tests/test_trezor.py @@ -1,10 +1,9 @@ import io import mock -import pytest from .. import formats, util -from ..trezor import client +from ..trezor import client, factory ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040] CURVE = 'nist256p1' @@ -18,15 +17,7 @@ PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd' class ConnectionMock(object): - def __init__(self, version): - self.features = mock.Mock(spec=[]) - self.features.device_id = '123456789' - self.features.label = 'mywallet' - self.features.vendor = 'mock' - self.features.major_version = version[0] - self.features.minor_version = version[1] - self.features.patch_version = version[2] - self.features.revision = b'456' + def __init__(self): self.closed = False def close(self): @@ -49,21 +40,20 @@ class ConnectionMock(object): return msg -class FactoryMock(object): +def identity_type(**kwargs): + result = mock.Mock(spec=[]) + result.index = 0 + result.proto = result.user = result.host = result.port = None + result.path = None + for k, v in kwargs.items(): + setattr(result, k, v) + return result - @staticmethod - def client(): - return ConnectionMock(version=(1, 3, 4)) - @staticmethod - def identity_type(**kwargs): - result = mock.Mock(spec=[]) - result.index = 0 - result.proto = result.user = result.host = result.port = None - result.path = None - for k, v in kwargs.items(): - setattr(result, k, v) - return result +def load_client(): + return factory.ClientWrapper(connection=ConnectionMock(), + identity_type=identity_type, + device_name='DEVICE_NAME') BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0' @@ -82,7 +72,7 @@ SIG = (b'\x00R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!' def test_ssh_agent(): label = 'localhost:22' - c = client.Client(factory=FactoryMock) + c = client.Client(loader=load_client) ident = c.get_identity(label=label) assert ident.host == 'localhost' assert ident.proto == 'ssh' @@ -129,15 +119,3 @@ def test_utils(): url = 'https://user@host:443/path' assert client.identity_to_string(identity) == url - - -def test_old_version(): - - class OldFactoryMock(FactoryMock): - - @staticmethod - def client(): - return ConnectionMock(version=(1, 2, 3)) - - with pytest.raises(ValueError): - client.Client(factory=OldFactoryMock) diff --git a/trezor_agent/trezor/_factory.py b/trezor_agent/trezor/_factory.py deleted file mode 100644 index bad4138..0000000 --- a/trezor_agent/trezor/_factory.py +++ /dev/null @@ -1,37 +0,0 @@ -''' Thin wrapper around trezorlib. ''' - - -def client(): - # pylint: disable=import-error - from trezorlib.client import TrezorClient - from trezorlib.transport_hid import HidTransport as TrezorHidTransport - from trezorlib.messages_pb2 import PassphraseAck as TrezorPassphraseAck - - from keepkeylib.client import KeepKeyClient - from keepkeylib.transport_hid import HidTransport as KeepKeyHidTransport - from keepkeylib.messages_pb2 import PassphraseAck as KeepKeyPassphraseAck - - devices = list(TrezorHidTransport.enumerate()) - if len(devices) == 1: - t = TrezorClient(TrezorHidTransport(devices[0])) - t.callback_PassphraseRequest = lambda msg: TrezorPassphraseAck(passphrase='') - else: - devices = list(KeepKeyHidTransport.enumerate()) - if len(devices) != 1: - msg = '{:d} devices found'.format(len(devices)) - raise IOError(msg) - t = KeepKeyClient(KeepKeyHidTransport(devices[0])) - t.callback_PassphraseRequest = lambda msg: KeepKeyPassphraseAck(passphrase='') - - return t - - -def trezor_identity_type(**kwargs): - # pylint: disable=import-error - from trezorlib.types_pb2 import IdentityType - return IdentityType(**kwargs) - -def keepkey_identity_type(**kwargs): - # pylint: disable=import-error - from keepkeylib.types_pb2 import IdentityType - return IdentityType(**kwargs) \ No newline at end of file diff --git a/trezor_agent/trezor/client.py b/trezor_agent/trezor/client.py index b153685..a13d0d8 100644 --- a/trezor_agent/trezor/client.py +++ b/trezor_agent/trezor/client.py @@ -4,7 +4,7 @@ import logging import re import struct -from . import _factory as Factory +from . import factory from .. import formats, util log = logging.getLogger(__name__) @@ -12,29 +12,12 @@ log = logging.getLogger(__name__) class Client(object): - TREZOR_MIN_VERSION = [1, 3, 4] - KEEPKEY_MIN_VERSION = [1, 0, 4] - - def __init__(self, factory=Factory, curve=formats.CURVE_NIST256): + def __init__(self, loader=factory.load, curve=formats.CURVE_NIST256): + client_wrapper = loader() + self.client = client_wrapper.connection + self.identity_type = client_wrapper.identity_type + self.device_name = client_wrapper.device_name self.curve = curve - self.factory = factory - self.client = self.factory.client() - f = self.client.features - log.debug('connected to Trezor %s', f.device_id) - log.debug('label : %s', f.label) - log.debug('vendor : %s', f.vendor) - version = [f.major_version, f.minor_version, f.patch_version] - version_str = '.'.join([str(v) for v in version]) - log.debug('version : %s', version_str) - log.debug('revision : %s', binascii.hexlify(f.revision)) - if f.vendor == 'bitcointrezor.com' and version < self.TREZOR_MIN_VERSION: - fmt = 'Please upgrade your TREZOR to v{}+ firmware' - version_str = '.'.join([str(v) for v in self.TREZOR_MIN_VERSION]) - raise ValueError(fmt.format(version_str)) - elif f.vendor == 'keepkey.com' and version < self.KEEPKEY_MIN_VERSION: - fmt = 'Please upgrade your KEEPKEY to v{}+ firmware' - version_str = '.'.join([str(v) for v in self.KEEPKEY_MIN_VERSION]) - raise ValueError(fmt.format(version_str)) def __enter__(self): msg = 'Hello World!' @@ -42,24 +25,20 @@ class Client(object): return self def __exit__(self, *args): - log.info('disconnected from Trezor') + log.info('disconnected from %s', self.device_name) self.client.clear_session() # forget PIN and shutdown screen self.client.close() def get_identity(self, label): - identity = string_to_identity(label, self.factory.trezor_identity_type) - - if self.client.features.vendor == 'keepkey.com': - identity = string_to_identity(label, self.factory.keepkey_identity_type) - + identity = string_to_identity(label, self.identity_type) identity.proto = 'ssh' return identity def get_public_key(self, label): identity = self.get_identity(label=label) label = identity_to_string(identity) # canonize key label - log.info('getting "%s" public key (%s) from Trezor...', - label, self.curve) + log.info('getting "%s" public key (%s) from %s...', + label, self.curve, self.device_name) addr = _get_address(identity) node = self.client.get_public_node(n=addr, ecdsa_curve_name=self.curve) @@ -72,8 +51,8 @@ class Client(object): identity = self.get_identity(label=label) msg = _parse_ssh_blob(blob) - log.info('please confirm user "%s" login to "%s" using Trezor...', - msg['user'], label) + log.info('please confirm user "%s" login to "%s" using %s...', + msg['user'], label, self.device_name) visual = identity.path # not signed when proto='ssh' result = self.client.sign_identity(identity=identity, diff --git a/trezor_agent/trezor/factory.py b/trezor_agent/trezor/factory.py new file mode 100644 index 0000000..4b6259b --- /dev/null +++ b/trezor_agent/trezor/factory.py @@ -0,0 +1,78 @@ +''' Thin wrapper around trezor/keepkey libraries. ''' +import binascii +import collections +import logging + +import semver + +log = logging.getLogger(__name__) + +ClientWrapper = collections.namedtuple( + 'ClientWrapper', + ['connection', 'identity_type', 'device_name']) + + +# pylint: disable=too-many-arguments +def _load_client(name, client_type, hid_transport, + passphrase_ack, identity_type, required_version): + + def empty_passphrase_handler(_): + return passphrase_ack(passphrase='') + + for d in hid_transport.enumerate(): + connection = client_type(hid_transport(d)) + connection.callback_PassphraseRequest = empty_passphrase_handler + f = connection.features + log.debug('connected to %s %s', name, f.device_id) + log.debug('label : %s', f.label) + log.debug('vendor : %s', f.vendor) + current_version = '{}.{}.{}'.format(f.major_version, + f.minor_version, + f.patch_version) + log.debug('version : %s', current_version) + log.debug('revision : %s', binascii.hexlify(f.revision)) + if not semver.match(current_version, required_version): + fmt = 'Please upgrade your {} firmware to {} version (current: {})' + raise ValueError(fmt.format(name, + required_version, + current_version)) + yield ClientWrapper(connection=connection, + identity_type=identity_type, + device_name=name) + + +def _load_trezor(): + # pylint: disable=import-error + from trezorlib.client import TrezorClient + from trezorlib.transport_hid import HidTransport + from trezorlib.messages_pb2 import PassphraseAck + from trezorlib.types_pb2 import IdentityType + return _load_client(name='Trezor', + client_type=TrezorClient, + hid_transport=HidTransport, + passphrase_ack=PassphraseAck, + identity_type=IdentityType, + required_version='>=1.3.4') + + +def _load_keepkey(): + # pylint: disable=import-error + from keepkeylib.client import KeepKeyClient + from keepkeylib.transport_hid import HidTransport + from keepkeylib.messages_pb2 import PassphraseAck + from keepkeylib.types_pb2 import IdentityType + return _load_client(name='KeepKey', + client_type=KeepKeyClient, + hid_transport=HidTransport, + passphrase_ack=PassphraseAck, + identity_type=IdentityType, + required_version='>=1.0.4') + + +def load(): + devices = list(_load_trezor()) + list(_load_keepkey()) + if len(devices) == 1: + return devices[0] + + msg = '{:d} devices found'.format(len(devices)) + raise IOError(msg)