add support Trezor SSH agent

nistp521
Roman Zeyde 9 years ago
parent 14162b664f
commit 4f4db9bdd5

@ -0,0 +1,123 @@
#!/usr/bin/env python
import socket
import os
import sys
import subprocess
import argparse
import tempfile
import contextlib
import threading
import logging
log = logging.getLogger(__name__)
import protocol
import trezor
def load_keys(key_files):
keys = []
for f in key_files:
k = protocol.load_public_key(f)
keys.append(k)
return keys
@contextlib.contextmanager
def unix_domain_socket_server(sock_path):
log.debug('serving on SSH_AUTH_SOCK=%s', sock_path)
try:
os.remove(sock_path)
except OSError:
if os.path.exists(sock_path):
raise
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(sock_path)
server.listen(1)
try:
yield server
finally:
os.remove(sock_path)
def worker_thread(server, keys, signer):
log.debug('worker thread started')
while True:
log.debug('waiting for connection on %s', server.getsockname())
try:
conn, _ = server.accept()
except socket.error as e:
log.debug('server error: %s', e, exc_info=True)
break
with contextlib.closing(conn):
protocol.handle_connection(conn, keys, signer)
log.debug('worker thread stopped')
@contextlib.contextmanager
def spawn(func, **kwargs):
t = threading.Thread(target=func, kwargs=kwargs)
t.start()
yield
t.join()
def run(command, environ):
log.debug('running %r with %r', command, environ)
env = dict(os.environ)
env.update(environ)
p = subprocess.Popen(args=command, env=env)
log.debug('subprocess %d is running', p.pid)
ret = p.wait()
log.debug('subprocess %d exited: %d', p.pid, ret)
return ret
def serve(key_files, command, signer, sock_path=None):
if sock_path is None:
sock_path = tempfile.mktemp(prefix='ssh-agent-')
keys = [protocol.parse_public_key(k) for k in key_files]
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
with unix_domain_socket_server(sock_path) as server:
with spawn(worker_thread, server=server, keys=keys, signer=signer):
ret = run(command=command, environ=environ)
log.debug('closing server')
server.shutdown(socket.SHUT_RD)
log.info('exitcode: %d', ret)
sys.exit(ret)
def main():
fmt = '%(asctime)s %(levelname)-12s %(message)-100s [%(filename)s]'
p = argparse.ArgumentParser()
p.add_argument('-k', '--key-label', dest='labels', action='append')
p.add_argument('-v', '--verbose', action='count', default=0)
p.add_argument('command', type=str, nargs='*')
args = p.parse_args()
verbosity = [logging.WARNING, logging.INFO, logging.DEBUG]
level = verbosity[min(args.verbose, len(verbosity))]
logging.basicConfig(level=level, format=fmt)
client = trezor.Client()
key_files = []
for label in args.labels:
pubkey = client.get_public_key(label=label)
key_files.append(trezor.export_public_key(pubkey=pubkey, label=label))
if not args.command:
sys.stdout.write(''.join(key_files))
return
signer = client.sign_ssh_challenge
try:
serve(key_files=key_files, command=args.command, signer=signer)
except KeyboardInterrupt:
log.info('server stopped')
if __name__ == '__main__':
main()

