From 004ec12ec3d043427c9e47b0534fa345a399347e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juhani=20Krekel=C3=A4?= Date: Fri, 31 Aug 2018 20:21:08 +0300 Subject: [PATCH] Revamp filtering entries --- src/check_fingerprint.py | 42 ++++++++++++++++++++++++++++------------ src/entry.py | 32 ++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/src/check_fingerprint.py b/src/check_fingerprint.py index 76b7fa1..68c89d5 100644 --- a/src/check_fingerprint.py +++ b/src/check_fingerprint.py @@ -1,16 +1,23 @@ import enum +from collections import namedtuple + import entry import hashing -# TODO: Include a thing for checking what hosts match a given fingerprint +# Result(str/None, u16/None, str) +Result = namedtuple('Result', ['domain', 'port', 'comment']) -def check_fingerprint(entries, domain, port, fingerprint): - """check_fingerprint([Entry], str, u16, bytes[32]) → ([str]: successes, [str]: fails) +def check(entries, domain, port, fingerprint): + """check([Entry], str, u16, bytes[32]) → ([Result]: successes, [Result]: fails, [Result]: same_fingerprint) Checks if the given host is found with the given fingerprint. - The successes and fails lists returned by the function have the - comments for the hosts that match and have the same fingerpring and - the hosts that match but have a different fingerprint, respectively""" + + successes contains ones where both the host and the fingerprint match. + + fails contains ones where host matches but the fingerprint doesn't. + + same_fingerprint contains ones where fingerprint matches but the + host doesn't. Their .domain and .port will be None""" assert type(entries) == list and all(type(i) == entry.Entry for i in entries) assert type(domain) == str assert type(port) == int and 0 <= port <= (1<<16) - 1 @@ -18,25 +25,36 @@ def check_fingerprint(entries, domain, port, fingerprint): # Normalize the host here, so we don't have to do it every time we # check for a possible match - normalized_hosts = [entry.normalize_host(domain, port)] + normalized_hosts = {port: entry.normalize_host(domain, port)} # If we are looking at non-22 port, also check the general form of # the host without a port specifier. This seems to be how OpenSSH # does it too if port != 22: - normalized_hosts.append(entry.normalize_host(domain, 22)) + normalized_hosts[22] = entry.normalize_host(domain, 22) successes = [] fails = [] + same_fingerprint = [] for possible_match in entries: - for normalized_host in normalized_hosts: + any_host_matched = False + + for current_port, normalized_host in normalized_hosts.items(): hashed_host = hashing.hash_with_salt(normalized_host, possible_match.salt) if hashed_host == possible_match.hashed_host: if fingerprint == possible_match.fingerprint: # Fingerprint matches, it passes - successes.append(possible_match.comment) + successes.append(Result(domain, current_port, possible_match.comment)) + any_host_matched = True else: # Fingerprint different, it fails - fails.append(possible_match.comment) + fails.append(Result(domain, current_port, possible_match.comment)) - return successes, fails + if not any_host_matched and fingerprint == possible_match.fingerprint: + # Host is not the same, but the fingerprint + # matches + print(possible_match)#debg + same_fingerprint.append(Result(None, None, possible_match.comment)) + + + return successes, fails, same_fingerprint diff --git a/src/entry.py b/src/entry.py index 7be6708..7e0dd4b 100644 --- a/src/entry.py +++ b/src/entry.py @@ -48,3 +48,35 @@ def create_entry(domain, port, fingerprint, comment): raise UnacceptableComment('Comment contains newlines') return Entry(salt, hashed_host, fingerprint, comment) + +def filter_by_host(entries, domain, port): + """filter_by_host([Entry], str, u16) → [Entry] + Return hosts that match given domain and port.""" + assert type(entries) == list and all(type(i) == Entry for i in entries) + assert type(domain) == str + assert type(port) == int and 0 <= port <= (1<<16) - 1 + + # Normalize the host here, so we don't have to do it every time we + # check for a match + normalized_host = normalize_host(domain, port) + + entries = [] + for entry in entries: + hashed_host = hashing.hash_with_salt(normalized_host, entry.salt) + if hashed_host == entry.hashed_host: + entries.append(entry) + + return entries + +def filter_by_fingerprint(entries, fingerprint): + """filter_by_fingerprint([Entry], bytes[32]) → [Entry] + Return hosts that match given fingerprint.""" + assert type(entries) == list and all(type(i) == Entry for i in entries) + assert type(fingerprint) == bytes and len(fingerprint) == 32 + + entries = [] + for entry in entries: + if fingerprint == entry.fingerprint: + entries.append(entry) + + return entries