diff --git a/src/entry.py b/src/entry.py index 81fcec2..7be6708 100644 --- a/src/entry.py +++ b/src/entry.py @@ -2,7 +2,7 @@ from collections import namedtuple import hashing -# Entry(bytes[32], bytes[32], bytes[32], bytes[0…2¹⁶-1]) +# Entry(bytes[32], bytes[32], bytes[32], str) Entry = namedtuple('Entry', ['salt', 'hashed_host', 'fingerprint', 'comment']) class UnacceptableComment(Exception): pass @@ -47,10 +47,4 @@ def create_entry(domain, port, fingerprint, comment): if '\n' in comment: raise UnacceptableComment('Comment contains newlines') - comment_encoded = comment.encode('utf-8') - - # Comment may be at max 2¹⁶-1 bytes long - if len(comment_encoded) >= 1<<16: - raise UnacceptableComment('Comment length of %i bytes is too long' % len(comment_encoded)) - - return Entry(salt, hashed_host, fingerprint, comment_encoded) + return Entry(salt, hashed_host, fingerprint, comment) diff --git a/src/read_file.py b/src/read_file.py index 0f9ce86..6158a5b 100644 --- a/src/read_file.py +++ b/src/read_file.py @@ -1,67 +1,113 @@ +import base64 + import entry class FileFormatError(Exception): pass class VersionMismatch(Exception): pass -def check_header(f): - """check_header(file(rb)) - Throw an error if the header isn't good""" - # Magic is b'WOT' - magic = f.read(3) - if magic != b'WOT': +def parse_header(header): + """parse_header(bytes) → str + Throw an error if the header isn't good and return the file comment + (if any) if it is""" + assert type(header) == bytes + + magic = header[0:6] + if magic != b'SSHWOT': raise FileFormatError('Invalid magic') # Version 0 is the current one - version = f.read(1) + version = header[6:7] if version == b'': - raise FileFormatError('Unexpected end of file') - if version != b'\0': + raise FileFormatError('No newline after header') + if version != b'0': raise VersionMismatch('Version %i not supported' % version[0]) -def read_entry(f): - """read_entry(file(rb)) → Entry / None - Returns None if the end of file has been reached""" - # u8[32]: salt - salt = f.read(32) - if len(salt) == 0: - # End of file has been reached, return None to mark that - return None - elif len(salt) != 32: - raise FileFormatError('Unexpected end of file') + # See if we have a comment + if header[7:8] == b' ': + # It says we have + if header[8:9] == b'\n': + # No, we don't, but we do have a space telling we + # have. The header is malformed + raise FileFormatError('Missing comment or spurious space in the header') + else: + # Yes, we do + # Check it ends with a newline + if header[-1] != 0x0a: + raise FileFormatError('Missing newline at the end of the header') - # u8[32]: hashed_host - hashed_host = f.read(32) - if len(hashed_host) != 32: - raise FileFormatError('Unexpected end of file') + try: + file_comment = header[8:-1].decode('utf-8') + except UnicodeDecodeError: + raise FileFormatError('Comment is not valid utf-8') - # u8[32]: fingerprint - fingerprint = f.read(32) - if len(fingerprint) != 32: - raise FileFormatError('Unexpected end of file') + return file_comment - # u16le: comment_length - length_bytes = f.read(2) - if len(length_bytes) != 2: - raise FileFormatError('Unexpected end of file') - comment_length = length_bytes[0] | length_bytes[1] << 8 + elif header[7:8] == b'\n': + # No, we have newline + return '' - # u8[comment_length]: comment - comment = f.read(comment_length) - if len(comment) != comment_length: - raise FileFormatError('Unexpected end of file') + else: + # No, we have something else + raise FileFormatError("Expected a space or a newline but got '%s' instead" % header[7:].decode('utf-8')) + +def parse_entry(line): + """parse_entry(bytes) → Entry""" + assert type(line) == bytes + + def extract_b64_field(rest): + """extract_b64_field(bytes) → (bytes: decoded_field, bytes:rest)""" + field_b64 = rest[0:44] + if len(field_b64) != 44: + raise FileFormatError('Unexpected end of line') + try: + field = base64.b64decode(field_b64, validate = True) + except (ValueError, base64.binascii.Error) as err: + raise FileFormatError('Malformed base64 string: %s' % field_b64.decode('utf-8')) from err + + return field, rest[44:] + + salt, rest = extract_b64_field(line) + hashed_host, rest = extract_b64_field(rest) + fingerprint, rest = extract_b64_field(rest) + + # What do we have after that? + if rest[0:1] == b' ': + # A comment? + if rest[1:2] == b'\n': + # No, but it says we have. It's malformed + raise FileFormatError('Missing comment or spurious space in the entry') + else: + # Yes. Make sure it ends in a newline + if rest[-1] != 0x0a: + raise FileFormatError('No newline after entry') + + try: + comment = rest[1:-1].decode('utf-8') + except UnicodeDecodeError: + raise FileFormatError('Comment is not valid utf-8') + + elif rest[0:1] == b'\n': + # A newline + comment = '' + + else: + # Something else + raise FileFormatError('Expected a space or a newline but got "%s" instead' % rest.decode('utf-8')) return entry.Entry(salt, hashed_host, fingerprint, comment) def read(f): - """read_file(file(rb)) → [Entry]""" - check_header(f) + """read(file(rb)) → ([Entry]: entries, str: file_comment)""" + lines = [line for line in f] + + if len(lines) == 0: + raise FileFormatError('Missing header') + + file_comment = parse_header(lines[0]) entries = [] - while True: - # Read until we reach the end of file - entry = read_entry(f) - if entry is None: break - entries.append(entry) + for line in lines[1:]: + entries.append(parse_entry(line)) - return entries + return entries, file_comment diff --git a/src/write_file.py b/src/write_file.py index bdf8620..60614eb 100644 --- a/src/write_file.py +++ b/src/write_file.py @@ -1,39 +1,53 @@ -def write_header(f): - """write_header(file(wb)) +import base64 + +def write_header(f, file_comment): + """write_header(file(wb), str) Writes the header to the given file.""" - # b'WOT' magic - f.write(b'WOT') + assert type(file_comment) == str + # b'SSHWOT' magic + f.write(b'SSHWOT') # Version number - f.write(bytes([0])) + f.write(b'0') + # b' ' + file_comment, if there is one + if len(file_comment) > 0: + f.write(b' ') + assert b'\n' not in file_comment + f.write(file_comment) + # End of header marked with b'\n' + f.write(b'\n') def write_entry(f, salt, hashed_host, fingerprint, comment): - """write_entry(file(wb), bytes[32], bytes[32], bytes[32], bytes[0…2¹⁶-1]) + """write_entry(file(wb), bytes[32], bytes[32], bytes[32], str) Writes an entry to the given file.""" assert type(salt) == bytes and len(salt) == 32 assert type(hashed_host) == bytes and len(hashed_host) == 32 assert type(fingerprint) == bytes and len(fingerprint) == 32 - assert type(comment) == bytes and 0 <= len(comment) <= (1<<16) - 1 + assert type(comment) == str - # u8[32]: salt - f.write(salt) + # base64 encoded (44 bytes): salt + f.write(base64.b64encode(salt)) - # u8[32]: hashed_host - f.write(hashed_host) + # base64 encoded (44 bytes): hashed_host + f.write(base64.b64encode(hashed_host)) - # u8[32]: fingerprint - f.write(fingerprint) + # base64 encoded (44 bytes): fingerprint + f.write(base64.b64encode(fingerprint)) - # u16le: len(comment) - comment_len = len(comment) - f.write(bytes([comment_len & 0xff, comment_len >> 8])) + # b' ' + comment, if there is one + if len(comment) > 0: + f.write(b' ') + assert '\n' not in comment + f.write(comment.encode('utf-8')) - # u8[]: comment - f.write(comment) + # End of entry marked with b'\n' + f.write(b'\n') -def write(f, entries): - """write(file(wb), [Entry]) +def write(f, entries, file_comment = ''): + """write(file(wb), [Entry], str) Creates a file containing all of the entries""" - write_header(f) + assert type(file_comment) == str + + write_header(f, file_comment) for entry in entries: write_entry(f, entry.salt, entry.hashed_host, entry.fingerprint, entry.comment)