From a60a627b8ac13afda501f6e20a8b5c917b8b8ae2 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Wed, 17 Jun 2015 16:52:11 +0300 Subject: [PATCH] server: serve should be a context manager --- sshagent/server.py | 32 ++++++++++++++--------------- sshagent/trezor_agent.py | 44 +++++++++++++++++----------------------- 2 files changed, 35 insertions(+), 41 deletions(-) diff --git a/sshagent/server.py b/sshagent/server.py index 488173d..1e741f5 100644 --- a/sshagent/server.py +++ b/sshagent/server.py @@ -68,21 +68,8 @@ def spawn(func, **kwargs): t.join() -def run(command, environ): - log.debug('running %r with %r', command, environ) - env = dict(os.environ) - env.update(environ) - try: - p = subprocess.Popen(args=command, env=env) - except OSError as e: - raise OSError('cannot run %r: %s' % (command, e)) - log.debug('subprocess %d is running', p.pid) - ret = p.wait() - log.debug('subprocess %d exited: %d', p.pid, ret) - return ret - - -def serve(key_files, command, signer, sock_path=None): +@contextlib.contextmanager +def serve(key_files, signer, sock_path=None): if sock_path is None: sock_path = tempfile.mktemp(prefix='ssh-agent-') @@ -91,8 +78,21 @@ def serve(key_files, command, signer, sock_path=None): with unix_domain_socket_server(sock_path) as server: with spawn(server_thread, server=server, keys=keys, signer=signer): try: - ret = run(command=command, environ=environ) + yield environ finally: log.debug('closing server') server.shutdown(socket.SHUT_RD) + + +def run_process(command, environ): + log.debug('running %r with %r', command, environ) + env = dict(os.environ) + env.update(environ) + try: + p = subprocess.Popen(args=command, env=env) + except OSError as e: + raise OSError('cannot run %r: %s' % (command, e)) + log.debug('subprocess %d is running', p.pid) + ret = p.wait() + log.debug('subprocess %d exited: %d', p.pid, ret) return ret diff --git a/sshagent/trezor_agent.py b/sshagent/trezor_agent.py index 46aa82f..97d56eb 100644 --- a/sshagent/trezor_agent.py +++ b/sshagent/trezor_agent.py @@ -22,30 +22,24 @@ def main(): level = verbosity[min(args.verbose, len(verbosity) - 1)] logging.basicConfig(level=level, format=fmt) - client = trezor.Client(factory=trezor.TrezorLibrary) - - key_files = [] - for label in args.labels: - pubkey = client.get_public_key(label=label) - key_file = formats.export_public_key(pubkey=pubkey, label=label) - key_files.append(key_file) - - if not args.command: - sys.stdout.write(''.join(key_files)) - return - - signer = client.sign_ssh_challenge - - ret = -1 - try: - ret = server.serve( - key_files=key_files, - command=args.command, - signer=signer) - log.info('exitcode: %d', ret) - except KeyboardInterrupt: - log.info('server stopped') - sys.exit(ret) + with trezor.Client(factory=trezor.TrezorLibrary) as client: + key_files = [] + for label in args.labels: + pubkey = client.get_public_key(label=label) + key_file = formats.export_public_key(pubkey=pubkey, label=label) + key_files.append(key_file) + + if not args.command: + sys.stdout.write(''.join(key_files)) + return + + signer = client.sign_ssh_challenge + + try: + with server.serve(key_files=key_files, signer=signer) as env: + return server.run_process(command=args.command, environ=env) + except KeyboardInterrupt: + log.info('server stopped') if __name__ == '__main__': - main() + sys.exit(main())