server: stop the server via a threading.Event

It seems that Mac OS does not support calling socket.shutdown(socket.SHUT_RD)
on a listening socket (see https://github.com/romanz/trezor-agent/issues/6).
The following implementation will set the accept() timeout to 0.1s and stop
the server if a threading.Event (named "quit_event") is set by the main thread.
nistp521
Roman Zeyde 8 years ago
parent 7ea20c7009
commit fb0d0a5f61

@ -12,3 +12,4 @@ commands=
pylint --reports=no --rcfile .pylintrc trezor_agent
coverage run --omit='trezor_agent/__main__.py' --source trezor_agent -m py.test -v trezor_agent
coverage report
coverage html

@ -12,6 +12,8 @@ from . import util
log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
def remove_file(path, remove=os.remove, exists=os.path.exists):
try:
@ -44,19 +46,31 @@ def handle_connection(conn, handler):
util.send(conn, reply)
except EOFError:
log.debug('goodbye agent')
except:
log.exception('error')
raise
def server_thread(server, handler):
def retry(func, exception_type, quit_event):
while True:
if quit_event.is_set():
raise StopIteration
try:
return func()
except exception_type:
pass
def server_thread(server, handler, quit_event):
log.debug('server thread started')
def accept_connection():
conn, _ = server.accept()
return conn
while True:
log.debug('waiting for connection on %s', server.getsockname())
try:
conn, _ = server.accept()
except socket.error as e:
log.debug('server stopped: %s', e)
conn = retry(accept_connection, socket.timeout, quit_event)
except StopIteration:
log.debug('server stopped')
break
with contextlib.closing(conn):
handle_connection(conn, handler)
@ -64,7 +78,7 @@ def server_thread(server, handler):
@contextlib.contextmanager
def spawn(func, **kwargs):
def spawn(func, kwargs):
t = threading.Thread(target=func, kwargs=kwargs)
t.start()
yield
@ -72,20 +86,23 @@ def spawn(func, **kwargs):
@contextlib.contextmanager
def serve(public_keys, signer, sock_path=None):
def serve(public_keys, signer, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
if sock_path is None:
sock_path = tempfile.mktemp(prefix='ssh-agent-')
keys = [formats.import_public_key(k) for k in public_keys]
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
with unix_domain_socket_server(sock_path) as server:
server.settimeout(timeout)
handler = protocol.Handler(keys=keys, signer=signer)
with spawn(server_thread, server=server, handler=handler):
quit_event = threading.Event()
kwargs = dict(server=server, handler=handler, quit_event=quit_event)
with spawn(server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
server.shutdown(socket.SHUT_RD)
quit_event.set()
def run_process(command, environ, use_shell=False):

@ -1,5 +1,6 @@
import tempfile
import socket
import threading
import os
import io
import pytest
@ -16,7 +17,7 @@ def test_socket():
assert not os.path.isfile(path)
class SocketMock(object):
class FakeSocket(object):
def __init__(self, data=b''):
self.rx = io.BytesIO(data)
@ -34,16 +35,16 @@ class SocketMock(object):
def test_handle():
handler = protocol.Handler(keys=[], signer=None)
conn = SocketMock()
conn = FakeSocket()
server.handle_connection(conn, handler)
msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES])
conn = SocketMock(util.frame(msg))
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES])
conn = SocketMock(util.frame(msg))
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
@ -51,25 +52,24 @@ def test_handle():
server.handle_connection(conn=None, handler=None)
class ServerMock(object):
def __init__(self, connections, name):
self.connections = connections
self.name = name
def test_server_thread():
def getsockname(self):
return self.name
connections = [FakeSocket()]
quit_event = threading.Event()
def accept(self):
if self.connections:
return self.connections.pop(), 'address'
raise socket.error('stop')
class FakeServer(object):
def accept(self): # pylint: disable=no-self-use
if connections:
return connections.pop(), 'address'
quit_event.set()
raise socket.timeout()
def getsockname(self): # pylint: disable=no-self-use
return 'fake_server'
def test_server_thread():
s = ServerMock(connections=[SocketMock()], name='mock')
h = protocol.Handler(keys=[], signer=None)
server.server_thread(s, h)
server.server_thread(server=FakeServer(),
handler=protocol.Handler(keys=[], signer=None),
quit_event=quit_event)
def test_spawn():
@ -78,7 +78,7 @@ def test_spawn():
def thread(x):
obj.append(x)
with server.spawn(thread, x=1):
with server.spawn(thread, dict(x=1)):
pass
assert obj == [1]

Loading…
Cancel
Save