import argparse import os import sys import check_fingerprint import default_files import entry import hashing import read_file def main(): # TODO: Do known_hosts files too parser = argparse.ArgumentParser( description = """Search sshwot files for matching fingerprints.""", # We want to provide help on --help, but the default thing # also adds -h, which we don't want add_help = False ) # --help to get help parser.add_argument('--help', action = 'help', help = 'show this help message and exit' ) # -p/--port for port, but host is a positional argument parser.add_argument('-p', '--port', action = 'store', dest = 'port', # Automatically convert to integer type = int, help = 'the port associated with the given host' ) # Host and fingerprint are required parser.add_argument('host', help = 'the domain to check' ) parser.add_argument('fingerprint', help = 'the fingerprint to check' ) # Input file(s) # Don't use argparse.FileType('rb'), since we want to know the names parser.add_argument('infiles', nargs = '*', # The text shown for these in the usage metavar = 'sshwot-file', help = 'a sshwot file to search' ) # This automatically parses the command line args for us. If it # returns, we have correct arguments args = parser.parse_args() # Default to port 22 port = 22 if args.port is None else args.port # Check the validity of the fingerprint and de-base64 it if args.fingerprint[0:7].upper() != 'SHA256:': print('We can only handle sha256 fingerprints (starts with SHA256:)') sys.exit(1) # We encode this, because hashing.base64dec expects bytes fingerprint = hashing.base64dec(args.fingerprint[7:].encode()) # A valid sha256 fingerprint is 32 bytes if len(fingerprint) < 32: raise Exception('Fingerprint too short') elif len(fingerprint) > 32: raise Exception('Fingerprint too long') # Use the default files if no input files were specified if len(args.infiles) == 0: infiles = default_files.list_all() else: infiles = args.infiles # Check for path in infiles: # Remove the directory and the extension from the file name = os.path.basename(path) if name.split('.')[-1] == 'sshwot': name = '.'.join(name.split('.')[:-1]) with open(path, 'rb') as f: entries, file_comment = read_file.read(f) success, fail, same_fingerprint = check_fingerprint.check(entries, args.host, port, fingerprint) for match_host, match_port, match_comment in success: # Use for display the same normalzed format as internally # We do .decode() here, as it produces bytes host_display = entry.normalize_host(match_host, match_port).decode() print('[\x1b[32mok\x1b[0m] %s: %s: %s' % (name, host_display, match_comment)) for fail_host, fail_port, fail_comment in fail: host_display = entry.normalize_host(fail_host, fail_port).decode() print('[\x1b[31mfail\x1b[0m] %s: %s: %s' % (name, host_display, fail_comment)) for _, _, same_fingerprint_comment in same_fingerprint: print('[same fingerprint] %s: (unknown host): %s' % (name, same_fingerprint_comment)) if __name__ == '__main__': try: main() except Exception as err: print('Error: %s' % err, file=sys.stderr) sys.exit(1)