diff --git a/setup.py b/setup.py index acf0173..94c5121 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setup( author_email='roman.zeyde@gmail.com', url='http://github.com/romanz/trezor-agent', packages=['trezor_agent', 'trezor_agent.trezor'], - install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6'], + install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6', 'keepkey>=0.7.0'], platforms=['POSIX'], classifiers=[ 'Development Status :: 4 - Beta', diff --git a/trezor_agent/trezor/_factory.py b/trezor_agent/trezor/_factory.py index 2c4cebf..bad4138 100644 --- a/trezor_agent/trezor/_factory.py +++ b/trezor_agent/trezor/_factory.py @@ -4,20 +4,34 @@ def client(): # pylint: disable=import-error from trezorlib.client import TrezorClient - from trezorlib.transport_hid import HidTransport - from trezorlib.messages_pb2 import PassphraseAck + from trezorlib.transport_hid import HidTransport as TrezorHidTransport + from trezorlib.messages_pb2 import PassphraseAck as TrezorPassphraseAck - devices = list(HidTransport.enumerate()) - if len(devices) != 1: - msg = '{:d} Trezor devices found'.format(len(devices)) - raise IOError(msg) + from keepkeylib.client import KeepKeyClient + from keepkeylib.transport_hid import HidTransport as KeepKeyHidTransport + from keepkeylib.messages_pb2 import PassphraseAck as KeepKeyPassphraseAck + + devices = list(TrezorHidTransport.enumerate()) + if len(devices) == 1: + t = TrezorClient(TrezorHidTransport(devices[0])) + t.callback_PassphraseRequest = lambda msg: TrezorPassphraseAck(passphrase='') + else: + devices = list(KeepKeyHidTransport.enumerate()) + if len(devices) != 1: + msg = '{:d} devices found'.format(len(devices)) + raise IOError(msg) + t = KeepKeyClient(KeepKeyHidTransport(devices[0])) + t.callback_PassphraseRequest = lambda msg: KeepKeyPassphraseAck(passphrase='') - t = TrezorClient(HidTransport(devices[0])) - t.callback_PassphraseRequest = lambda msg: PassphraseAck(passphrase='') return t -def identity_type(**kwargs): +def trezor_identity_type(**kwargs): # pylint: disable=import-error from trezorlib.types_pb2 import IdentityType return IdentityType(**kwargs) + +def keepkey_identity_type(**kwargs): + # pylint: disable=import-error + from keepkeylib.types_pb2 import IdentityType + return IdentityType(**kwargs) \ No newline at end of file diff --git a/trezor_agent/trezor/client.py b/trezor_agent/trezor/client.py index 033b984..b153685 100644 --- a/trezor_agent/trezor/client.py +++ b/trezor_agent/trezor/client.py @@ -4,7 +4,7 @@ import logging import re import struct -from . import _factory as TrezorFactory +from . import _factory as Factory from .. import formats, util log = logging.getLogger(__name__) @@ -12,9 +12,10 @@ log = logging.getLogger(__name__) class Client(object): - MIN_VERSION = [1, 3, 4] + TREZOR_MIN_VERSION = [1, 3, 4] + KEEPKEY_MIN_VERSION = [1, 0, 4] - def __init__(self, factory=TrezorFactory, curve=formats.CURVE_NIST256): + def __init__(self, factory=Factory, curve=formats.CURVE_NIST256): self.curve = curve self.factory = factory self.client = self.factory.client() @@ -26,9 +27,13 @@ class Client(object): version_str = '.'.join([str(v) for v in version]) log.debug('version : %s', version_str) log.debug('revision : %s', binascii.hexlify(f.revision)) - if version < self.MIN_VERSION: + if f.vendor == 'bitcointrezor.com' and version < self.TREZOR_MIN_VERSION: fmt = 'Please upgrade your TREZOR to v{}+ firmware' - version_str = '.'.join([str(v) for v in self.MIN_VERSION]) + version_str = '.'.join([str(v) for v in self.TREZOR_MIN_VERSION]) + raise ValueError(fmt.format(version_str)) + elif f.vendor == 'keepkey.com' and version < self.KEEPKEY_MIN_VERSION: + fmt = 'Please upgrade your KEEPKEY to v{}+ firmware' + version_str = '.'.join([str(v) for v in self.KEEPKEY_MIN_VERSION]) raise ValueError(fmt.format(version_str)) def __enter__(self): @@ -42,7 +47,11 @@ class Client(object): self.client.close() def get_identity(self, label): - identity = string_to_identity(label, self.factory.identity_type) + identity = string_to_identity(label, self.factory.trezor_identity_type) + + if self.client.features.vendor == 'keepkey.com': + identity = string_to_identity(label, self.factory.keepkey_identity_type) + identity.proto = 'ssh' return identity