diff --git a/libagent/ssh/__init__.py b/libagent/ssh/__init__.py index b0af65e..1c9b248 100644 --- a/libagent/ssh/__init__.py +++ b/libagent/ssh/__init__.py @@ -2,6 +2,7 @@ import argparse import contextlib import functools +import io import logging import os import re @@ -133,27 +134,36 @@ def handle_connection_error(func): return wrapper -def parse_config(fname): +def parse_config(contents): """Parse config file into a list of Identity objects.""" - contents = open(fname).read() for identity_str, curve_name in re.findall(r'\<(.*?)\|(.*?)\>', contents): yield device.interface.Identity(identity_str=identity_str, curve_name=curve_name) +def import_public_keys(contents): + """Load (previously exported) SSH public keys from a file's contents.""" + for line in io.StringIO(contents): + # Verify this line represents valid SSH public key + formats.import_public_key(line) + yield line + + class JustInTimeConnection(object): """Connect to the device just before the needed operation.""" - def __init__(self, conn_factory, identities): + def __init__(self, conn_factory, identities, public_keys=None): """Create a JIT connection object.""" self.conn_factory = conn_factory self.identities = identities - self.public_keys = util.memoize(self._public_keys) # a simple cache + self.public_keys_cache = public_keys - def _public_keys(self): + def public_keys(self): """Return a list of SSH public keys (in textual format).""" - conn = self.conn_factory() - return conn.export_public_keys(self.identities) + if not self.public_keys_cache: + conn = self.conn_factory() + self.public_keys_cache = conn.export_public_keys(self.identities) + return self.public_keys_cache def parse_public_keys(self): """Parse SSH public keys into dictionaries.""" @@ -175,8 +185,14 @@ def main(device_type): args = create_agent_parser().parse_args() util.setup_logging(verbosity=args.verbose) + public_keys = None if args.identity.startswith('/'): - identities = list(parse_config(fname=args.identity)) + filename = args.identity + contents = open(filename, 'rb').read().decode('ascii') + # Allow loading previously exported SSH public keys + if filename.endswith('.pub'): + public_keys = list(import_public_keys(contents)) + identities = list(parse_config(contents)) else: identities = [device.interface.Identity( identity_str=args.identity, curve_name=args.ecdsa_curve_name)] @@ -197,7 +213,7 @@ def main(device_type): conn = JustInTimeConnection( conn_factory=lambda: client.Client(device_type()), - identities=identities) + identities=identities, public_keys=public_keys) if command: return run_server(conn=conn, command=command, debug=args.debug, timeout=args.timeout)