diff --git a/sshagent/trezor.py b/sshagent/trezor.py index f0e30e7..b108a41 100644 --- a/sshagent/trezor.py +++ b/sshagent/trezor.py @@ -11,24 +11,9 @@ from . import formats import logging log = logging.getLogger(__name__) -def label_addr(ident): - index = '\x00' * 4 - addr = index + '{}://{}'.format(ident.proto, ident.host) - h = bytearray(formats.hashfunc(addr).digest()) - - address_n = [0] * 5 - address_n[0] = 13 - address_n[1] = h[0] | (h[1] << 8) | (h[2] << 16) | (h[3] << 24) - address_n[2] = h[4] | (h[5] << 8) | (h[6] << 16) | (h[7] << 24) - address_n[3] = h[8] | (h[9] << 8) | (h[10] << 16) | (h[11] << 24) - address_n[4] = h[12] | (h[13] << 8) | (h[14] << 16) | (h[15] << 24) - return [-x for x in address_n] # prime each address component - class Client(object): - proto = 'ssh' - def __init__(self): devices = HidTransport.enumerate() if len(devices) != 1: @@ -47,18 +32,15 @@ class Client(object): def close(self): self.client.close() - def _get_identity(self, label): - return IdentityType(host=label, proto=self.proto) - def get_public_key(self, label): - addr = label_addr(self._get_identity(label)) + addr = _get_address(_get_identity(label)) log.info('getting %r SSH public key from Trezor...', label) node = self.client.get_public_node(addr) return node.node.public_key def sign_ssh_challenge(self, label, blob): - ident = self._get_identity(label) - msg = parse_ssh_blob(blob) + ident = _get_identity(label) + msg = _parse_ssh_blob(blob) request = 'user: "{user}"'.format(**msg) log.info('confirm %s connection to %r using Trezor...', @@ -66,12 +48,27 @@ class Client(object): s = self.client.sign_identity(identity=ident, challenge_hidden=blob, challenge_visual=request) + assert len(s.signature) == 64 r = util.bytes2num(s.signature[:32]) s = util.bytes2num(s.signature[32:]) return (r, s) -def parse_ssh_blob(data): +def _get_identity(label, proto='ssh'): + return IdentityType(host=label, proto=proto) + + +def _get_address(ident): + index = '\x00' * 4 + addr = index + '{}://{}'.format(ident.proto, ident.host) + digest = formats.hashfunc(addr).digest() + s = io.BytesIO(bytearray(digest)) + + address_n = [13] + list(util.recv(s, '