from collections import namedtuple import hashing # Entry(bytes[32], bytes[32], bytes[32], str) Entry = namedtuple('Entry', ['salt', 'hashed_host', 'fingerprint', 'comment']) class UnacceptableComment(Exception): pass def normalize_domain(domain): """normalize_domain(str) → bytes Changes the domain into a normalized form (punycoded, lowercase)""" assert type(domain) == str # We want to have domain names reasonably normalized. This is why we # convert all internationalized domain names to punycode and # lowercase all domains. # The reason the lowercasing happens after the punycoding is because # that way we don't have to worry about Unicode case mapping: in # case of IDN the IDNA codec handles that for us, and in case of an # ASCII domain it passes through the IDNA unmodified return domain.encode('idna').lower() def normalize_host(domain, port): """normalize_host(str, u16) → bytes Tranform a host into the format in which it will be hashed. Main difference between this and normalize_domain() is that this one includes the port and produces stuff like [domain]:port if port is not 22""" assert type(domain) == str assert type(port) == int and 0 <= port <= (1<<16) - 1 normalized_domain = normalize_domain(domain) # If the port is not :22, we store [domain]:port instead if port != 22: normalized_domain = b'[%s]:%i' % (normalized_domain, port) return normalized_domain def create_entry(domain, port, fingerprint, comment): """create_entry(str, u16, bytes[32], str) → Entry Given unprocessed host, a binary fingerprint and a comment, creates and entry describing it""" assert type(domain) == str assert type(port) == int and 0 <= port <= (1<<16) - 1 assert type(fingerprint) == bytes and len(fingerprint) == 32 assert type(comment) == str # Normalize the host before hashing normalized_host = normalize_host(domain, port) # Hash the host and store the salt salt, hashed_host = hashing.hash_host(normalized_host) # Comment must not include newlines if '\n' in comment: raise UnacceptableComment('Comment contains newlines') return Entry(salt, hashed_host, fingerprint, comment) def filter_by_host(entries, domain, port): """filter_by_host([Entry], str, u16) → [Entry] Return hosts that match given domain and port.""" assert type(entries) == list and all(type(i) == Entry for i in entries) assert type(domain) == str assert type(port) == int and 0 <= port <= (1<<16) - 1 # Normalize the host here, so we don't have to do it every time we # check for a match normalized_host = normalize_host(domain, port) matches = [] for entry in entries: hashed_host = hashing.hash_with_salt(normalized_host, entry.salt) if hashed_host == entry.hashed_host: matches.append(entry) return matches def filter_by_fingerprint(entries, fingerprint): """filter_by_fingerprint([Entry], bytes[32]) → [Entry] Return hosts that match given fingerprint.""" assert type(entries) == list and all(type(i) == Entry for i in entries) assert type(fingerprint) == bytes and len(fingerprint) == 32 matches = [] for entry in entries: if fingerprint == entry.fingerprint: matches.append(entry) return matches