Allow firmware update by version or latest from releases.json

pull/1/head
Chris Rico 9 years ago committed by Pavol Rusnak
parent a20f976721
commit 6475f98b1e
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D

@ -20,22 +20,22 @@ def parse_args(commands):
# parser.add_argument('-d', '--debug', dest='debug', action='store_true', help='Enable low-level debugging') # parser.add_argument('-d', '--debug', dest='debug', action='store_true', help='Enable low-level debugging')
cmdparser = parser.add_subparsers(title='Available commands') cmdparser = parser.add_subparsers(title='Available commands')
for cmd in commands._list_commands(): for cmd in commands._list_commands():
func = object.__getattribute__(commands, cmd) func = object.__getattribute__(commands, cmd)
try: try:
arguments = func.arguments arguments = func.arguments
except AttributeError: except AttributeError:
arguments = ((('params',), {'nargs': '*'}),) arguments = ((('params',), {'nargs': '*'}),)
item = cmdparser.add_parser(cmd, help=func.help) item = cmdparser.add_parser(cmd, help=func.help)
for arg in arguments: for arg in arguments:
item.add_argument(*arg[0], **arg[1]) item.add_argument(*arg[0], **arg[1])
item.set_defaults(func=func) item.set_defaults(func=func)
item.set_defaults(cmd=cmd) item.set_defaults(cmd=cmd)
return parser.parse_args() return parser.parse_args()
def get_transport(transport_string, path, **kwargs): def get_transport(transport_string, path, **kwargs):
@ -54,7 +54,7 @@ def get_transport(transport_string, path, **kwargs):
return HidTransport(d, **kwargs) return HidTransport(d, **kwargs)
raise Exception("Device not found") raise Exception("Device not found")
if transport_string == 'serial': if transport_string == 'serial':
from trezorlib.transport_serial import SerialTransport from trezorlib.transport_serial import SerialTransport
return SerialTransport(path, **kwargs) return SerialTransport(path, **kwargs)
@ -62,7 +62,7 @@ def get_transport(transport_string, path, **kwargs):
if transport_string == 'pipe': if transport_string == 'pipe':
from trezorlib.transport_pipe import PipeTransport from trezorlib.transport_pipe import PipeTransport
return PipeTransport(path, is_device=False, **kwargs) return PipeTransport(path, is_device=False, **kwargs)
if transport_string == 'socket': if transport_string == 'socket':
from trezorlib.transport_socket import SocketTransportClient from trezorlib.transport_socket import SocketTransportClient
return SocketTransportClient(path, **kwargs) return SocketTransportClient(path, **kwargs)
@ -70,29 +70,29 @@ def get_transport(transport_string, path, **kwargs):
if transport_string == 'bridge': if transport_string == 'bridge':
from trezorlib.transport_bridge import BridgeTransport from trezorlib.transport_bridge import BridgeTransport
return BridgeTransport(path, **kwargs) return BridgeTransport(path, **kwargs)
if transport_string == 'fake': if transport_string == 'fake':
from trezorlib.transport_fake import FakeTransport from trezorlib.transport_fake import FakeTransport
return FakeTransport(path, **kwargs) return FakeTransport(path, **kwargs)
raise NotImplemented("Unknown transport") raise NotImplemented("Unknown transport")
class Commands(object): class Commands(object):
def __init__(self, client): def __init__(self, client):
self.client = client self.client = client
@classmethod @classmethod
def _list_commands(cls): def _list_commands(cls):
return [ x for x in dir(cls) if not x.startswith('_') ] return [ x for x in dir(cls) if not x.startswith('_') ]
def list(self, args): def list(self, args):
# Fake method for advertising 'list' command # Fake method for advertising 'list' command
pass pass
def get_address(self, args): def get_address(self, args):
address_n = self.client.expand_path(args.n) address_n = self.client.expand_path(args.n)
return self.client.get_address(args.coin, address_n, args.show_display) return self.client.get_address(args.coin, address_n, args.show_display)
def get_entropy(self, args): def get_entropy(self, args):
return binascii.hexlify(self.client.get_entropy(args.size)) return binascii.hexlify(self.client.get_entropy(args.size))
@ -108,7 +108,7 @@ class Commands(object):
def get_public_node(self, args): def get_public_node(self, args):
address_n = self.client.expand_path(args.n) address_n = self.client.expand_path(args.n)
return self.client.get_public_node(address_n) return self.client.get_public_node(address_n)
def set_label(self, args): def set_label(self, args):
return self.client.apply_settings(label=args.label) return self.client.apply_settings(label=args.label)
@ -203,9 +203,6 @@ class Commands(object):
return ret return ret
def firmware_update(self, args): def firmware_update(self, args):
if not args.file and not args.url:
raise Exception("Must provide firmware filename or URL")
if args.file: if args.file:
fp = open(args.file, 'r') fp = open(args.file, 'r')
elif args.url: elif args.url:
@ -213,18 +210,31 @@ class Commands(object):
resp = urllib.urlretrieve(args.url) resp = urllib.urlretrieve(args.url)
fp = open(resp[0], 'r') fp = open(resp[0], 'r')
urllib.urlcleanup() # We still keep file pointer open urllib.urlcleanup() # We still keep file pointer open
else:
resp = urllib.urlopen("https://mytrezor.com/data/firmware/releases.json")
releases = json.load(resp)
version = lambda r: r['version']
version_string = lambda r: ".".join(map(str, version(r)))
if args.version:
release = next((r for r in releases if version_string(r) == args.version))
else:
release = max(releases, key=version)
print "No file, url, or version given. Fetching latest version: %s" % version_string(release)
print "Firmware fingerprint: %s" % release['fingerprint']
args.url = release['url']
return self.firmware_update(args)
if fp.read(8) == '54525a52': if fp.read(8) == '54525a52':
print "Converting firmware to binary" print "Converting firmware to binary"
fp.seek(0) fp.seek(0)
fp_old = fp fp_old = fp
fp = tempfile.TemporaryFile() fp = tempfile.TemporaryFile()
fp.write(binascii.unhexlify(fp_old.read())) fp.write(binascii.unhexlify(fp_old.read()))
fp_old.close() fp_old.close()
fp.seek(0) fp.seek(0)
if fp.read(4) != 'TRZR': if fp.read(4) != 'TRZR':
raise Exception("Trezor firmware header expected") raise Exception("Trezor firmware header expected")
@ -262,7 +272,7 @@ class Commands(object):
(('-n', '-address'), {'type': str}), (('-n', '-address'), {'type': str}),
(('-d', '--show-display'), {'action': 'store_true', 'default': False}), (('-d', '--show-display'), {'action': 'store_true', 'default': False}),
) )
get_entropy.arguments = ( get_entropy.arguments = (
(('size',), {'type': int}), (('size',), {'type': int}),
) )
@ -277,7 +287,7 @@ class Commands(object):
(('-p', '--pin-protection'), {'action': 'store_true', 'default': False}), (('-p', '--pin-protection'), {'action': 'store_true', 'default': False}),
(('-r', '--passphrase-protection'), {'action': 'store_true', 'default': False}), (('-r', '--passphrase-protection'), {'action': 'store_true', 'default': False}),
) )
set_label.arguments = ( set_label.arguments = (
(('-l', '--label',), {'type': str, 'default': ''}), (('-l', '--label',), {'type': str, 'default': ''}),
# (('-c', '--clear'), {'action': 'store_true', 'default': False}) # (('-c', '--clear'), {'action': 'store_true', 'default': False})
@ -289,7 +299,7 @@ class Commands(object):
change_pin.arguments = ( change_pin.arguments = (
(('-r', '--remove'), {'action': 'store_true', 'default': False}), (('-r', '--remove'), {'action': 'store_true', 'default': False}),
) )
wipe_device.arguments = () wipe_device.arguments = ()
recovery_device.arguments = ( recovery_device.arguments = (
@ -358,6 +368,7 @@ class Commands(object):
firmware_update.arguments = ( firmware_update.arguments = (
(('-f', '--file'), {'type': str}), (('-f', '--file'), {'type': str}),
(('-u', '--url'), {'type': str}), (('-u', '--url'), {'type': str}),
(('-n', '--version'), {'type': str}),
) )
def list_usb(): def list_usb():
@ -418,7 +429,7 @@ def qt_pin_func(input_text, message=None):
# let's fallback to default pin_func implementation # let's fallback to default pin_func implementation
return pin_func(input_text, message) return pin_func(input_text, message)
''' '''
def main(): def main():
args = parse_args(Commands) args = parse_args(Commands)
@ -441,13 +452,13 @@ def main():
client = TrezorClient(transport) client = TrezorClient(transport)
cmds = Commands(client) cmds = Commands(client)
res = args.func(cmds, args) res = args.func(cmds, args)
if args.json: if args.json:
print json.dumps(res, sort_keys=True, indent=4) print json.dumps(res, sort_keys=True, indent=4)
else: else:
print res print res
if __name__ == '__main__': if __name__ == '__main__':
main() main()

Loading…
Cancel
Save