sshwot/src/main-verify.py

108 lines
3.1 KiB
Python

import argparse
import os
import sys
import check_fingerprint
import default_files
import entry
import hashing
import read_file
def main():
# TODO: Do known_hosts files too
parser = argparse.ArgumentParser(
description = """Search sshwot files for matching fingerprints.""",
# We want to provide help on --help, but the default thing
# also adds -h, which we don't want
add_help = False
)
# --help to get help
parser.add_argument('--help',
action = 'help',
help = 'show this help message and exit'
)
# -p/--port for port, but host is a positional argument
parser.add_argument('-p', '--port',
action = 'store',
dest = 'port',
# Automatically convert to integer
type = int,
help = 'the port associated with the given host'
)
# Host and fingerprint are required
parser.add_argument('host',
help = 'the domain to check'
)
parser.add_argument('fingerprint',
help = 'the fingerprint to check'
)
# Input file(s)
# Don't use argparse.FileType('rb'), since we want to know the names
parser.add_argument('infiles',
nargs = '*',
# The text shown for these in the usage
metavar = 'sshwot-file',
help = 'a sshwot file to search'
)
# This automatically parses the command line args for us. If it
# returns, we have correct arguments
args = parser.parse_args()
# Default to port 22
port = 22 if args.port is None else args.port
# Check the validity of the fingerprint and de-base64 it
if args.fingerprint[0:7].upper() != 'SHA256:':
print('We can only handle sha256 fingerprints (starts with SHA256:)')
sys.exit(1)
# We encode this, because hashing.base64dec expects bytes
fingerprint = hashing.base64dec(args.fingerprint[7:].encode())
# A valid sha256 fingerprint is 32 bytes
if len(fingerprint) < 32:
raise Exception('Fingerprint too short')
elif len(fingerprint) > 32:
raise Exception('Fingerprint too long')
# Use the default files if no input files were specified
if len(args.infiles) == 0:
infiles = default_files.list_all()
else:
infiles = args.infiles
# Check
for path in infiles:
# Remove the directory and the extension from the file
name = os.path.basename(path)
if name.split('.')[-1] == 'sshwot':
name = '.'.join(name.split('.')[:-1])
with open(path, 'rb') as f:
entries, file_comment = read_file.read(f)
success, fail, same_fingerprint = check_fingerprint.check(entries, args.host, port, fingerprint)
for match_host, match_port, match_comment in success:
# Use for display the same normalzed format as internally
# We do .decode() here, as it produces bytes
host_display = entry.normalize_host(match_host, match_port).decode()
print('[\x1b[32mok\x1b[0m] %s: %s: %s' % (name, host_display, match_comment))
for fail_host, fail_port, fail_comment in fail:
host_display = entry.normalize_host(fail_host, fail_port).decode()
print('[\x1b[31mfail\x1b[0m] %s: %s: %s' % (name, host_display, fail_comment))
for _, _, same_fingerprint_comment in same_fingerprint:
print('[same fingerprint] %s: (unknown host): %s' % (name, same_fingerprint_comment))
if __name__ == '__main__':
try:
main()
except Exception as err:
print('Error: %s' % err, file=sys.stderr)
sys.exit(1)