From fb0d0a5f6138d3e55786ae8b32b6ea2d231527c9 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Fri, 8 Jan 2016 20:28:38 +0200 Subject: [PATCH] server: stop the server via a threading.Event It seems that Mac OS does not support calling socket.shutdown(socket.SHUT_RD) on a listening socket (see https://github.com/romanz/trezor-agent/issues/6). The following implementation will set the accept() timeout to 0.1s and stop the server if a threading.Event (named "quit_event") is set by the main thread. --- tox.ini | 1 + trezor_agent/server.py | 39 +++++++++++++++++++++--------- trezor_agent/tests/test_server.py | 40 +++++++++++++++---------------- 3 files changed, 49 insertions(+), 31 deletions(-) diff --git a/tox.ini b/tox.ini index 55ac2bd..b4c0024 100644 --- a/tox.ini +++ b/tox.ini @@ -12,3 +12,4 @@ commands= pylint --reports=no --rcfile .pylintrc trezor_agent coverage run --omit='trezor_agent/__main__.py' --source trezor_agent -m py.test -v trezor_agent coverage report + coverage html diff --git a/trezor_agent/server.py b/trezor_agent/server.py index 3e1675a..77af856 100644 --- a/trezor_agent/server.py +++ b/trezor_agent/server.py @@ -12,6 +12,8 @@ from . import util log = logging.getLogger(__name__) +UNIX_SOCKET_TIMEOUT = 0.1 + def remove_file(path, remove=os.remove, exists=os.path.exists): try: @@ -44,19 +46,31 @@ def handle_connection(conn, handler): util.send(conn, reply) except EOFError: log.debug('goodbye agent') - except: - log.exception('error') - raise -def server_thread(server, handler): +def retry(func, exception_type, quit_event): + while True: + if quit_event.is_set(): + raise StopIteration + try: + return func() + except exception_type: + pass + + +def server_thread(server, handler, quit_event): log.debug('server thread started') + + def accept_connection(): + conn, _ = server.accept() + return conn + while True: log.debug('waiting for connection on %s', server.getsockname()) try: - conn, _ = server.accept() - except socket.error as e: - log.debug('server stopped: %s', e) + conn = retry(accept_connection, socket.timeout, quit_event) + except StopIteration: + log.debug('server stopped') break with contextlib.closing(conn): handle_connection(conn, handler) @@ -64,7 +78,7 @@ def server_thread(server, handler): @contextlib.contextmanager -def spawn(func, **kwargs): +def spawn(func, kwargs): t = threading.Thread(target=func, kwargs=kwargs) t.start() yield @@ -72,20 +86,23 @@ def spawn(func, **kwargs): @contextlib.contextmanager -def serve(public_keys, signer, sock_path=None): +def serve(public_keys, signer, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT): if sock_path is None: sock_path = tempfile.mktemp(prefix='ssh-agent-') keys = [formats.import_public_key(k) for k in public_keys] environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} with unix_domain_socket_server(sock_path) as server: + server.settimeout(timeout) handler = protocol.Handler(keys=keys, signer=signer) - with spawn(server_thread, server=server, handler=handler): + quit_event = threading.Event() + kwargs = dict(server=server, handler=handler, quit_event=quit_event) + with spawn(server_thread, kwargs): try: yield environ finally: log.debug('closing server') - server.shutdown(socket.SHUT_RD) + quit_event.set() def run_process(command, environ, use_shell=False): diff --git a/trezor_agent/tests/test_server.py b/trezor_agent/tests/test_server.py index 1dae44a..64616cc 100644 --- a/trezor_agent/tests/test_server.py +++ b/trezor_agent/tests/test_server.py @@ -1,5 +1,6 @@ import tempfile import socket +import threading import os import io import pytest @@ -16,7 +17,7 @@ def test_socket(): assert not os.path.isfile(path) -class SocketMock(object): +class FakeSocket(object): def __init__(self, data=b''): self.rx = io.BytesIO(data) @@ -34,16 +35,16 @@ class SocketMock(object): def test_handle(): handler = protocol.Handler(keys=[], signer=None) - conn = SocketMock() + conn = FakeSocket() server.handle_connection(conn, handler) msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES]) - conn = SocketMock(util.frame(msg)) + conn = FakeSocket(util.frame(msg)) server.handle_connection(conn, handler) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00' msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES]) - conn = SocketMock(util.frame(msg)) + conn = FakeSocket(util.frame(msg)) server.handle_connection(conn, handler) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00' @@ -51,25 +52,24 @@ def test_handle(): server.handle_connection(conn=None, handler=None) -class ServerMock(object): - - def __init__(self, connections, name): - self.connections = connections - self.name = name +def test_server_thread(): - def getsockname(self): - return self.name + connections = [FakeSocket()] + quit_event = threading.Event() - def accept(self): - if self.connections: - return self.connections.pop(), 'address' - raise socket.error('stop') + class FakeServer(object): + def accept(self): # pylint: disable=no-self-use + if connections: + return connections.pop(), 'address' + quit_event.set() + raise socket.timeout() + def getsockname(self): # pylint: disable=no-self-use + return 'fake_server' -def test_server_thread(): - s = ServerMock(connections=[SocketMock()], name='mock') - h = protocol.Handler(keys=[], signer=None) - server.server_thread(s, h) + server.server_thread(server=FakeServer(), + handler=protocol.Handler(keys=[], signer=None), + quit_event=quit_event) def test_spawn(): @@ -78,7 +78,7 @@ def test_spawn(): def thread(x): obj.append(x) - with server.spawn(thread, x=1): + with server.spawn(thread, dict(x=1)): pass assert obj == [1]