diff --git a/.gitignore b/.gitignore index ef98c70..88e1ce7 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,7 @@ __pycache__ *.pyc *.swp *.sshwot +*.sshwot.* build sshwot-export-known-hosts +sshwot-filter diff --git a/Makefile b/Makefile index 3a55b13..4cb4933 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,10 @@ -BINS:=sshwot-export-known-hosts +BINS:=sshwot-export-known-hosts sshwot-filter SSHWOT_EXPORT_KNOWN_HOSTS_MAIN:=src/main-export-known-hosts.py -SSHWOT_EXPORT_KNOWN_HOSTS_DEPS:=$(SSHWOT_EXPORT_KNOWN_HOSTS_MAIN) src/entry.py src/hashing.py src/process_known_hosts.py src/write_file.py +SSHWOT_EXPORT_KNOWN_HOSTS_DEPS:=src/entry.py src/hashing.py src/process_known_hosts.py src/write_file.py + +SSHWOT_FILTER_MAIN:=src/main-filter.py +SSHWOT_FILTER_DEPS:=src/entry.py src/hashing.py src/read_file.py src/write_file.py all: $(BINS) @@ -14,6 +17,15 @@ sshwot-export-known-hosts: $(SSHWOT_EXPORT_KNOWN_HOSTS_MAIN) $(SSHWOT_EXPORT_KNO cat build/$@.zip >> $@ chmod +x $@ +sshwot-filter: $(SSHWOT_FILTER_MAIN) $(SSHWOT_FILTER_DEPS) + mkdir -p build/$@ + cp $(SSHWOT_FILTER_DEPS) build/$@/ + cp $(SSHWOT_FILTER_MAIN) build/$@/__main__.py + zip --quiet --junk-paths build/$@.zip build/$@/*.py + echo '#!/usr/bin/env python3' > $@ + cat build/$@.zip >> $@ + chmod +x $@ + .PHONY: all clean distclean buildclean clean: diff --git a/src/entry.py b/src/entry.py index 7e0dd4b..857c2e1 100644 --- a/src/entry.py +++ b/src/entry.py @@ -7,11 +7,10 @@ Entry = namedtuple('Entry', ['salt', 'hashed_host', 'fingerprint', 'comment']) class UnacceptableComment(Exception): pass -def normalize_host(domain, port): - """normalize_host(str, u16) → bytes - Tranform a domain into the format in which it will be hashed""" +def normalize_domain(domain): + """normalize_domain(str) → bytes + Changes the domain into a normalized form (punycoded, lowercase)""" assert type(domain) == str - assert type(port) == int and 0 <= port <= (1<<16) - 1 # We want to have domain names reasonably normalized. This is why we # convert all internationalized domain names to punycode and @@ -20,13 +19,25 @@ def normalize_host(domain, port): # that way we don't have to worry about Unicode case mapping: in # case of IDN the IDNA codec handles that for us, and in case of an # ASCII domain it passes through the IDNA unmodified - normalized_host = domain.encode('idna').lower() + return domain.encode('idna').lower() - # If the port is not :22, we store [host]:port instead +def normalize_host(domain, port): + """normalize_host(str, u16) → bytes + Tranform a host into the format in which it will be hashed. + + Main difference between this and normalize_domain() is that this one + includes the port and produces stuff like [domain]:port if port is not + 22""" + assert type(domain) == str + assert type(port) == int and 0 <= port <= (1<<16) - 1 + + normalized_domain = normalize_domain(domain) + + # If the port is not :22, we store [domain]:port instead if port != 22: - normalized_host = b'[%s]:%i' % (normalized_host, port) + normalized_domain = b'[%s]:%i' % (normalized_domain, port) - return normalized_host + return normalized_domain def create_entry(domain, port, fingerprint, comment): """create_entry(str, u16, bytes[32], str) → Entry @@ -60,13 +71,13 @@ def filter_by_host(entries, domain, port): # check for a match normalized_host = normalize_host(domain, port) - entries = [] + matches = [] for entry in entries: hashed_host = hashing.hash_with_salt(normalized_host, entry.salt) if hashed_host == entry.hashed_host: - entries.append(entry) + matches.append(entry) - return entries + return matches def filter_by_fingerprint(entries, fingerprint): """filter_by_fingerprint([Entry], bytes[32]) → [Entry] @@ -74,9 +85,9 @@ def filter_by_fingerprint(entries, fingerprint): assert type(entries) == list and all(type(i) == Entry for i in entries) assert type(fingerprint) == bytes and len(fingerprint) == 32 - entries = [] + matches = [] for entry in entries: if fingerprint == entry.fingerprint: - entries.append(entry) + matches.append(entry) - return entries + return matches diff --git a/src/main-filter.py b/src/main-filter.py new file mode 100644 index 0000000..11f9a25 --- /dev/null +++ b/src/main-filter.py @@ -0,0 +1,118 @@ +import argparse +import sys + +import entry +import hashing +import read_file +import write_file + +def main(): + # TODO: Default location to search + parser = argparse.ArgumentParser( + description = """Search sshwot file(s) for given host and/or + fingerprint.""", + # We want to provide help on --help, but the default thing + # also adds -h, which we don't want + add_help = False + ) + + # --help to get help + parser.add_argument('--help', + action = 'help', + help = 'show this help message and exit' + ) + + # -o can be used to write the results to a file instead of stdout + parser.add_argument('-o', + # Given one argument, we open a file of that name and store it + # to outfile, which will be sys.stdout.buffer otherwise. + # We use .buffer since we're going to write binary data + action = 'store', + dest = 'outfile', + type = argparse.FileType('wb'), + default = sys.stdout.buffer, + # This is what will be displayed in the help after -o + metavar = 'outfile', + help = 'write the sshwot file to a given file instead of the stdout' + ) + + # -h/--host for domain, -p/--port for port (default port is 22) + parser.add_argument('-h', '--host', + action = 'store', + dest = 'host', + help = 'the domain to filter for' + ) + parser.add_argument('-p', '--port', + action = 'store', + dest = 'port', + # Automatically convert to integer + type = int, + help = 'the port associated with the given host' + ) + + # -f/--fingerprint for fingerprint + parser.add_argument('-f', '--fingerprint', + action = 'store', + dest = 'fingerprint', + help = 'the fingerprint to filter for' + ) + + # Input file(s) + parser.add_argument('infiles', + nargs = '*', + type = argparse.FileType('rb'), + # The text shown for these in the usage + metavar = 'sshwot-file', + help = 'a sshwot file to search' + ) + + # This automatically parses the command line args for us. If it + # returns, we have correct arguments + args = parser.parse_args() + + # Ensure we have a host if port was specified + if args.port is not None and args.host is None: + print('If you specify a port you need to specify a domain too', file=sys.stderr) + sys.exit(1) + + # Ensure we're filtering for something + if args.host is None and args.fingerprint is None: + print('Specify a host and/or a fingerprint to filter for', file=sys.stderr) + sys.exit(1) + + # Default to port 22 + port = 22 if args.port is None else args.port + + # Check the validity of the fingerprint and de-base64 it, if we have it + if args.fingerprint is not None: + if args.fingerprint[0:7].upper() != 'SHA256:': + print('We can only handle sha256 fingerprints (starts with SHA256:)') + sys.exit(1) + # We encode this, because hashing.base64dec expects bytes + fingerprint = hashing.base64dec(args.fingerprint[7:].encode()) + + else: + fingerprint = None + + matches = [] + for infile in args.infiles: + entries, file_comment = read_file.read(infile) + + # Filter by host if it's present + if args.host is not None: + entries = entry.filter_by_host(entries, args.host, port) + # Filter by fingerprint if it's present + if fingerprint is not None: + entries = entry.filter_by_fingerprint(entries, fingerprint) + + matches.extend(entries) + + # Print the matches in sshwot format + write_file.write(args.outfile, matches) + +if __name__ == '__main__': + try: + main() + except Exception as err: + print('Error: %s' % err, file=sys.stderr) + sys.exit(1)