diff --git a/setup.py b/setup.py index 09a87e5..fd1fd5b 100755 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup setup( name='trezor', - version='0.6.13', + version='0.7.0', author='Bitcoin TREZOR', author_email='info@bitcointrezor.com', description='Python library for communicating with TREZOR Bitcoin Hardware Wallet', diff --git a/trezorlib/transport.py b/trezorlib/transport.py index dc59619..3e26089 100644 --- a/trezorlib/transport.py +++ b/trezorlib/transport.py @@ -1,5 +1,5 @@ import struct -from . import mapping +import mapping class NotImplementedException(Exception): pass @@ -9,7 +9,6 @@ class ConnectionError(Exception): class Transport(object): def __init__(self, device, *args, **kwargs): - print("Transport constructor") self.device = device self.session_id = 0 self.session_depth = 0 @@ -79,9 +78,7 @@ class Transport(object): if msg_type == 'protobuf': return data else: - print mapping.get_class(msg_type) inst = mapping.get_class(msg_type)() - print inst, data inst.ParseFromString(bytes(data)) return inst @@ -150,7 +147,7 @@ class TransportV1(Transport): headerlen = struct.calcsize(">HL") (msg_type, datalen) = struct.unpack(">HL", chunk[3:3 + headerlen]) except: - raise Exception("Cannot parse header length") + raise Exception("Cannot parse header") data = chunk[3 + headerlen:] return (msg_type, datalen, data) @@ -163,12 +160,72 @@ class TransportV1(Transport): class TransportV2(Transport): def write(self, msg): - ser = msg.SerializeToString() - raise NotImplemented() + data = bytearray(msg.SerializeToString()) + + header1 = struct.pack(">L", self.session_id) + header2 = struct.pack(">LL", mapping.get_type(msg), len(data)) + + data = header2 + data + + first = True + while len(data): + if first: + # Magic characters, header1, header2, data padded to 64 bytes + datalen = 62 - len(header1) + chunk = b'?!' + header1 + data[:datalen] + b'\0' * (datalen - len(data[:datalen])) + else: + # Magic characters, header1, data padded to 64 bytes + datalen = 63 - len(header1) + chunk = b'?' + header1 + data[:datalen] + b'\0' * (datalen - len(data[:datalen])) + + self._write_chunk(chunk) + data = data[datalen:] + first = False def _read(self): - pass + chunk = self._read_chunk() + (session_id, msg_type, datalen, data) = self.parse_first(chunk) + + while len(data) < datalen: + chunk = self._read_chunk() + (session_id2, data) = self.parse_next(chunk) + + if session_id != session_id2: + raise Exception("Session id mismatch") + + data.extend(data) + + # Strip padding zeros + data = data[:datalen] + return (session_id, msg_type, data) + + def parse_first(self, chunk): + if chunk[:2] != b"?!": + raise Exception("Unexpected magic characters") + + try: + headerlen = struct.calcsize(">LLL") + (session_id, msg_type, datalen) = struct.unpack(">LLL", chunk[2:2 + headerlen]) + except: + raise Exception("Cannot parse header") + + data = chunk[2 + headerlen:] + return (session_id, msg_type, datalen, data) + + def parse_next(self, chunk): + if chunk[0:1] != b"?": + raise Exception("Unexpected magic characters") + + try: + headerlen = struct.calcsize(">L") + session_id = struct.unpack(">L", chunk[1:1 + headerlen]) + except: + raise Exception("Cannot parse header") + + data = chunk[1 + headerlen:] + return (session_id, data) + ''' def read_headers(self, read_f): c = read_f.read(2) if c != b"?!": @@ -180,5 +237,5 @@ class TransportV2(Transport): except: raise Exception("Cannot parse header length") - print datalen return (0, msg_type, datalen) + ''' diff --git a/trezorlib/transport_pipe.py b/trezorlib/transport_pipe.py index 5faa219..d59db4e 100644 --- a/trezorlib/transport_pipe.py +++ b/trezorlib/transport_pipe.py @@ -1,12 +1,12 @@ from __future__ import print_function import os from select import select -from .transport import Transport +from transport import TransportV1 """PipeTransport implements fake wire transport over local named pipe. Use this transport for talking with trezor simulator.""" -class PipeTransport(Transport): +class PipeTransport(TransportV1): def __init__(self, device, is_device, *args, **kwargs): self.is_device = is_device # Set True if act as device @@ -39,22 +39,36 @@ class PipeTransport(Transport): os.unlink(self.filename_read) os.unlink(self.filename_write) - def ready_to_read(self): + def _ready_to_read(self): rlist, _, _ = select([self.read_f], [], [], 0) return len(rlist) > 0 - def _write(self, msg, protobuf_msg): + def _write_chunk(self, chunk): + if len(chunk) != 64: + raise Exception("Unexpected data length") + try: - self.write_f.write(msg) + self.write_f.write(chunk) self.write_f.flush() except OSError: print("Error while writing to socket") raise - def _read(self): - try: - (msg_type, datalen) = self._read_headers(self.read_f) - return (msg_type, self.read_f.read(datalen)) - except IOError: - print("Failed to read from device") - raise + def _read_chunk(self): + while True: + try: + data = self.read_f.read(64) + except IOError: + print("Failed to read from device") + raise + + if not len(data): + time.sleep(0.001) + continue + + break + + if len(data) != 64: + raise Exception("Unexpected chunk size: %d" % len(data)) + + return bytearray(data) diff --git a/trezorlib/transport_udp.py b/trezorlib/transport_udp.py index 9c7dd5c..27e95da 100644 --- a/trezorlib/transport_udp.py +++ b/trezorlib/transport_udp.py @@ -3,20 +3,10 @@ import socket from select import select import time -from .transport import Transport, ConnectionError +from .transport import TransportV2, ConnectionError -class FakeRead(object): - # Let's pretend we have a file-like interface - def __init__(self, func): - self.func = func - - def read(self, size): - return self.func(size) - -class UdpTransport(Transport): +class UdpTransport(TransportV2): def __init__(self, device, *args, **kwargs): - self.buffer = '' - device = device.split(':') if len(device) < 2: if not device[0]: @@ -33,13 +23,13 @@ class UdpTransport(Transport): def _open(self): self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket.connect(self.device) + self.socket.settimeout(10) def _close(self): self.socket.close() self.socket = None - self.buffer = '' - def ready_to_read(self): + def _ready_to_read(self): rlist, _, _ = select([self.socket], [], [], 0) return len(rlist) > 0 @@ -49,32 +39,9 @@ class UdpTransport(Transport): self.socket.sendall(chunk) - def _write(self, msg, protobuf_msg): - raise NotImplemented() - - def _read(self): - (session_id, msg_type, datalen) = self._read_headers(FakeRead(self._raw_read)) - return (session_id, msg_type, self._raw_read(datalen)) - - def _raw_read(self, length): - start = time.time() - while len(self.buffer) < length: - data = self.socket.recv(64) - if not len(data): - if time.time() - start > 10: - # Over 10 s of no response, let's check if - # device is still alive - if not self.is_connected(): - raise ConnectionError("Connection failed") - else: - # Restart timer - start = time.time() - - time.sleep(0.001) - continue - - self.buffer += data + def _read_chunk(self): + data = self.socket.recv(64) + if len(data) != 64: + raise Exception("Unexpected chunk size: %d" % len(data)) - ret = self.buffer[:length] - self.buffer = self.buffer[length:] - return ret + return bytearray(data)