ssh: move related code to a separate subdirectory

nistp521
Roman Zeyde 7 years ago
parent 6c2273387d
commit 257992d04c
No known key found for this signature in database
GPG Key ID: 87CAE5FA46917CBB

@ -1,19 +1,15 @@
"""UNIX-domain socket server for ssh-agent implementation."""
import contextlib
import functools
import logging
import os
import socket
import subprocess
import tempfile
import threading
from . import util
log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
def remove_file(path, remove=os.remove, exists=os.path.exists):
"""Remove file, and raise OSError if still exists."""
@ -114,39 +110,6 @@ def spawn(func, kwargs):
t.join()
@contextlib.contextmanager
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
"""
Start the ssh-agent server on a UNIX-domain socket.
If no connection is made during the specified timeout,
retry until the context is over.
"""
ssh_version = subprocess.check_output(['ssh', '-V'],
stderr=subprocess.STDOUT)
log.debug('local SSH version: %r', ssh_version)
if sock_path is None:
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
device_mutex = threading.Lock()
with unix_domain_socket_server(sock_path) as sock:
sock.settimeout(timeout)
quit_event = threading.Event()
handle_conn = functools.partial(handle_connection,
handler=handler,
mutex=device_mutex)
kwargs = dict(sock=sock,
handle_conn=handle_conn,
quit_event=quit_event)
with spawn(server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
quit_event.set()
def run_process(command, environ):
"""
Run the specified process and wait until it finishes.

@ -1,16 +1,23 @@
"""SSH-agent implementation using hardware authentication devices."""
import argparse
import contextlib
import functools
import logging
import os
import re
import subprocess
import sys
import tempfile
import threading
from .. import client, device, formats, protocol, server, util
from .. import device, formats, server, util
from . import client, protocol
log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
def ssh_args(label):
"""Create SSH command for connecting specified server."""
@ -51,7 +58,7 @@ def create_parser():
default=formats.CURVE_NIST256,
help='specify ECDSA curve name: ' + curve_names)
p.add_argument('--timeout',
default=server.UNIX_SOCKET_TIMEOUT, type=float,
default=UNIX_SOCKET_TIMEOUT, type=float,
help='Timeout for accepting SSH client connections')
p.add_argument('--debug', default=False, action='store_true',
help='Log SSH protocol messages for debugging.')
@ -110,11 +117,44 @@ def git_host(remote_name, attributes):
return '{user}@{host}'.format(**match.groupdict())
@contextlib.contextmanager
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
"""
Start the ssh-agent server on a UNIX-domain socket.
If no connection is made during the specified timeout,
retry until the context is over.
"""
ssh_version = subprocess.check_output(['ssh', '-V'],
stderr=subprocess.STDOUT)
log.debug('local SSH version: %r', ssh_version)
if sock_path is None:
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
device_mutex = threading.Lock()
with server.unix_domain_socket_server(sock_path) as sock:
sock.settimeout(timeout)
quit_event = threading.Event()
handle_conn = functools.partial(server.handle_connection,
handler=handler,
mutex=device_mutex)
kwargs = dict(sock=sock,
handle_conn=handle_conn,
quit_event=quit_event)
with server.spawn(server.server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
quit_event.set()
def run_server(conn, command, debug, timeout):
"""Common code for run_agent and run_git below."""
try:
handler = protocol.Handler(conn=conn, debug=debug)
with server.serve(handler=handler, timeout=timeout) as env:
with serve(handler=handler, timeout=timeout) as env:
return server.run_process(command=command, environ=env)
except KeyboardInterrupt:
log.info('server stopped')

@ -0,0 +1 @@
"""Unit-tests for this package."""

@ -8,7 +8,8 @@ import threading
import mock
import pytest
from .. import protocol, server, util
from .. import server, util
from ..ssh import protocol
def test_socket():
@ -117,12 +118,6 @@ def test_run():
server.run_process([''], environ={})
def test_serve_main():
handler = protocol.Handler(conn=empty_device())
with server.serve(handler=handler, sock_path=None):
pass
def test_remove():
path = 'foo.bar'

Loading…
Cancel
Save