@ -0,0 +1,208 @@
import io
import struct
import hashlib
import ecdsa
import base64
import logging
log = logging.getLogger(__name__)
def send(conn, data, fmt=None):
if fmt:
data = struct.pack(fmt, *data)
conn.sendall(data)
def recv(conn, size):
try:
fmt = size
size = struct.calcsize(fmt)
except TypeError:
fmt = None
try:
_read = conn.recv
except AttributeError:
_read = conn.read
res = io.BytesIO()
while size > 0:
buf = _read(size)
if not buf:
raise EOFError
size = size - len(buf)
res.write(buf)
res = res.getvalue()
if fmt:
return struct.unpack(fmt, res)
else:
return res
def read_frame(conn):
size, = recv(conn, '>L')
return recv(conn, size)
def bytes2num(s):
res = 0
for i, c in enumerate(reversed(bytearray(s))):
res += c << (i * 8)
return res
def parse_pubkey(blob):
s = io.BytesIO(blob)
key_type = read_frame(s)
log.debug('key type: %s', key_type)
curve = read_frame(s)
log.debug('curve name: %s', curve)
point = read_frame(s)
_type, point = point[:1], point[1:]
assert _type == DER_OCTET_STRING
size = len(point) // 2
assert len(point) == 2 * size
coords = map(bytes2num, [point[:size], point[size:]])
log.debug('coordinates: %s', coords)
fp = fingerprint(blob)
result = {
'point': tuple(coords), 'curve': curve,
'fingerprint': fp,
'type': key_type,
'blob': blob, 'size': size
}
return result
def list_keys(c):
send(c, [0x1, 0xB], '>LB')
buf = io.BytesIO(read_frame(c))
assert recv(buf, '>B') == (0xC,)
num, = recv(buf, '>L')
for i in range(num):
k = parse_pubkey(read_frame(buf))
k['comment'] = read_frame(buf)
yield k
def frame(*msgs):
res = io.BytesIO()
for msg in msgs:
res.write(msg)
msg = res.getvalue()
return pack('L', len(msg)) + msg
SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1
SSH_AGENT_RSA_IDENTITIES_ANSWER = 2
SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
SSH2_AGENTC_REQUEST_IDENTITIES = 11
SSH2_AGENT_IDENTITIES_ANSWER = 12
SSH2_AGENTC_SIGN_REQUEST = 13
SSH2_AGENT_SIGN_RESPONSE = 14
SSH2_AGENTC_ADD_IDENTITY = 17
SSH2_AGENTC_REMOVE_IDENTITY = 18
SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19
def pack(fmt, *args):
return struct.pack('>' + fmt, *args)
def legacy_pubs(buf, keys, signer):
code = pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
num = pack('L', 0) # no SSH v1 keys
return frame(code, num)
def list_pubs(buf, keys, signer):
code = pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
num = pack('L', len(keys))
log.debug('available keys: %s', [k['name'] for k in keys])
for i, k in enumerate(keys):
log.debug('%2d) %s', i+1, k['fingerprint'])
pubs = [frame(k['blob']) + frame(k['name']) for k in keys]
return frame(code, num, *pubs)
def fingerprint(blob):
digest = hashlib.md5(blob).digest()
return ':'.join('{:02x}'.format(c) for c in bytearray(digest))
def num2bytes(value, size):
res = []
for i in range(size):
res.append(value & 0xFF)
value = value >> 8
assert value == 0
return bytearray(list(reversed(res)))
def sign_message(buf, keys, signer):
key = parse_pubkey(read_frame(buf))
log.debug('looking for %s', key['fingerprint'])
blob = read_frame(buf)
for k in keys:
if (k['fingerprint']) == (key['fingerprint']):
log.debug('using key %r (%s)', k['name'], k['fingerprint'])
key = k
break
else:
raise ValueError('key not found')
log.debug('signing %d-byte blob', len(blob))
r, s = signer(label=k['name'], blob=blob)
signature = (r, s)
log.debug('signature: %s', signature)
curve = ecdsa.curves.NIST256p
point = ecdsa.ellipticcurve.Point(curve.curve, *key['point'])
vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashlib.sha256)
success = vk.verify(signature=signature, data=blob,
sigdecode=lambda sig, _: sig)
log.info('signature status: %s', 'OK' if success else 'ERROR')
if not success:
raise ValueError('invalid signature')
sig_bytes = io.BytesIO()
for x in signature:
sig_bytes.write(frame(b'\x00' + num2bytes(x, key['size'])))
sig_bytes = sig_bytes.getvalue()
log.debug('signature size: %d bytes', len(sig_bytes))
data = frame(frame(key['type']), frame(sig_bytes))
code = pack('B', SSH2_AGENT_SIGN_RESPONSE)
return frame(code, data)
handlers = {
SSH_AGENTC_REQUEST_RSA_IDENTITIES: legacy_pubs,
SSH2_AGENTC_REQUEST_IDENTITIES: list_pubs,
SSH2_AGENTC_SIGN_REQUEST: sign_message,
}
def handle_connection(conn, keys, signer):
try:
log.debug('welcome agent')
while True:
msg = read_frame(conn)
buf = io.BytesIO(msg)
code, = recv(buf, '>B')
log.debug('request: %d bytes', len(msg))
handler = handlers[code]
log.debug('calling %s()', handler.__name__)
reply = handler(buf=buf, keys=keys, signer=signer)
log.debug('reply: %d bytes', len(reply))
send(conn, reply)
except EOFError:
log.debug('goodbye agent')
except:
log.exception('error')
raise
DER_OCTET_STRING = b'\x04'
def load_public_key(filename):
with open(filename) as f:
return parse_public_key(f.read())
def parse_public_key(data):
file_type, base64blob, name = data.split()
blob = base64.b64decode(base64blob)
result = parse_pubkey(blob)
result['name'] = name.encode('ascii')
assert result['type'] == file_type.encode('ascii')
log.debug('loaded %s %s', file_type, result['fingerprint'])
return result

