import functools import io import os import socket import tempfile import threading import mock import pytest from .. import protocol, server, util def test_socket(): path = tempfile.mktemp() with server.unix_domain_socket_server(path): pass assert not os.path.isfile(path) class FakeSocket(object): def __init__(self, data=b''): self.rx = io.BytesIO(data) self.tx = io.BytesIO() def sendall(self, data): self.tx.write(data) def recv(self, size): return self.rx.read(size) def close(self): pass def settimeout(self, value): pass def empty_device(): c = mock.Mock(spec=['parse_public_keys']) c.parse_public_keys.return_value = [] return c def test_handle(): mutex = threading.Lock() handler = protocol.Handler(conn=empty_device()) conn = FakeSocket() server.handle_connection(conn, handler, mutex) msg = bytearray([protocol.msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES')]) conn = FakeSocket(util.frame(msg)) server.handle_connection(conn, handler, mutex) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00' msg = bytearray([protocol.msg_code('SSH2_AGENTC_REQUEST_IDENTITIES')]) conn = FakeSocket(util.frame(msg)) server.handle_connection(conn, handler, mutex) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00' msg = bytearray([protocol.msg_code('SSH2_AGENTC_ADD_IDENTITY')]) conn = FakeSocket(util.frame(msg)) server.handle_connection(conn, handler, mutex) conn.tx.seek(0) reply = util.read_frame(conn.tx) assert reply == util.pack('B', protocol.msg_code('SSH_AGENT_FAILURE')) conn_mock = mock.Mock(spec=FakeSocket) conn_mock.recv.side_effect = [Exception, EOFError] server.handle_connection(conn=conn_mock, handler=None, mutex=mutex) def test_server_thread(): connections = [FakeSocket()] quit_event = threading.Event() class FakeServer(object): def accept(self): # pylint: disable=no-self-use if connections: return connections.pop(), 'address' quit_event.set() raise socket.timeout() def getsockname(self): # pylint: disable=no-self-use return 'fake_server' handler = protocol.Handler(conn=empty_device()), handle_conn = functools.partial(server.handle_connection, handler=handler, mutex=None) server.server_thread(sock=FakeServer(), handle_conn=handle_conn, quit_event=quit_event) def test_spawn(): obj = [] def thread(x): obj.append(x) with server.spawn(thread, dict(x=1)): pass assert obj == [1] def test_run(): assert server.run_process(['true'], environ={}) == 0 assert server.run_process(['false'], environ={}) == 1 assert server.run_process(command=['bash', '-c', 'exit $X'], environ={'X': '42'}) == 42 with pytest.raises(OSError): server.run_process([''], environ={}) def test_serve_main(): handler = protocol.Handler(conn=empty_device()) with server.serve(handler=handler, sock_path=None): pass def test_remove(): path = 'foo.bar' def remove(p): assert p == path server.remove_file(path, remove=remove) def remove_raise(_): raise OSError('boom') server.remove_file(path, remove=remove_raise, exists=lambda _: False) with pytest.raises(OSError): server.remove_file(path, remove=remove_raise, exists=lambda _: True)