diff --git a/sshagent/trezor.py b/sshagent/trezor.py index a0c7596..228433e 100644 --- a/sshagent/trezor.py +++ b/sshagent/trezor.py @@ -51,9 +51,11 @@ class Client(object): self.client.clear_session() self.client.close() - def get_public_key(self, label): - identity = self.factory.parse_identity(label) - label = _identity_to_string(identity) # update label after parsing + def get_identity(self, label): + return self.factory.parse_identity(label) + + def get_public_key(self, identity): + label = _identity_to_string(identity) log.info('getting "%s" public key from Trezor...', label) addr = _get_address(identity) node = self.client.get_public_node(addr, self.curve_name) @@ -64,13 +66,12 @@ class Client(object): def sign_ssh_challenge(self, label, blob): identity = self.factory.parse_identity(label) msg = _parse_ssh_blob(blob) - request = 'user: "{user}"'.format(**msg) - log.info('confirm %s connection to %r using Trezor...', - request, label) + log.info('confirm user %s connection to %r using Trezor...', + msg['user'], label) s = self.client.sign_identity(identity=identity, challenge_hidden=blob, - challenge_visual=request, + challenge_visual='', ecdsa_curve_name=self.curve_name) assert len(s.signature) == 65 assert s.signature[0] == b'\x00' diff --git a/sshagent/trezor_agent.py b/sshagent/trezor_agent.py index dfd8034..d8f39ed 100644 --- a/sshagent/trezor_agent.py +++ b/sshagent/trezor_agent.py @@ -17,7 +17,9 @@ def main(): g.add_argument('-v', '--verbose', default=0, action='count') g.add_argument('-q', '--quiet', default=False, action='store_true') - p.add_argument('identity', type=str, + p.add_argument('-p', '--public-key', default=False, action='store_true') + + p.add_argument('-i', '--identity', type=str, help='proto://[user@]host[:port][/path]') p.add_argument('command', type=str, nargs='*', help='command to run under the SSH agent') @@ -32,17 +34,22 @@ def main(): logging.basicConfig(level=loglevel, format=fmt) with trezor.Client(factory=trezor.TrezorLibrary) as client: - public_keys = [client.get_public_key(i) for i in args.identity] + identity = client.get_identity(label=args.identity) + public_key = client.get_public_key(identity=identity) + if args.public_key: + sys.stdout.write(public_key) + return - command = args.command + command, use_shell = args.command, False if not command: - command = os.environ['SHELL'] - log.info('using %r shell', command) + command, use_shell = os.environ['SHELL'], True signer = client.sign_ssh_challenge try: - with server.serve(public_keys=public_keys, signer=signer) as env: - return server.run_process(command=command, environ=env) + with server.serve(public_keys=[public_key], signer=signer) as env: + return server.run_process( + command=command, environ=env, use_shell=use_shell + ) except KeyboardInterrupt: log.info('server stopped')