ssh: allow "just-in-time" connection for agent-like behaviour

This would allow launching trezor-agent into the background
during the system startup, and the connecting the device
when the cryptographic operations are required.
nistp521
Roman Zeyde 8 years ago
parent 6672ea9bc4
commit 43c424a402
No known key found for this signature in database
GPG Key ID: 87CAE5FA46917CBB

@ -110,12 +110,10 @@ def git_host(remote_name, attributes):
return '{user}@{host}'.format(**match.groupdict())
def run_server(conn, public_keys, command, debug, timeout):
def run_server(conn, command, debug, timeout):
"""Common code for run_agent and run_git below."""
try:
signer = conn.sign_ssh_challenge
handler = protocol.Handler(keys=public_keys, signer=signer,
debug=debug)
handler = protocol.Handler(conn=conn, debug=debug)
with server.serve(handler=handler, timeout=timeout) as env:
return server.run_process(command=command, environ=env)
except KeyboardInterrupt:
@ -142,13 +140,39 @@ def parse_config(fname):
curve_name=curve_name)
class JustInTimeConnection(object):
"""Connect to the device just before the needed operation."""
def __init__(self, conn_factory, identities):
"""Create a JIT connection object."""
self.conn_factory = conn_factory
self.identities = identities
def public_keys(self):
"""Return a list of SSH public keys (in textual format)."""
conn = self.conn_factory()
return [conn.get_public_key(i) for i in self.identities]
def parse_public_keys(self):
"""Parse SSH public keys into dictionaries."""
public_keys = [formats.import_public_key(pk)
for pk in self.public_keys()]
for pk, identity in zip(public_keys, self.identities):
pk['identity'] = identity
return public_keys
def sign(self, blob, identity):
"""Sign a given blob using the specified identity on the device."""
conn = self.conn_factory()
return conn.sign_ssh_challenge(blob=blob, identity=identity)
@handle_connection_error
def run_agent(client_factory=client.Client):
"""Run ssh-agent using given hardware client factory."""
args = create_agent_parser().parse_args()
util.setup_logging(verbosity=args.verbose)
conn = client_factory(device=device.detect())
if args.identity.startswith('/'):
identities = list(parse_config(fname=args.identity))
else:
@ -158,8 +182,6 @@ def run_agent(client_factory=client.Client):
identity.identity_dict['proto'] = 'ssh'
log.info('identity #%d: %s', index, identity)
public_keys = [conn.get_public_key(i) for i in identities]
if args.connect:
command = ['ssh'] + ssh_args(args.identity) + args.command
elif args.mosh:
@ -171,13 +193,12 @@ def run_agent(client_factory=client.Client):
if use_shell:
command = os.environ['SHELL']
if not command:
for pk in public_keys:
conn = JustInTimeConnection(
conn_factory=lambda: client_factory(device.detect()),
identities=identities)
if command:
return run_server(conn=conn, command=command, debug=args.debug,
timeout=args.timeout)
else:
for pk in conn.public_keys():
sys.stdout.write(pk)
return
public_keys = [formats.import_public_key(pk) for pk in public_keys]
for pk, identity in zip(public_keys, identities):
pk['identity'] = identity
return run_server(conn=conn, public_keys=public_keys, command=command,
debug=args.debug, timeout=args.timeout)

