Merge Trezor and KeepKey functionality

nistp521
Roman Zeyde 8 years ago
parent 5932a89dc5
commit 8c6ac43cf4

@ -5,7 +5,7 @@ python:
- "3.4"
install:
- pip install ecdsa ed25519 # test without trezorlib for now
- pip install ecdsa ed25519 semver # test without trezorlib for now
- pip install pylint coverage pep8
script:

@ -9,7 +9,7 @@ setup(
author_email='roman.zeyde@gmail.com',
url='http://github.com/romanz/trezor-agent',
packages=['trezor_agent', 'trezor_agent.trezor'],
install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6', 'keepkey>=0.7.0'],
install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6', 'keepkey>=0.7.0', 'semver>=2.2'],
platforms=['POSIX'],
classifiers=[
'Development Status :: 4 - Beta',

@ -7,6 +7,7 @@ deps=
pep8
coverage
pylint
semver
commands=
pep8 trezor_agent
pylint --reports=no --rcfile .pylintrc trezor_agent

@ -1,10 +1,9 @@
import io
import mock
import pytest
from .. import formats, util
from ..trezor import client
from ..trezor import client, factory
ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040]
CURVE = 'nist256p1'
@ -18,15 +17,7 @@ PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
class ConnectionMock(object):
def __init__(self, version):
self.features = mock.Mock(spec=[])
self.features.device_id = '123456789'
self.features.label = 'mywallet'
self.features.vendor = 'mock'
self.features.major_version = version[0]
self.features.minor_version = version[1]
self.features.patch_version = version[2]
self.features.revision = b'456'
def __init__(self):
self.closed = False
def close(self):
@ -49,21 +40,20 @@ class ConnectionMock(object):
return msg
class FactoryMock(object):
def identity_type(**kwargs):
result = mock.Mock(spec=[])
result.index = 0
result.proto = result.user = result.host = result.port = None
result.path = None
for k, v in kwargs.items():
setattr(result, k, v)
return result
@staticmethod
def client():
return ConnectionMock(version=(1, 3, 4))
@staticmethod
def identity_type(**kwargs):
result = mock.Mock(spec=[])
result.index = 0
result.proto = result.user = result.host = result.port = None
result.path = None
for k, v in kwargs.items():
setattr(result, k, v)
return result
def load_client():
return factory.ClientWrapper(connection=ConnectionMock(),
identity_type=identity_type,
device_name='DEVICE_NAME')
BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0'
@ -82,7 +72,7 @@ SIG = (b'\x00R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
def test_ssh_agent():
label = 'localhost:22'
c = client.Client(factory=FactoryMock)
c = client.Client(loader=load_client)
ident = c.get_identity(label=label)
assert ident.host == 'localhost'
assert ident.proto == 'ssh'
@ -129,15 +119,3 @@ def test_utils():
url = 'https://user@host:443/path'
assert client.identity_to_string(identity) == url
def test_old_version():
class OldFactoryMock(FactoryMock):
@staticmethod
def client():
return ConnectionMock(version=(1, 2, 3))
with pytest.raises(ValueError):
client.Client(factory=OldFactoryMock)

@ -1,37 +0,0 @@
''' Thin wrapper around trezorlib. '''
def client():
# pylint: disable=import-error
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport as TrezorHidTransport
from trezorlib.messages_pb2 import PassphraseAck as TrezorPassphraseAck
from keepkeylib.client import KeepKeyClient
from keepkeylib.transport_hid import HidTransport as KeepKeyHidTransport
from keepkeylib.messages_pb2 import PassphraseAck as KeepKeyPassphraseAck
devices = list(TrezorHidTransport.enumerate())
if len(devices) == 1:
t = TrezorClient(TrezorHidTransport(devices[0]))
t.callback_PassphraseRequest = lambda msg: TrezorPassphraseAck(passphrase='')
else:
devices = list(KeepKeyHidTransport.enumerate())
if len(devices) != 1:
msg = '{:d} devices found'.format(len(devices))
raise IOError(msg)
t = KeepKeyClient(KeepKeyHidTransport(devices[0]))
t.callback_PassphraseRequest = lambda msg: KeepKeyPassphraseAck(passphrase='')
return t
def trezor_identity_type(**kwargs):
# pylint: disable=import-error
from trezorlib.types_pb2 import IdentityType
return IdentityType(**kwargs)
def keepkey_identity_type(**kwargs):
# pylint: disable=import-error
from keepkeylib.types_pb2 import IdentityType
return IdentityType(**kwargs)

@ -4,7 +4,7 @@ import logging
import re
import struct
from . import _factory as Factory
from . import factory
from .. import formats, util
log = logging.getLogger(__name__)
@ -12,29 +12,12 @@ log = logging.getLogger(__name__)
class Client(object):
TREZOR_MIN_VERSION = [1, 3, 4]
KEEPKEY_MIN_VERSION = [1, 0, 4]
def __init__(self, factory=Factory, curve=formats.CURVE_NIST256):
def __init__(self, loader=factory.load, curve=formats.CURVE_NIST256):
client_wrapper = loader()
self.client = client_wrapper.connection
self.identity_type = client_wrapper.identity_type
self.device_name = client_wrapper.device_name
self.curve = curve
self.factory = factory
self.client = self.factory.client()
f = self.client.features
log.debug('connected to Trezor %s', f.device_id)
log.debug('label : %s', f.label)
log.debug('vendor : %s', f.vendor)
version = [f.major_version, f.minor_version, f.patch_version]
version_str = '.'.join([str(v) for v in version])
log.debug('version : %s', version_str)
log.debug('revision : %s', binascii.hexlify(f.revision))
if f.vendor == 'bitcointrezor.com' and version < self.TREZOR_MIN_VERSION:
fmt = 'Please upgrade your TREZOR to v{}+ firmware'
version_str = '.'.join([str(v) for v in self.TREZOR_MIN_VERSION])
raise ValueError(fmt.format(version_str))
elif f.vendor == 'keepkey.com' and version < self.KEEPKEY_MIN_VERSION:
fmt = 'Please upgrade your KEEPKEY to v{}+ firmware'
version_str = '.'.join([str(v) for v in self.KEEPKEY_MIN_VERSION])
raise ValueError(fmt.format(version_str))
def __enter__(self):
msg = 'Hello World!'
@ -42,24 +25,20 @@ class Client(object):
return self
def __exit__(self, *args):
log.info('disconnected from Trezor')
log.info('disconnected from %s', self.device_name)
self.client.clear_session() # forget PIN and shutdown screen
self.client.close()
def get_identity(self, label):
identity = string_to_identity(label, self.factory.trezor_identity_type)
if self.client.features.vendor == 'keepkey.com':
identity = string_to_identity(label, self.factory.keepkey_identity_type)
identity = string_to_identity(label, self.identity_type)
identity.proto = 'ssh'
return identity
def get_public_key(self, label):
identity = self.get_identity(label=label)
label = identity_to_string(identity) # canonize key label
log.info('getting "%s" public key (%s) from Trezor...',
label, self.curve)
log.info('getting "%s" public key (%s) from %s...',
label, self.curve, self.device_name)
addr = _get_address(identity)
node = self.client.get_public_node(n=addr,
ecdsa_curve_name=self.curve)
@ -72,8 +51,8 @@ class Client(object):
identity = self.get_identity(label=label)
msg = _parse_ssh_blob(blob)
log.info('please confirm user "%s" login to "%s" using Trezor...',
msg['user'], label)
log.info('please confirm user "%s" login to "%s" using %s...',
msg['user'], label, self.device_name)
visual = identity.path # not signed when proto='ssh'
result = self.client.sign_identity(identity=identity,

@ -0,0 +1,78 @@
''' Thin wrapper around trezor/keepkey libraries. '''
import binascii
import collections
import logging
import semver
log = logging.getLogger(__name__)
ClientWrapper = collections.namedtuple(
'ClientWrapper',
['connection', 'identity_type', 'device_name'])
# pylint: disable=too-many-arguments
def _load_client(name, client_type, hid_transport,
passphrase_ack, identity_type, required_version):
def empty_passphrase_handler(_):
return passphrase_ack(passphrase='')
for d in hid_transport.enumerate():
connection = client_type(hid_transport(d))
connection.callback_PassphraseRequest = empty_passphrase_handler
f = connection.features
log.debug('connected to %s %s', name, f.device_id)
log.debug('label : %s', f.label)
log.debug('vendor : %s', f.vendor)
current_version = '{}.{}.{}'.format(f.major_version,
f.minor_version,
f.patch_version)
log.debug('version : %s', current_version)
log.debug('revision : %s', binascii.hexlify(f.revision))
if not semver.match(current_version, required_version):
fmt = 'Please upgrade your {} firmware to {} version (current: {})'
raise ValueError(fmt.format(name,
required_version,
current_version))
yield ClientWrapper(connection=connection,
identity_type=identity_type,
device_name=name)
def _load_trezor():
# pylint: disable=import-error
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
from trezorlib.messages_pb2 import PassphraseAck
from trezorlib.types_pb2 import IdentityType
return _load_client(name='Trezor',
client_type=TrezorClient,
hid_transport=HidTransport,
passphrase_ack=PassphraseAck,
identity_type=IdentityType,
required_version='>=1.3.4')
def _load_keepkey():
# pylint: disable=import-error
from keepkeylib.client import KeepKeyClient
from keepkeylib.transport_hid import HidTransport
from keepkeylib.messages_pb2 import PassphraseAck
from keepkeylib.types_pb2 import IdentityType
return _load_client(name='KeepKey',
client_type=KeepKeyClient,
hid_transport=HidTransport,
passphrase_ack=PassphraseAck,
identity_type=IdentityType,
required_version='>=1.0.4')
def load():
devices = list(_load_trezor()) + list(_load_keepkey())
if len(devices) == 1:
return devices[0]
msg = '{:d} devices found'.format(len(devices))
raise IOError(msg)
Loading…
Cancel
Save