|
|
|
@ -5,6 +5,7 @@ import io
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import signal
|
|
|
|
|
import subprocess
|
|
|
|
|
import sys
|
|
|
|
|
import tempfile
|
|
|
|
@ -12,6 +13,7 @@ import threading
|
|
|
|
|
|
|
|
|
|
import pkg_resources
|
|
|
|
|
import configargparse
|
|
|
|
|
import daemon
|
|
|
|
|
|
|
|
|
|
from .. import device, formats, server, util
|
|
|
|
|
from . import client, protocol
|
|
|
|
@ -80,6 +82,8 @@ def create_agent_parser(device_type):
|
|
|
|
|
help='log SSH protocol messages for debugging.')
|
|
|
|
|
|
|
|
|
|
g = p.add_mutually_exclusive_group()
|
|
|
|
|
g.add_argument('-d', '--daemonize', default=False, action='store_true',
|
|
|
|
|
help='Daemonize the agent and print its UNIX socket path')
|
|
|
|
|
g.add_argument('-s', '--shell', default=False, action='store_true',
|
|
|
|
|
help=('run ${SHELL} as subprocess under SSH agent, allowing '
|
|
|
|
|
'regular SSH-based tools to be used in the shell'))
|
|
|
|
@ -96,7 +100,7 @@ def create_agent_parser(device_type):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
|
|
|
|
def serve(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT):
|
|
|
|
|
"""
|
|
|
|
|
Start the ssh-agent server on a UNIX-domain socket.
|
|
|
|
|
|
|
|
|
@ -106,9 +110,6 @@ def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
|
|
|
|
ssh_version = subprocess.check_output(['ssh', '-V'],
|
|
|
|
|
stderr=subprocess.STDOUT)
|
|
|
|
|
log.debug('local SSH version: %r', ssh_version)
|
|
|
|
|
if sock_path is None:
|
|
|
|
|
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
|
|
|
|
|
|
|
|
|
|
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
|
|
|
|
device_mutex = threading.Lock()
|
|
|
|
|
with server.unix_domain_socket_server(sock_path) as sock:
|
|
|
|
@ -128,12 +129,17 @@ def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
|
|
|
|
quit_event.set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_server(conn, command, debug, timeout):
|
|
|
|
|
def run_server(conn, command, sock_path, debug, timeout):
|
|
|
|
|
"""Common code for run_agent and run_git below."""
|
|
|
|
|
try:
|
|
|
|
|
handler = protocol.Handler(conn=conn, debug=debug)
|
|
|
|
|
with serve(handler=handler, timeout=timeout) as env:
|
|
|
|
|
return server.run_process(command=command, environ=env)
|
|
|
|
|
with serve(handler=handler, sock_path=sock_path,
|
|
|
|
|
timeout=timeout) as env:
|
|
|
|
|
if command is None:
|
|
|
|
|
signal.pause() # wait for signal
|
|
|
|
|
return 0
|
|
|
|
|
else:
|
|
|
|
|
return server.run_process(command=command, environ=env)
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
log.info('server stopped')
|
|
|
|
|
|
|
|
|
@ -195,6 +201,11 @@ class JustInTimeConnection(object):
|
|
|
|
|
return conn.sign_ssh_challenge(blob=blob, identity=identity)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def _dummy_context():
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@handle_connection_error
|
|
|
|
|
def main(device_type):
|
|
|
|
|
"""Run ssh-agent using given hardware client factory."""
|
|
|
|
@ -216,10 +227,21 @@ def main(device_type):
|
|
|
|
|
identity.identity_dict['proto'] = u'ssh'
|
|
|
|
|
log.info('identity #%d: %s', index, identity.to_string())
|
|
|
|
|
|
|
|
|
|
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
|
|
|
|
|
|
|
|
|
|
command = None
|
|
|
|
|
context = _dummy_context()
|
|
|
|
|
if args.connect:
|
|
|
|
|
command = ['ssh'] + ssh_args(args.identity) + args.command
|
|
|
|
|
elif args.mosh:
|
|
|
|
|
command = ['mosh'] + mosh_args(args.identity) + args.command
|
|
|
|
|
elif args.daemonize:
|
|
|
|
|
msg = ('SSH_AUTH_SOCK={0}; export SSH_AUTH_SOCK;\n'
|
|
|
|
|
'SSH_AGENT_PID={1}; export SSH_AGENT_PID;\n'
|
|
|
|
|
'echo Agent pid {1};\n'.format(sock_path, os.getpid()))
|
|
|
|
|
sys.stdout.write(msg)
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
context = daemon.DaemonContext()
|
|
|
|
|
else:
|
|
|
|
|
command = args.command
|
|
|
|
|
|
|
|
|
@ -231,9 +253,11 @@ def main(device_type):
|
|
|
|
|
conn = JustInTimeConnection(
|
|
|
|
|
conn_factory=lambda: client.Client(device_type()),
|
|
|
|
|
identities=identities, public_keys=public_keys)
|
|
|
|
|
if command:
|
|
|
|
|
return run_server(conn=conn, command=command, debug=args.debug,
|
|
|
|
|
timeout=args.timeout)
|
|
|
|
|
|
|
|
|
|
if command or args.daemonize:
|
|
|
|
|
with context:
|
|
|
|
|
return run_server(conn=conn, command=command, sock_path=sock_path,
|
|
|
|
|
debug=args.debug, timeout=args.timeout)
|
|
|
|
|
else:
|
|
|
|
|
for pk in conn.public_keys():
|
|
|
|
|
sys.stdout.write(pk)
|
|
|
|
|