@ -71,14 +71,13 @@ def _legacy_pubs(buf):
class Handler(object):
"""ssh-agent protocol handler."""
def __init__(self, keys, signer, debug=False):
def __init__(self, conn, debug=False):
"""
Create a protocol handler with specified public keys.
Use specified signer function to sign SSH authentication requests.
"""
self.public_keys = keys
self.signer = signer
self.conn = conn
self.debug = debug
self.methods = {
@ -107,7 +106,7 @@ class Handler(object):
def list_pubs(self, buf):
"""SSH v2 public keys are serialized and returned."""
assert not buf.read()
keys = self.public_keys
keys = self.conn.parse_public_keys()
code = util.pack('B', msg_code('SSH2_AGENT_IDENTITIES_ANSWER'))
num = util.pack('L', len(keys))
log.debug('available keys: %s', [k['name'] for k in keys])
@ -129,7 +128,7 @@ class Handler(object):
assert util.read_frame(buf) == b''
assert not buf.read()
for k in self.public_keys:
for k in self.conn.parse_public_keys():
if (k['fingerprint']) == (key['fingerprint']):
log.debug('using key %r (%s)', k['name'], k['fingerprint'])
key = k
@ -140,7 +139,7 @@ class Handler(object):
label = key['name'].decode('ascii') # label should be a string
log.debug('signing %d-byte blob with "%s" key', len(blob), label)
try:
signature = self.signer(blob=blob, identity=key['identity'])
signature = self.conn.sign(blob=blob, identity=key['identity'])
except IOError:
return failure()
log.debug('signature: %r', signature)

@ -1,3 +1,4 @@
import mock
import pytest
from .. import device, formats, protocol
@ -15,16 +16,23 @@ NIST256_SIGN_MSG = b'\r\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\
NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00J\x00\x00\x00!\x00\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1f\x00\x00\x00!\x00q\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8
def fake_connection(keys, signer):
c = mock.Mock()
c.parse_public_keys.return_value = keys
c.sign = signer
return c
def test_list():
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=None)
h = protocol.Handler(fake_connection(keys=[key], signer=None))
reply = h.handle(LIST_MSG)
assert reply == LIST_NIST256_REPLY
def test_unsupported():
h = protocol.Handler(keys=[], signer=None)
h = protocol.Handler(fake_connection(keys=[], signer=None))
reply = h.handle(b'\x09')
assert reply == b'\x00\x00\x00\x01\x05'
@ -38,13 +46,13 @@ def ecdsa_signer(identity, blob):
def test_ecdsa_sign():
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=ecdsa_signer)
h = protocol.Handler(fake_connection(keys=[key], signer=ecdsa_signer))
reply = h.handle(NIST256_SIGN_MSG)
assert reply == NIST256_SIGN_REPLY
def test_sign_missing():
h = protocol.Handler(keys=[], signer=ecdsa_signer)
h = protocol.Handler(fake_connection(keys=[], signer=ecdsa_signer))
with pytest.raises(KeyError):
h.handle(NIST256_SIGN_MSG)
@ -57,7 +65,7 @@ def test_sign_wrong():
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=wrong_signature)
h = protocol.Handler(fake_connection(keys=[key], signer=wrong_signature))
with pytest.raises(ValueError):
h.handle(NIST256_SIGN_MSG)
@ -68,7 +76,7 @@ def test_sign_cancel():
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=cancel_signature)
h = protocol.Handler(fake_connection(keys=[key], signer=cancel_signature))
assert h.handle(NIST256_SIGN_MSG) == protocol.failure()
@ -89,6 +97,6 @@ def ed25519_signer(identity, blob):
def test_ed25519_sign():
key = formats.import_public_key(ED25519_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'ed25519')
h = protocol.Handler(keys=[key], signer=ed25519_signer)
h = protocol.Handler(fake_connection(keys=[key], signer=ed25519_signer))
reply = h.handle(ED25519_SIGN_MSG)
assert reply == ED25519_SIGN_REPLY

@ -37,10 +37,16 @@ class FakeSocket(object):
pass
def empty_device():
c = mock.Mock(spec=['parse_public_keys'])
c.parse_public_keys.return_value = []
return c
def test_handle():
mutex = threading.Lock()
handler = protocol.Handler(keys=[], signer=None)
handler = protocol.Handler(conn=empty_device())
conn = FakeSocket()
server.handle_connection(conn, handler, mutex)
@ -67,7 +73,6 @@ def test_handle():
def test_server_thread():
connections = [FakeSocket()]
quit_event = threading.Event()
@ -81,8 +86,10 @@ def test_server_thread():
def getsockname(self): # pylint: disable=no-self-use
return 'fake_server'
handler = protocol.Handler(keys=[], signer=None),
handle_conn = functools.partial(server.handle_connection, handler=handler)
handler = protocol.Handler(conn=empty_device()),
handle_conn = functools.partial(server.handle_connection,
handler=handler,
mutex=None)
server.server_thread(sock=FakeServer(),
handle_conn=handle_conn,
quit_event=quit_event)
@ -111,7 +118,7 @@ def test_run():
def test_serve_main():
handler = protocol.Handler(keys=[], signer=None)
handler = protocol.Handler(conn=empty_device())
with server.serve(handler=handler, sock_path=None):
pass

Loading…
Cancel
Save