@ -0,0 +1,108 @@
import io
import base64
import logging
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
from trezorlib.types_pb2 import IdentityType
import ecdsa
import bitcoin
import hashlib
import protocol
log = logging.getLogger(__name__)
curve = ecdsa.NIST256p
hashfunc = hashlib.sha256
def decode_pubkey(pub):
P = curve.curve.p()
A = curve.curve.a()
B = curve.curve.b()
x = bitcoin.decode(pub[1:33], 256)
beta = pow(int(x*x*x+A*x+B), int((P+1)//4), int(P))
y = (P-beta) if ((beta + bitcoin.from_byte_to_int(pub[0])) % 2) else beta
return (x, y)
def export_public_key(pubkey, label):
x, y = decode_pubkey(pubkey)
point = ecdsa.ellipticcurve.Point(curve.curve, x, y)
vk = ecdsa.VerifyingKey.from_public_point(point, curve=curve,
hashfunc=hashfunc)
key_type = 'ecdsa-sha2-nistp256'
curve_name = 'nistp256'
blobs = map(protocol.frame, [key_type, curve_name, '\x04' + vk.to_string()])
b64 = base64.b64encode(''.join(blobs))
return '{} {} {}\n'.format(key_type, b64, label)
def label_addr(ident):
index = '\x00' * 4
addr = index + '{}://{}'.format(ident.proto, ident.host)
h = bytearray(hashfunc(addr).digest())
address_n = [0] * 5
address_n[0] = 13
address_n[1] = h[0] | (h[1] << 8) | (h[2] << 16) | (h[3] << 24)
address_n[2] = h[4] | (h[5] << 8) | (h[6] << 16) | (h[7] << 24)
address_n[3] = h[8] | (h[9] << 8) | (h[10] << 16) | (h[11] << 24)
address_n[4] = h[12] | (h[13] << 8) | (h[14] << 16) | (h[15] << 24)
return [-x for x in address_n] # prime each address component
class Client(object):
proto = 'ssh'
def __init__(self):
device, = HidTransport.enumerate()
client = TrezorClient(HidTransport(device))
log.debug('connected to Trezor #%s', client.get_device_id())
self.client = client
def close(self):
self.client.close()
def _get_identity(self, label):
return IdentityType(host=label, proto=self.proto)
def get_public_key(self, label):
addr = label_addr(self._get_identity(label))
log.info('getting %r SSH public key from Trezor...', label)
node = self.client.get_public_node(addr)
return node.node.public_key
def sign_ssh_challenge(self, label, blob):
ident = self._get_identity(label)
msg = parse_ssh_blob(blob)
request = 'user: "{user}"'.format(**msg)
log.info('confirm %s connection to %r using Trezor...',
request, label)
s = self.client.sign_identity(identity=ident,
challenge_hidden=blob,
challenge_visual=request)
r = protocol.bytes2num(s.signature[:32])
s = protocol.bytes2num(s.signature[32:])
return (r, s)
def parse_ssh_blob(data):
res = {}
if data:
i = io.BytesIO(data)
res['nonce'] = protocol.read_frame(i)
i.read(1) # TBD
res['user'] = protocol.read_frame(i)
res['conn'] = protocol.read_frame(i)
res['auth'] = protocol.read_frame(i)
i.read(1) # TBD
res['key_type'] = protocol.read_frame(i)
res['pubkey'] = protocol.read_frame(i)
log.debug('%s: user %r via %r (%r)',
res['conn'], res['user'], res['auth'], res['key_type'])
return res
Loading…
Cancel
Save