|
|
|
@ -2,6 +2,7 @@
|
|
|
|
|
import argparse
|
|
|
|
|
import contextlib
|
|
|
|
|
import functools
|
|
|
|
|
import io
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
@ -133,27 +134,36 @@ def handle_connection_error(func):
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_config(fname):
|
|
|
|
|
def parse_config(contents):
|
|
|
|
|
"""Parse config file into a list of Identity objects."""
|
|
|
|
|
contents = open(fname).read()
|
|
|
|
|
for identity_str, curve_name in re.findall(r'\<(.*?)\|(.*?)\>', contents):
|
|
|
|
|
yield device.interface.Identity(identity_str=identity_str,
|
|
|
|
|
curve_name=curve_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def import_public_keys(contents):
|
|
|
|
|
"""Load (previously exported) SSH public keys from a file's contents."""
|
|
|
|
|
for line in io.StringIO(contents):
|
|
|
|
|
# Verify this line represents valid SSH public key
|
|
|
|
|
formats.import_public_key(line)
|
|
|
|
|
yield line
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class JustInTimeConnection(object):
|
|
|
|
|
"""Connect to the device just before the needed operation."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, conn_factory, identities):
|
|
|
|
|
def __init__(self, conn_factory, identities, public_keys=None):
|
|
|
|
|
"""Create a JIT connection object."""
|
|
|
|
|
self.conn_factory = conn_factory
|
|
|
|
|
self.identities = identities
|
|
|
|
|
self.public_keys = util.memoize(self._public_keys) # a simple cache
|
|
|
|
|
self.public_keys_cache = public_keys
|
|
|
|
|
|
|
|
|
|
def _public_keys(self):
|
|
|
|
|
def public_keys(self):
|
|
|
|
|
"""Return a list of SSH public keys (in textual format)."""
|
|
|
|
|
conn = self.conn_factory()
|
|
|
|
|
return conn.export_public_keys(self.identities)
|
|
|
|
|
if not self.public_keys_cache:
|
|
|
|
|
conn = self.conn_factory()
|
|
|
|
|
self.public_keys_cache = conn.export_public_keys(self.identities)
|
|
|
|
|
return self.public_keys_cache
|
|
|
|
|
|
|
|
|
|
def parse_public_keys(self):
|
|
|
|
|
"""Parse SSH public keys into dictionaries."""
|
|
|
|
@ -175,8 +185,14 @@ def main(device_type):
|
|
|
|
|
args = create_agent_parser().parse_args()
|
|
|
|
|
util.setup_logging(verbosity=args.verbose)
|
|
|
|
|
|
|
|
|
|
public_keys = None
|
|
|
|
|
if args.identity.startswith('/'):
|
|
|
|
|
identities = list(parse_config(fname=args.identity))
|
|
|
|
|
filename = args.identity
|
|
|
|
|
contents = open(filename, 'rb').read().decode('ascii')
|
|
|
|
|
# Allow loading previously exported SSH public keys
|
|
|
|
|
if filename.endswith('.pub'):
|
|
|
|
|
public_keys = list(import_public_keys(contents))
|
|
|
|
|
identities = list(parse_config(contents))
|
|
|
|
|
else:
|
|
|
|
|
identities = [device.interface.Identity(
|
|
|
|
|
identity_str=args.identity, curve_name=args.ecdsa_curve_name)]
|
|
|
|
@ -197,7 +213,7 @@ def main(device_type):
|
|
|
|
|
|
|
|
|
|
conn = JustInTimeConnection(
|
|
|
|
|
conn_factory=lambda: client.Client(device_type()),
|
|
|
|
|
identities=identities)
|
|
|
|
|
identities=identities, public_keys=public_keys)
|
|
|
|
|
if command:
|
|
|
|
|
return run_server(conn=conn, command=command, debug=args.debug,
|
|
|
|
|
timeout=args.timeout)
|
|
|
|
|