94 lines
3.1 KiB
Python
94 lines
3.1 KiB
Python
from collections import namedtuple
|
|
|
|
import hashing
|
|
|
|
# Entry(bytes[32], bytes[32], bytes[32], str)
|
|
Entry = namedtuple('Entry', ['salt', 'hashed_host', 'fingerprint', 'comment'])
|
|
|
|
class UnacceptableComment(Exception): pass
|
|
|
|
def normalize_domain(domain):
|
|
"""normalize_domain(str) → bytes
|
|
Changes the domain into a normalized form (punycoded, lowercase)"""
|
|
assert type(domain) == str
|
|
|
|
# We want to have domain names reasonably normalized. This is why we
|
|
# convert all internationalized domain names to punycode and
|
|
# lowercase all domains.
|
|
# The reason the lowercasing happens after the punycoding is because
|
|
# that way we don't have to worry about Unicode case mapping: in
|
|
# case of IDN the IDNA codec handles that for us, and in case of an
|
|
# ASCII domain it passes through the IDNA unmodified
|
|
return domain.encode('idna').lower()
|
|
|
|
def normalize_host(domain, port):
|
|
"""normalize_host(str, u16) → bytes
|
|
Tranform a host into the format in which it will be hashed.
|
|
|
|
Main difference between this and normalize_domain() is that this one
|
|
includes the port and produces stuff like [domain]:port if port is not
|
|
22"""
|
|
assert type(domain) == str
|
|
assert type(port) == int and 0 <= port <= (1<<16) - 1
|
|
|
|
normalized_domain = normalize_domain(domain)
|
|
|
|
# If the port is not :22, we store [domain]:port instead
|
|
if port != 22:
|
|
normalized_domain = b'[%s]:%i' % (normalized_domain, port)
|
|
|
|
return normalized_domain
|
|
|
|
def create_entry(domain, port, fingerprint, comment):
|
|
"""create_entry(str, u16, bytes[32], str) → Entry
|
|
Given unprocessed host, a binary fingerprint and a comment, creates
|
|
and entry describing it"""
|
|
assert type(domain) == str
|
|
assert type(port) == int and 0 <= port <= (1<<16) - 1
|
|
assert type(fingerprint) == bytes and len(fingerprint) == 32
|
|
assert type(comment) == str
|
|
|
|
# Normalize the host before hashing
|
|
normalized_host = normalize_host(domain, port)
|
|
|
|
# Hash the host and store the salt
|
|
salt, hashed_host = hashing.hash_host(normalized_host)
|
|
|
|
# Comment must not include newlines
|
|
if '\n' in comment:
|
|
raise UnacceptableComment('Comment contains newlines')
|
|
|
|
return Entry(salt, hashed_host, fingerprint, comment)
|
|
|
|
def filter_by_host(entries, domain, port):
|
|
"""filter_by_host([Entry], str, u16) → [Entry]
|
|
Return hosts that match given domain and port."""
|
|
assert type(entries) == list and all(type(i) == Entry for i in entries)
|
|
assert type(domain) == str
|
|
assert type(port) == int and 0 <= port <= (1<<16) - 1
|
|
|
|
# Normalize the host here, so we don't have to do it every time we
|
|
# check for a match
|
|
normalized_host = normalize_host(domain, port)
|
|
|
|
matches = []
|
|
for entry in entries:
|
|
hashed_host = hashing.hash_with_salt(normalized_host, entry.salt)
|
|
if hashed_host == entry.hashed_host:
|
|
matches.append(entry)
|
|
|
|
return matches
|
|
|
|
def filter_by_fingerprint(entries, fingerprint):
|
|
"""filter_by_fingerprint([Entry], bytes[32]) → [Entry]
|
|
Return hosts that match given fingerprint."""
|
|
assert type(entries) == list and all(type(i) == Entry for i in entries)
|
|
assert type(fingerprint) == bytes and len(fingerprint) == 32
|
|
|
|
matches = []
|
|
for entry in entries:
|
|
if fingerprint == entry.fingerprint:
|
|
matches.append(entry)
|
|
|
|
return matches
|