@ -110,12 +110,10 @@ def git_host(remote_name, attributes):
return ' {user} @ {host} ' . format ( * * match . groupdict ( ) )
def run_server ( conn , public_keys, command, debug , timeout ) :
def run_server ( conn , command, debug , timeout ) :
""" Common code for run_agent and run_git below. """
try :
signer = conn . sign_ssh_challenge
handler = protocol . Handler ( keys = public_keys , signer = signer ,
debug = debug )
handler = protocol . Handler ( conn = conn , debug = debug )
with server . serve ( handler = handler , timeout = timeout ) as env :
return server . run_process ( command = command , environ = env )
except KeyboardInterrupt :
@ -142,13 +140,39 @@ def parse_config(fname):
curve_name = curve_name )
class JustInTimeConnection ( object ) :
""" Connect to the device just before the needed operation. """
def __init__ ( self , conn_factory , identities ) :
""" Create a JIT connection object. """
self . conn_factory = conn_factory
self . identities = identities
def public_keys ( self ) :
""" Return a list of SSH public keys (in textual format). """
conn = self . conn_factory ( )
return [ conn . get_public_key ( i ) for i in self . identities ]
def parse_public_keys ( self ) :
""" Parse SSH public keys into dictionaries. """
public_keys = [ formats . import_public_key ( pk )
for pk in self . public_keys ( ) ]
for pk , identity in zip ( public_keys , self . identities ) :
pk [ ' identity ' ] = identity
return public_keys
def sign ( self , blob , identity ) :
""" Sign a given blob using the specified identity on the device. """
conn = self . conn_factory ( )
return conn . sign_ssh_challenge ( blob = blob , identity = identity )
@handle_connection_error
def run_agent ( client_factory = client . Client ) :
""" Run ssh-agent using given hardware client factory. """
args = create_agent_parser ( ) . parse_args ( )
util . setup_logging ( verbosity = args . verbose )
conn = client_factory ( device = device . detect ( ) )
if args . identity . startswith ( ' / ' ) :
identities = list ( parse_config ( fname = args . identity ) )
else :
@ -158,8 +182,6 @@ def run_agent(client_factory=client.Client):
identity . identity_dict [ ' proto ' ] = ' ssh '
log . info ( ' identity # %d : %s ' , index , identity )
public_keys = [ conn . get_public_key ( i ) for i in identities ]
if args . connect :
command = [ ' ssh ' ] + ssh_args ( args . identity ) + args . command
elif args . mosh :
@ -171,13 +193,12 @@ def run_agent(client_factory=client.Client):
if use_shell :
command = os . environ [ ' SHELL ' ]
if not command :
for pk in public_keys :
conn = JustInTimeConnection (
conn_factory = lambda : client_factory ( device . detect ( ) ) ,
identities = identities )
if command :
return run_server ( conn = conn , command = command , debug = args . debug ,
timeout = args . timeout )
else :
for pk in conn . public_keys ( ) :
sys . stdout . write ( pk )
return
public_keys = [ formats . import_public_key ( pk ) for pk in public_keys ]
for pk , identity in zip ( public_keys , identities ) :
pk [ ' identity ' ] = identity
return run_server ( conn = conn , public_keys = public_keys , command = command ,
debug = args . debug , timeout = args . timeout )