From 4f4db9bdd57b5dccb05971a54d80a0d953556905 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Sat, 6 Jun 2015 17:52:10 +0300 Subject: [PATCH] add support Trezor SSH agent --- agent.py | 123 +++++++++++++++++++++++++++++++ protocol.py | 208 ++++++++++++++++++++++++++++++++++++++++++++++++++++ trezor.py | 108 +++++++++++++++++++++++++++ 3 files changed, 439 insertions(+) create mode 100755 agent.py create mode 100644 protocol.py create mode 100644 trezor.py diff --git a/agent.py b/agent.py new file mode 100755 index 0000000..5213c6e --- /dev/null +++ b/agent.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +import socket +import os +import sys +import subprocess +import argparse +import tempfile +import contextlib +import threading +import logging +log = logging.getLogger(__name__) + +import protocol +import trezor + + +def load_keys(key_files): + keys = [] + for f in key_files: + k = protocol.load_public_key(f) + keys.append(k) + return keys + + +@contextlib.contextmanager +def unix_domain_socket_server(sock_path): + log.debug('serving on SSH_AUTH_SOCK=%s', sock_path) + try: + os.remove(sock_path) + except OSError: + if os.path.exists(sock_path): + raise + + server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server.bind(sock_path) + server.listen(1) + try: + yield server + finally: + os.remove(sock_path) + + +def worker_thread(server, keys, signer): + log.debug('worker thread started') + while True: + log.debug('waiting for connection on %s', server.getsockname()) + try: + conn, _ = server.accept() + except socket.error as e: + log.debug('server error: %s', e, exc_info=True) + break + with contextlib.closing(conn): + protocol.handle_connection(conn, keys, signer) + log.debug('worker thread stopped') + + +@contextlib.contextmanager +def spawn(func, **kwargs): + t = threading.Thread(target=func, kwargs=kwargs) + t.start() + yield + t.join() + + +def run(command, environ): + log.debug('running %r with %r', command, environ) + env = dict(os.environ) + env.update(environ) + p = subprocess.Popen(args=command, env=env) + log.debug('subprocess %d is running', p.pid) + ret = p.wait() + log.debug('subprocess %d exited: %d', p.pid, ret) + return ret + + +def serve(key_files, command, signer, sock_path=None): + if sock_path is None: + sock_path = tempfile.mktemp(prefix='ssh-agent-') + + keys = [protocol.parse_public_key(k) for k in key_files] + environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} + with unix_domain_socket_server(sock_path) as server: + with spawn(worker_thread, server=server, keys=keys, signer=signer): + ret = run(command=command, environ=environ) + log.debug('closing server') + server.shutdown(socket.SHUT_RD) + + log.info('exitcode: %d', ret) + sys.exit(ret) + + +def main(): + fmt = '%(asctime)s %(levelname)-12s %(message)-100s [%(filename)s]' + p = argparse.ArgumentParser() + p.add_argument('-k', '--key-label', dest='labels', action='append') + p.add_argument('-v', '--verbose', action='count', default=0) + p.add_argument('command', type=str, nargs='*') + args = p.parse_args() + + verbosity = [logging.WARNING, logging.INFO, logging.DEBUG] + level = verbosity[min(args.verbose, len(verbosity))] + logging.basicConfig(level=level, format=fmt) + + client = trezor.Client() + + key_files = [] + for label in args.labels: + pubkey = client.get_public_key(label=label) + key_files.append(trezor.export_public_key(pubkey=pubkey, label=label)) + + if not args.command: + sys.stdout.write(''.join(key_files)) + return + + signer = client.sign_ssh_challenge + + try: + serve(key_files=key_files, command=args.command, signer=signer) + except KeyboardInterrupt: + log.info('server stopped') + +if __name__ == '__main__': + main() diff --git a/protocol.py b/protocol.py new file mode 100644 index 0000000..c3e5733 --- /dev/null +++ b/protocol.py @@ -0,0 +1,208 @@ +import io +import struct +import hashlib +import ecdsa +import base64 + +import logging +log = logging.getLogger(__name__) + +def send(conn, data, fmt=None): + if fmt: + data = struct.pack(fmt, *data) + conn.sendall(data) + +def recv(conn, size): + try: + fmt = size + size = struct.calcsize(fmt) + except TypeError: + fmt = None + try: + _read = conn.recv + except AttributeError: + _read = conn.read + + res = io.BytesIO() + while size > 0: + buf = _read(size) + if not buf: + raise EOFError + size = size - len(buf) + res.write(buf) + res = res.getvalue() + if fmt: + return struct.unpack(fmt, res) + else: + return res + + +def read_frame(conn): + size, = recv(conn, '>L') + return recv(conn, size) + +def bytes2num(s): + res = 0 + for i, c in enumerate(reversed(bytearray(s))): + res += c << (i * 8) + return res + + +def parse_pubkey(blob): + s = io.BytesIO(blob) + key_type = read_frame(s) + log.debug('key type: %s', key_type) + curve = read_frame(s) + log.debug('curve name: %s', curve) + point = read_frame(s) + _type, point = point[:1], point[1:] + assert _type == DER_OCTET_STRING + size = len(point) // 2 + assert len(point) == 2 * size + coords = map(bytes2num, [point[:size], point[size:]]) + log.debug('coordinates: %s', coords) + fp = fingerprint(blob) + result = { + 'point': tuple(coords), 'curve': curve, + 'fingerprint': fp, + 'type': key_type, + 'blob': blob, 'size': size + } + return result + +def list_keys(c): + send(c, [0x1, 0xB], '>LB') + buf = io.BytesIO(read_frame(c)) + assert recv(buf, '>B') == (0xC,) + num, = recv(buf, '>L') + for i in range(num): + k = parse_pubkey(read_frame(buf)) + k['comment'] = read_frame(buf) + yield k + +def frame(*msgs): + res = io.BytesIO() + for msg in msgs: + res.write(msg) + msg = res.getvalue() + return pack('L', len(msg)) + msg + +SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1 +SSH_AGENT_RSA_IDENTITIES_ANSWER = 2 + +SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9 + +SSH2_AGENTC_REQUEST_IDENTITIES = 11 +SSH2_AGENT_IDENTITIES_ANSWER = 12 +SSH2_AGENTC_SIGN_REQUEST = 13 +SSH2_AGENT_SIGN_RESPONSE = 14 +SSH2_AGENTC_ADD_IDENTITY = 17 +SSH2_AGENTC_REMOVE_IDENTITY = 18 +SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19 + +def pack(fmt, *args): + return struct.pack('>' + fmt, *args) + +def legacy_pubs(buf, keys, signer): + code = pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER) + num = pack('L', 0) # no SSH v1 keys + return frame(code, num) + +def list_pubs(buf, keys, signer): + code = pack('B', SSH2_AGENT_IDENTITIES_ANSWER) + num = pack('L', len(keys)) + log.debug('available keys: %s', [k['name'] for k in keys]) + for i, k in enumerate(keys): + log.debug('%2d) %s', i+1, k['fingerprint']) + pubs = [frame(k['blob']) + frame(k['name']) for k in keys] + return frame(code, num, *pubs) + +def fingerprint(blob): + digest = hashlib.md5(blob).digest() + return ':'.join('{:02x}'.format(c) for c in bytearray(digest)) + +def num2bytes(value, size): + res = [] + for i in range(size): + res.append(value & 0xFF) + value = value >> 8 + assert value == 0 + return bytearray(list(reversed(res))) + +def sign_message(buf, keys, signer): + key = parse_pubkey(read_frame(buf)) + log.debug('looking for %s', key['fingerprint']) + blob = read_frame(buf) + + for k in keys: + if (k['fingerprint']) == (key['fingerprint']): + log.debug('using key %r (%s)', k['name'], k['fingerprint']) + key = k + break + else: + raise ValueError('key not found') + + log.debug('signing %d-byte blob', len(blob)) + r, s = signer(label=k['name'], blob=blob) + signature = (r, s) + + log.debug('signature: %s', signature) + + curve = ecdsa.curves.NIST256p + point = ecdsa.ellipticcurve.Point(curve.curve, *key['point']) + vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashlib.sha256) + success = vk.verify(signature=signature, data=blob, + sigdecode=lambda sig, _: sig) + log.info('signature status: %s', 'OK' if success else 'ERROR') + if not success: + raise ValueError('invalid signature') + + sig_bytes = io.BytesIO() + for x in signature: + sig_bytes.write(frame(b'\x00' + num2bytes(x, key['size']))) + sig_bytes = sig_bytes.getvalue() + log.debug('signature size: %d bytes', len(sig_bytes)) + + data = frame(frame(key['type']), frame(sig_bytes)) + code = pack('B', SSH2_AGENT_SIGN_RESPONSE) + return frame(code, data) + +handlers = { + SSH_AGENTC_REQUEST_RSA_IDENTITIES: legacy_pubs, + SSH2_AGENTC_REQUEST_IDENTITIES: list_pubs, + SSH2_AGENTC_SIGN_REQUEST: sign_message, +} + +def handle_connection(conn, keys, signer): + try: + log.debug('welcome agent') + while True: + msg = read_frame(conn) + buf = io.BytesIO(msg) + code, = recv(buf, '>B') + log.debug('request: %d bytes', len(msg)) + handler = handlers[code] + log.debug('calling %s()', handler.__name__) + reply = handler(buf=buf, keys=keys, signer=signer) + log.debug('reply: %d bytes', len(reply)) + send(conn, reply) + except EOFError: + log.debug('goodbye agent') + except: + log.exception('error') + raise + +DER_OCTET_STRING = b'\x04' + +def load_public_key(filename): + with open(filename) as f: + return parse_public_key(f.read()) + +def parse_public_key(data): + file_type, base64blob, name = data.split() + blob = base64.b64decode(base64blob) + result = parse_pubkey(blob) + result['name'] = name.encode('ascii') + assert result['type'] == file_type.encode('ascii') + log.debug('loaded %s %s', file_type, result['fingerprint']) + return result diff --git a/trezor.py b/trezor.py new file mode 100644 index 0000000..609909d --- /dev/null +++ b/trezor.py @@ -0,0 +1,108 @@ +import io +import base64 +import logging + +from trezorlib.client import TrezorClient +from trezorlib.transport_hid import HidTransport +from trezorlib.types_pb2 import IdentityType + +import ecdsa +import bitcoin +import hashlib + + +import protocol +log = logging.getLogger(__name__) + +curve = ecdsa.NIST256p +hashfunc = hashlib.sha256 + + +def decode_pubkey(pub): + P = curve.curve.p() + A = curve.curve.a() + B = curve.curve.b() + x = bitcoin.decode(pub[1:33], 256) + beta = pow(int(x*x*x+A*x+B), int((P+1)//4), int(P)) + y = (P-beta) if ((beta + bitcoin.from_byte_to_int(pub[0])) % 2) else beta + return (x, y) + + +def export_public_key(pubkey, label): + x, y = decode_pubkey(pubkey) + point = ecdsa.ellipticcurve.Point(curve.curve, x, y) + vk = ecdsa.VerifyingKey.from_public_point(point, curve=curve, + hashfunc=hashfunc) + key_type = 'ecdsa-sha2-nistp256' + curve_name = 'nistp256' + blobs = map(protocol.frame, [key_type, curve_name, '\x04' + vk.to_string()]) + b64 = base64.b64encode(''.join(blobs)) + return '{} {} {}\n'.format(key_type, b64, label) + + +def label_addr(ident): + index = '\x00' * 4 + addr = index + '{}://{}'.format(ident.proto, ident.host) + h = bytearray(hashfunc(addr).digest()) + + address_n = [0] * 5 + address_n[0] = 13 + address_n[1] = h[0] | (h[1] << 8) | (h[2] << 16) | (h[3] << 24) + address_n[2] = h[4] | (h[5] << 8) | (h[6] << 16) | (h[7] << 24) + address_n[3] = h[8] | (h[9] << 8) | (h[10] << 16) | (h[11] << 24) + address_n[4] = h[12] | (h[13] << 8) | (h[14] << 16) | (h[15] << 24) + return [-x for x in address_n] # prime each address component + + +class Client(object): + + proto = 'ssh' + + def __init__(self): + device, = HidTransport.enumerate() + client = TrezorClient(HidTransport(device)) + log.debug('connected to Trezor #%s', client.get_device_id()) + self.client = client + + def close(self): + self.client.close() + + def _get_identity(self, label): + return IdentityType(host=label, proto=self.proto) + + def get_public_key(self, label): + addr = label_addr(self._get_identity(label)) + log.info('getting %r SSH public key from Trezor...', label) + node = self.client.get_public_node(addr) + return node.node.public_key + + def sign_ssh_challenge(self, label, blob): + ident = self._get_identity(label) + msg = parse_ssh_blob(blob) + request = 'user: "{user}"'.format(**msg) + + log.info('confirm %s connection to %r using Trezor...', + request, label) + s = self.client.sign_identity(identity=ident, + challenge_hidden=blob, + challenge_visual=request) + r = protocol.bytes2num(s.signature[:32]) + s = protocol.bytes2num(s.signature[32:]) + return (r, s) + + +def parse_ssh_blob(data): + res = {} + if data: + i = io.BytesIO(data) + res['nonce'] = protocol.read_frame(i) + i.read(1) # TBD + res['user'] = protocol.read_frame(i) + res['conn'] = protocol.read_frame(i) + res['auth'] = protocol.read_frame(i) + i.read(1) # TBD + res['key_type'] = protocol.read_frame(i) + res['pubkey'] = protocol.read_frame(i) + log.debug('%s: user %r via %r (%r)', + res['conn'], res['user'], res['auth'], res['key_type']) + return res