diff --git a/sshagent/trezor.py b/sshagent/trezor.py index 70b3891..80ad512 100644 --- a/sshagent/trezor.py +++ b/sshagent/trezor.py @@ -1,10 +1,6 @@ import io import binascii -from trezorlib.client import TrezorClient -from trezorlib.transport_hid import HidTransport -from trezorlib.types_pb2 import IdentityType - from . import util from . import formats @@ -12,14 +8,29 @@ import logging log = logging.getLogger(__name__) -class Client(object): +class TrezorLibrary(object): - def __init__(self): + @staticmethod + def client(): + from trezorlib.client import TrezorClient + from trezorlib.transport_hid import HidTransport devices = HidTransport.enumerate() if len(devices) != 1: raise ValueError('{:d} Trezor devices found'.format(len(devices))) - client = TrezorClient(HidTransport(devices[0])) - f = client.features + return TrezorClient(HidTransport(devices[0])) + + @staticmethod + def identity(label, proto='ssh'): + from trezorlib.types_pb2 import IdentityType + return IdentityType(host=label, proto=proto) + + +class Client(object): + + def __init__(self, factory=TrezorLibrary): + self.factory = factory + self.client = self.factory.client() + f = self.client.features log.info('connected to Trezor') log.debug('ID : %s', f.device_id) log.debug('label : %s', f.label) @@ -27,19 +38,18 @@ class Client(object): version = [f.major_version, f.minor_version, f.patch_version] log.debug('version : %s', '.'.join([str(v) for v in version])) log.debug('revision : %s', binascii.hexlify(f.revision)) - self.client = client def close(self): self.client.close() def get_public_key(self, label): - addr = _get_address(_get_identity(label)) + addr = _get_address(self.factory.identity(label)) log.info('getting %r SSH public key from Trezor...', label) node = self.client.get_public_node(addr) return node.node.public_key def sign_ssh_challenge(self, label, blob): - ident = _get_identity(label) + ident = self.factory.identity(label) msg = _parse_ssh_blob(blob) request = 'user: "{user}"'.format(**msg) @@ -54,10 +64,6 @@ class Client(object): return (r, s) -def _get_identity(label, proto='ssh'): - return IdentityType(host=label, proto=proto) - - def _get_address(ident): index = '\x00' * 4 addr = index + '{}://{}'.format(ident.proto, ident.host) diff --git a/sshagent/trezor_agent.py b/sshagent/trezor_agent.py index 2723248..46aa82f 100644 --- a/sshagent/trezor_agent.py +++ b/sshagent/trezor_agent.py @@ -22,7 +22,7 @@ def main(): level = verbosity[min(args.verbose, len(verbosity) - 1)] logging.basicConfig(level=level, format=fmt) - client = trezor.Client() + client = trezor.Client(factory=trezor.TrezorLibrary) key_files = [] for label in args.labels: