diff --git a/sshagent/formats.py b/sshagent/formats.py new file mode 100644 index 0000000..99f29d3 --- /dev/null +++ b/sshagent/formats.py @@ -0,0 +1,80 @@ +import io +import hashlib +import base64 +import ecdsa + +import logging +log = logging.getLogger(__name__) + +from . import util + +def fingerprint(blob): + digest = hashlib.md5(blob).digest() + return ':'.join('{:02x}'.format(c) for c in bytearray(digest)) + +DER_OCTET_STRING = b'\x04' + +curve = ecdsa.NIST256p +hashfunc = hashlib.sha256 + +def parse_pubkey(blob): + s = io.BytesIO(blob) + key_type = util.read_frame(s) + log.debug('key type: %s', key_type) + curve_name = util.read_frame(s) + log.debug('curve name: %s', curve_name) + point = util.read_frame(s) + _type, point = point[:1], point[1:] + assert _type == DER_OCTET_STRING + size = len(point) // 2 + assert len(point) == 2 * size + coords = (util.bytes2num(point[:size]), util.bytes2num(point[size:])) + log.debug('coordinates: %s', coords) + fp = fingerprint(blob) + + point = ecdsa.ellipticcurve.Point(curve.curve, *coords) + vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashfunc) + result = { + 'point': coords, + 'curve': curve_name, + 'fingerprint': fp, + 'type': key_type, + 'blob': blob, + 'size': size, + 'verifying_key': vk + } + return result + +def load_public_key(filename): + with open(filename) as f: + return parse_public_key(f.read()) + +def parse_public_key(data): + file_type, base64blob, name = data.split() + blob = base64.b64decode(base64blob) + result = parse_pubkey(blob) + result['name'] = name.encode('ascii') + assert result['type'] == file_type.encode('ascii') + log.debug('loaded %s %s', file_type, result['fingerprint']) + return result + +def decompress_pubkey(pub): + P = curve.curve.p() + A = curve.curve.a() + B = curve.curve.b() + x = util.bytes2num(pub[1:33]) + beta = pow(int(x*x*x+A*x+B), int((P+1)//4), int(P)) + y = (P-beta) if ((beta + ord(pub[0])) % 2) else beta + return (x, y) + + +def export_public_key(pubkey, label): + x, y = decompress_pubkey(pubkey) + point = ecdsa.ellipticcurve.Point(curve.curve, x, y) + vk = ecdsa.VerifyingKey.from_public_point(point, curve=curve, + hashfunc=hashfunc) + key_type = 'ecdsa-sha2-nistp256' + curve_name = 'nistp256' + blobs = map(util.frame, [key_type, curve_name, '\x04' + vk.to_string()]) + b64 = base64.b64encode(''.join(blobs)) + return '{} {} {}\n'.format(key_type, b64, label) diff --git a/sshagent/protocol.py b/sshagent/protocol.py index c3e5733..07374a1 100644 --- a/sshagent/protocol.py +++ b/sshagent/protocol.py @@ -1,92 +1,11 @@ import io -import struct -import hashlib -import ecdsa -import base64 + +from . import util +from . import formats import logging log = logging.getLogger(__name__) -def send(conn, data, fmt=None): - if fmt: - data = struct.pack(fmt, *data) - conn.sendall(data) - -def recv(conn, size): - try: - fmt = size - size = struct.calcsize(fmt) - except TypeError: - fmt = None - try: - _read = conn.recv - except AttributeError: - _read = conn.read - - res = io.BytesIO() - while size > 0: - buf = _read(size) - if not buf: - raise EOFError - size = size - len(buf) - res.write(buf) - res = res.getvalue() - if fmt: - return struct.unpack(fmt, res) - else: - return res - - -def read_frame(conn): - size, = recv(conn, '>L') - return recv(conn, size) - -def bytes2num(s): - res = 0 - for i, c in enumerate(reversed(bytearray(s))): - res += c << (i * 8) - return res - - -def parse_pubkey(blob): - s = io.BytesIO(blob) - key_type = read_frame(s) - log.debug('key type: %s', key_type) - curve = read_frame(s) - log.debug('curve name: %s', curve) - point = read_frame(s) - _type, point = point[:1], point[1:] - assert _type == DER_OCTET_STRING - size = len(point) // 2 - assert len(point) == 2 * size - coords = map(bytes2num, [point[:size], point[size:]]) - log.debug('coordinates: %s', coords) - fp = fingerprint(blob) - result = { - 'point': tuple(coords), 'curve': curve, - 'fingerprint': fp, - 'type': key_type, - 'blob': blob, 'size': size - } - return result - -def list_keys(c): - send(c, [0x1, 0xB], '>LB') - buf = io.BytesIO(read_frame(c)) - assert recv(buf, '>B') == (0xC,) - num, = recv(buf, '>L') - for i in range(num): - k = parse_pubkey(read_frame(buf)) - k['comment'] = read_frame(buf) - yield k - -def frame(*msgs): - res = io.BytesIO() - for msg in msgs: - res.write(msg) - msg = res.getvalue() - return pack('L', len(msg)) + msg - SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1 SSH_AGENT_RSA_IDENTITIES_ANSWER = 2 @@ -100,39 +19,34 @@ SSH2_AGENTC_ADD_IDENTITY = 17 SSH2_AGENTC_REMOVE_IDENTITY = 18 SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19 -def pack(fmt, *args): - return struct.pack('>' + fmt, *args) +def list_keys(c): + util.send(c, [0x1, 0xB], '>LB') + buf = io.BytesIO(util.read_frame(c)) + assert util.recv(buf, '>B') == (0xC,) + num, = util.recv(buf, '>L') + for i in range(num): + k = formats.parse_pubkey(util.read_frame(buf)) + k['comment'] = util.read_frame(buf) + yield k def legacy_pubs(buf, keys, signer): - code = pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER) - num = pack('L', 0) # no SSH v1 keys - return frame(code, num) + code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER) + num = util.pack('L', 0) # no SSH v1 keys + return util.frame(code, num) def list_pubs(buf, keys, signer): - code = pack('B', SSH2_AGENT_IDENTITIES_ANSWER) - num = pack('L', len(keys)) + code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER) + num = util.pack('L', len(keys)) log.debug('available keys: %s', [k['name'] for k in keys]) for i, k in enumerate(keys): log.debug('%2d) %s', i+1, k['fingerprint']) - pubs = [frame(k['blob']) + frame(k['name']) for k in keys] - return frame(code, num, *pubs) - -def fingerprint(blob): - digest = hashlib.md5(blob).digest() - return ':'.join('{:02x}'.format(c) for c in bytearray(digest)) - -def num2bytes(value, size): - res = [] - for i in range(size): - res.append(value & 0xFF) - value = value >> 8 - assert value == 0 - return bytearray(list(reversed(res))) + pubs = [util.frame(k['blob']) + util.frame(k['name']) for k in keys] + return util.frame(code, num, *pubs) def sign_message(buf, keys, signer): - key = parse_pubkey(read_frame(buf)) + key = formats.parse_pubkey(util.read_frame(buf)) log.debug('looking for %s', key['fingerprint']) - blob = read_frame(buf) + blob = util.read_frame(buf) for k in keys: if (k['fingerprint']) == (key['fingerprint']): @@ -145,27 +59,23 @@ def sign_message(buf, keys, signer): log.debug('signing %d-byte blob', len(blob)) r, s = signer(label=k['name'], blob=blob) signature = (r, s) - log.debug('signature: %s', signature) - curve = ecdsa.curves.NIST256p - point = ecdsa.ellipticcurve.Point(curve.curve, *key['point']) - vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashlib.sha256) - success = vk.verify(signature=signature, data=blob, - sigdecode=lambda sig, _: sig) + success = key['verifying_key'].verify(signature=signature, data=blob, + sigdecode=lambda sig, _: sig) log.info('signature status: %s', 'OK' if success else 'ERROR') if not success: raise ValueError('invalid signature') sig_bytes = io.BytesIO() for x in signature: - sig_bytes.write(frame(b'\x00' + num2bytes(x, key['size']))) + sig_bytes.write(util.frame(b'\x00' + util.num2bytes(x, key['size']))) sig_bytes = sig_bytes.getvalue() log.debug('signature size: %d bytes', len(sig_bytes)) - data = frame(frame(key['type']), frame(sig_bytes)) - code = pack('B', SSH2_AGENT_SIGN_RESPONSE) - return frame(code, data) + data = util.frame(util.frame(key['type']), util.frame(sig_bytes)) + code = util.pack('B', SSH2_AGENT_SIGN_RESPONSE) + return util.frame(code, data) handlers = { SSH_AGENTC_REQUEST_RSA_IDENTITIES: legacy_pubs, @@ -173,36 +83,12 @@ handlers = { SSH2_AGENTC_SIGN_REQUEST: sign_message, } -def handle_connection(conn, keys, signer): - try: - log.debug('welcome agent') - while True: - msg = read_frame(conn) - buf = io.BytesIO(msg) - code, = recv(buf, '>B') - log.debug('request: %d bytes', len(msg)) - handler = handlers[code] - log.debug('calling %s()', handler.__name__) - reply = handler(buf=buf, keys=keys, signer=signer) - log.debug('reply: %d bytes', len(reply)) - send(conn, reply) - except EOFError: - log.debug('goodbye agent') - except: - log.exception('error') - raise - -DER_OCTET_STRING = b'\x04' - -def load_public_key(filename): - with open(filename) as f: - return parse_public_key(f.read()) - -def parse_public_key(data): - file_type, base64blob, name = data.split() - blob = base64.b64decode(base64blob) - result = parse_pubkey(blob) - result['name'] = name.encode('ascii') - assert result['type'] == file_type.encode('ascii') - log.debug('loaded %s %s', file_type, result['fingerprint']) - return result +def handle_message(msg, keys, signer): + log.debug('request: %d bytes', len(msg)) + buf = io.BytesIO(msg) + code, = util.recv(buf, '>B') + handler = handlers[code] + log.debug('calling %s()', handler.__name__) + reply = handler(buf=buf, keys=keys, signer=signer) + log.debug('reply: %d bytes', len(reply)) + return reply diff --git a/sshagent/server.py b/sshagent/server.py index 81c24e9..9bb78f2 100644 --- a/sshagent/server.py +++ b/sshagent/server.py @@ -1,7 +1,5 @@ -#!/usr/bin/env python import socket import os -import sys import subprocess import tempfile import contextlib @@ -9,8 +7,9 @@ import threading import logging log = logging.getLogger(__name__) -import protocol - +from . import protocol +from . import formats +from . import util @contextlib.contextmanager def unix_domain_socket_server(sock_path): @@ -29,6 +28,18 @@ def unix_domain_socket_server(sock_path): finally: os.remove(sock_path) +def handle_connection(conn, keys, signer): + try: + log.debug('welcome agent') + while True: + msg = util.read_frame(conn) + reply = protocol.handle_message(msg=msg, keys=keys, signer=signer) + util.send(conn, reply) + except EOFError: + log.debug('goodbye agent') + except: + log.exception('error') + raise def server_thread(server, keys, signer): log.debug('server thread started') @@ -40,7 +51,7 @@ def server_thread(server, keys, signer): log.debug('server error: %s', e, exc_info=True) break with contextlib.closing(conn): - protocol.handle_connection(conn, keys, signer) + handle_connection(conn, keys, signer) log.debug('server thread stopped') @@ -70,7 +81,7 @@ def serve(key_files, command, signer, sock_path=None): if sock_path is None: sock_path = tempfile.mktemp(prefix='ssh-agent-') - keys = [protocol.parse_public_key(k) for k in key_files] + keys = [formats.parse_public_key(k) for k in key_files] environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} with unix_domain_socket_server(sock_path) as server: with spawn(server_thread, server=server, keys=keys, signer=signer): @@ -79,6 +90,4 @@ def serve(key_files, command, signer, sock_path=None): finally: log.debug('closing server') server.shutdown(socket.SHUT_RD) - - log.info('exitcode: %d', ret) - sys.exit(ret) + return ret diff --git a/sshagent/trezor.py b/sshagent/trezor.py index deace85..f0e30e7 100644 --- a/sshagent/trezor.py +++ b/sshagent/trezor.py @@ -1,50 +1,20 @@ import io -import base64 -import logging import binascii from trezorlib.client import TrezorClient from trezorlib.transport_hid import HidTransport from trezorlib.types_pb2 import IdentityType -import ecdsa -import bitcoin -import hashlib - +from . import util +from . import formats -import protocol +import logging log = logging.getLogger(__name__) -curve = ecdsa.NIST256p -hashfunc = hashlib.sha256 - - -def decode_pubkey(pub): - P = curve.curve.p() - A = curve.curve.a() - B = curve.curve.b() - x = bitcoin.decode(pub[1:33], 256) - beta = pow(int(x*x*x+A*x+B), int((P+1)//4), int(P)) - y = (P-beta) if ((beta + bitcoin.from_byte_to_int(pub[0])) % 2) else beta - return (x, y) - - -def export_public_key(pubkey, label): - x, y = decode_pubkey(pubkey) - point = ecdsa.ellipticcurve.Point(curve.curve, x, y) - vk = ecdsa.VerifyingKey.from_public_point(point, curve=curve, - hashfunc=hashfunc) - key_type = 'ecdsa-sha2-nistp256' - curve_name = 'nistp256' - blobs = map(protocol.frame, [key_type, curve_name, '\x04' + vk.to_string()]) - b64 = base64.b64encode(''.join(blobs)) - return '{} {} {}\n'.format(key_type, b64, label) - - def label_addr(ident): index = '\x00' * 4 addr = index + '{}://{}'.format(ident.proto, ident.host) - h = bytearray(hashfunc(addr).digest()) + h = bytearray(formats.hashfunc(addr).digest()) address_n = [0] * 5 address_n[0] = 13 @@ -96,8 +66,8 @@ class Client(object): s = self.client.sign_identity(identity=ident, challenge_hidden=blob, challenge_visual=request) - r = protocol.bytes2num(s.signature[:32]) - s = protocol.bytes2num(s.signature[32:]) + r = util.bytes2num(s.signature[:32]) + s = util.bytes2num(s.signature[32:]) return (r, s) @@ -105,14 +75,14 @@ def parse_ssh_blob(data): res = {} if data: i = io.BytesIO(data) - res['nonce'] = protocol.read_frame(i) + res['nonce'] = util.read_frame(i) i.read(1) # TBD - res['user'] = protocol.read_frame(i) - res['conn'] = protocol.read_frame(i) - res['auth'] = protocol.read_frame(i) + res['user'] = util.read_frame(i) + res['conn'] = util.read_frame(i) + res['auth'] = util.read_frame(i) i.read(1) # TBD - res['key_type'] = protocol.read_frame(i) - res['pubkey'] = protocol.read_frame(i) + res['key_type'] = util.read_frame(i) + res['pubkey'] = util.read_frame(i) log.debug('%s: user %r via %r (%r)', res['conn'], res['user'], res['auth'], res['key_type']) return res diff --git a/sshagent/trezor_agent.py b/sshagent/trezor_agent.py index 2667f89..9b6e477 100644 --- a/sshagent/trezor_agent.py +++ b/sshagent/trezor_agent.py @@ -3,8 +3,9 @@ import argparse import logging log = logging.getLogger(__name__) -import trezor -import server +from . import trezor +from . import server +from . import formats def main(): fmt = '%(asctime)s %(levelname)-12s %(message)-100s [%(filename)s]' @@ -24,7 +25,8 @@ def main(): key_files = [] for label in args.labels: pubkey = client.get_public_key(label=label) - key_files.append(trezor.export_public_key(pubkey=pubkey, label=label)) + key_file = formats.export_public_key(pubkey=pubkey, label=label) + key_files.append(key_file) if not args.command: sys.stdout.write(''.join(key_files)) @@ -32,12 +34,18 @@ def main(): signer = client.sign_ssh_challenge + ret = -1 try: - server.serve(key_files=key_files, command=args.command, signer=signer) + ret = server.serve( + key_files=key_files, + command=args.command, + signer=signer) + log.info('exitcode: %d', ret) except KeyboardInterrupt: log.info('server stopped') except Exception as e: log.warning(e, exc_info=True) + sys.exit(ret) if __name__ == '__main__': main() diff --git a/sshagent/util.py b/sshagent/util.py new file mode 100644 index 0000000..b837b71 --- /dev/null +++ b/sshagent/util.py @@ -0,0 +1,60 @@ +import struct +import io + +def send(conn, data, fmt=None): + if fmt: + data = struct.pack(fmt, *data) + conn.sendall(data) + +def recv(conn, size): + try: + fmt = size + size = struct.calcsize(fmt) + except TypeError: + fmt = None + try: + _read = conn.recv + except AttributeError: + _read = conn.read + + res = io.BytesIO() + while size > 0: + buf = _read(size) + if not buf: + raise EOFError + size = size - len(buf) + res.write(buf) + res = res.getvalue() + if fmt: + return struct.unpack(fmt, res) + else: + return res + + +def read_frame(conn): + size, = recv(conn, '>L') + return recv(conn, size) + +def bytes2num(s): + res = 0 + for i, c in enumerate(reversed(bytearray(s))): + res += c << (i * 8) + return res + +def num2bytes(value, size): + res = [] + for i in range(size): + res.append(value & 0xFF) + value = value >> 8 + assert value == 0 + return bytearray(list(reversed(res))) + +def pack(fmt, *args): + return struct.pack('>' + fmt, *args) + +def frame(*msgs): + res = io.BytesIO() + for msg in msgs: + res.write(msg) + msg = res.getvalue() + return pack('L', len(msg)) + msg