server: stop the server via a threading.Event

It seems that Mac OS does not support calling socket.shutdown(socket.SHUT_RD)
on a listening socket (see https://github.com/romanz/trezor-agent/issues/6).
The following implementation will set the accept() timeout to 0.1s and stop
the server if a threading.Event (named "quit_event") is set by the main thread.
nistp521
Roman Zeyde 9 years ago
parent 7ea20c7009
commit fb0d0a5f61

@ -12,3 +12,4 @@ commands=
pylint --reports=no --rcfile .pylintrc trezor_agent pylint --reports=no --rcfile .pylintrc trezor_agent
coverage run --omit='trezor_agent/__main__.py' --source trezor_agent -m py.test -v trezor_agent coverage run --omit='trezor_agent/__main__.py' --source trezor_agent -m py.test -v trezor_agent
coverage report coverage report
coverage html

@ -12,6 +12,8 @@ from . import util
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
def remove_file(path, remove=os.remove, exists=os.path.exists): def remove_file(path, remove=os.remove, exists=os.path.exists):
try: try:
@ -44,19 +46,31 @@ def handle_connection(conn, handler):
util.send(conn, reply) util.send(conn, reply)
except EOFError: except EOFError:
log.debug('goodbye agent') log.debug('goodbye agent')
except:
log.exception('error')
raise
def server_thread(server, handler): def retry(func, exception_type, quit_event):
while True:
if quit_event.is_set():
raise StopIteration
try:
return func()
except exception_type:
pass
def server_thread(server, handler, quit_event):
log.debug('server thread started') log.debug('server thread started')
def accept_connection():
conn, _ = server.accept()
return conn
while True: while True:
log.debug('waiting for connection on %s', server.getsockname()) log.debug('waiting for connection on %s', server.getsockname())
try: try:
conn, _ = server.accept() conn = retry(accept_connection, socket.timeout, quit_event)
except socket.error as e: except StopIteration:
log.debug('server stopped: %s', e) log.debug('server stopped')
break break
with contextlib.closing(conn): with contextlib.closing(conn):
handle_connection(conn, handler) handle_connection(conn, handler)
@ -64,7 +78,7 @@ def server_thread(server, handler):
@contextlib.contextmanager @contextlib.contextmanager
def spawn(func, **kwargs): def spawn(func, kwargs):
t = threading.Thread(target=func, kwargs=kwargs) t = threading.Thread(target=func, kwargs=kwargs)
t.start() t.start()
yield yield
@ -72,20 +86,23 @@ def spawn(func, **kwargs):
@contextlib.contextmanager @contextlib.contextmanager
def serve(public_keys, signer, sock_path=None): def serve(public_keys, signer, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
if sock_path is None: if sock_path is None:
sock_path = tempfile.mktemp(prefix='ssh-agent-') sock_path = tempfile.mktemp(prefix='ssh-agent-')
keys = [formats.import_public_key(k) for k in public_keys] keys = [formats.import_public_key(k) for k in public_keys]
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
with unix_domain_socket_server(sock_path) as server: with unix_domain_socket_server(sock_path) as server:
server.settimeout(timeout)
handler = protocol.Handler(keys=keys, signer=signer) handler = protocol.Handler(keys=keys, signer=signer)
with spawn(server_thread, server=server, handler=handler): quit_event = threading.Event()
kwargs = dict(server=server, handler=handler, quit_event=quit_event)
with spawn(server_thread, kwargs):
try: try:
yield environ yield environ
finally: finally:
log.debug('closing server') log.debug('closing server')
server.shutdown(socket.SHUT_RD) quit_event.set()
def run_process(command, environ, use_shell=False): def run_process(command, environ, use_shell=False):

@ -1,5 +1,6 @@
import tempfile import tempfile
import socket import socket
import threading
import os import os
import io import io
import pytest import pytest
@ -16,7 +17,7 @@ def test_socket():
assert not os.path.isfile(path) assert not os.path.isfile(path)
class SocketMock(object): class FakeSocket(object):
def __init__(self, data=b''): def __init__(self, data=b''):
self.rx = io.BytesIO(data) self.rx = io.BytesIO(data)
@ -34,16 +35,16 @@ class SocketMock(object):
def test_handle(): def test_handle():
handler = protocol.Handler(keys=[], signer=None) handler = protocol.Handler(keys=[], signer=None)
conn = SocketMock() conn = FakeSocket()
server.handle_connection(conn, handler) server.handle_connection(conn, handler)
msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES]) msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES])
conn = SocketMock(util.frame(msg)) conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler) server.handle_connection(conn, handler)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00' assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES]) msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES])
conn = SocketMock(util.frame(msg)) conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler) server.handle_connection(conn, handler)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00' assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
@ -51,25 +52,24 @@ def test_handle():
server.handle_connection(conn=None, handler=None) server.handle_connection(conn=None, handler=None)
class ServerMock(object): def test_server_thread():
def __init__(self, connections, name):
self.connections = connections
self.name = name
def getsockname(self): connections = [FakeSocket()]
return self.name quit_event = threading.Event()
def accept(self): class FakeServer(object):
if self.connections: def accept(self): # pylint: disable=no-self-use
return self.connections.pop(), 'address' if connections:
raise socket.error('stop') return connections.pop(), 'address'
quit_event.set()
raise socket.timeout()
def getsockname(self): # pylint: disable=no-self-use
return 'fake_server'
def test_server_thread(): server.server_thread(server=FakeServer(),
s = ServerMock(connections=[SocketMock()], name='mock') handler=protocol.Handler(keys=[], signer=None),
h = protocol.Handler(keys=[], signer=None) quit_event=quit_event)
server.server_thread(s, h)
def test_spawn(): def test_spawn():
@ -78,7 +78,7 @@ def test_spawn():
def thread(x): def thread(x):
obj.append(x) obj.append(x)
with server.spawn(thread, x=1): with server.spawn(thread, dict(x=1)):
pass pass
assert obj == [1] assert obj == [1]

Loading…
Cancel
Save