Implement sshwot-filter

This commit is contained in:
Juhani Krekelä 2018-09-01 21:38:08 +03:00
parent f390da248c
commit b40f61b0b5
4 changed files with 159 additions and 16 deletions

2
.gitignore vendored
View File

@ -2,5 +2,7 @@ __pycache__
*.pyc
*.swp
*.sshwot
*.sshwot.*
build
sshwot-export-known-hosts
sshwot-filter

View File

@ -1,7 +1,10 @@
BINS:=sshwot-export-known-hosts
BINS:=sshwot-export-known-hosts sshwot-filter
SSHWOT_EXPORT_KNOWN_HOSTS_MAIN:=src/main-export-known-hosts.py
SSHWOT_EXPORT_KNOWN_HOSTS_DEPS:=$(SSHWOT_EXPORT_KNOWN_HOSTS_MAIN) src/entry.py src/hashing.py src/process_known_hosts.py src/write_file.py
SSHWOT_EXPORT_KNOWN_HOSTS_DEPS:=src/entry.py src/hashing.py src/process_known_hosts.py src/write_file.py
SSHWOT_FILTER_MAIN:=src/main-filter.py
SSHWOT_FILTER_DEPS:=src/entry.py src/hashing.py src/read_file.py src/write_file.py
all: $(BINS)
@ -14,6 +17,15 @@ sshwot-export-known-hosts: $(SSHWOT_EXPORT_KNOWN_HOSTS_MAIN) $(SSHWOT_EXPORT_KNO
cat build/$@.zip >> $@
chmod +x $@
sshwot-filter: $(SSHWOT_FILTER_MAIN) $(SSHWOT_FILTER_DEPS)
mkdir -p build/$@
cp $(SSHWOT_FILTER_DEPS) build/$@/
cp $(SSHWOT_FILTER_MAIN) build/$@/__main__.py
zip --quiet --junk-paths build/$@.zip build/$@/*.py
echo '#!/usr/bin/env python3' > $@
cat build/$@.zip >> $@
chmod +x $@
.PHONY: all clean distclean buildclean
clean:

View File

@ -7,11 +7,10 @@ Entry = namedtuple('Entry', ['salt', 'hashed_host', 'fingerprint', 'comment'])
class UnacceptableComment(Exception): pass
def normalize_host(domain, port):
"""normalize_host(str, u16) → bytes
Tranform a domain into the format in which it will be hashed"""
def normalize_domain(domain):
"""normalize_domain(str) → bytes
Changes the domain into a normalized form (punycoded, lowercase)"""
assert type(domain) == str
assert type(port) == int and 0 <= port <= (1<<16) - 1
# We want to have domain names reasonably normalized. This is why we
# convert all internationalized domain names to punycode and
@ -20,13 +19,25 @@ def normalize_host(domain, port):
# that way we don't have to worry about Unicode case mapping: in
# case of IDN the IDNA codec handles that for us, and in case of an
# ASCII domain it passes through the IDNA unmodified
normalized_host = domain.encode('idna').lower()
return domain.encode('idna').lower()
# If the port is not :22, we store [host]:port instead
def normalize_host(domain, port):
"""normalize_host(str, u16) → bytes
Tranform a host into the format in which it will be hashed.
Main difference between this and normalize_domain() is that this one
includes the port and produces stuff like [domain]:port if port is not
22"""
assert type(domain) == str
assert type(port) == int and 0 <= port <= (1<<16) - 1
normalized_domain = normalize_domain(domain)
# If the port is not :22, we store [domain]:port instead
if port != 22:
normalized_host = b'[%s]:%i' % (normalized_host, port)
normalized_domain = b'[%s]:%i' % (normalized_domain, port)
return normalized_host
return normalized_domain
def create_entry(domain, port, fingerprint, comment):
"""create_entry(str, u16, bytes[32], str) → Entry
@ -60,13 +71,13 @@ def filter_by_host(entries, domain, port):
# check for a match
normalized_host = normalize_host(domain, port)
entries = []
matches = []
for entry in entries:
hashed_host = hashing.hash_with_salt(normalized_host, entry.salt)
if hashed_host == entry.hashed_host:
entries.append(entry)
matches.append(entry)
return entries
return matches
def filter_by_fingerprint(entries, fingerprint):
"""filter_by_fingerprint([Entry], bytes[32]) → [Entry]
@ -74,9 +85,9 @@ def filter_by_fingerprint(entries, fingerprint):
assert type(entries) == list and all(type(i) == Entry for i in entries)
assert type(fingerprint) == bytes and len(fingerprint) == 32
entries = []
matches = []
for entry in entries:
if fingerprint == entry.fingerprint:
entries.append(entry)
matches.append(entry)
return entries
return matches

118
src/main-filter.py Normal file
View File

@ -0,0 +1,118 @@
import argparse
import sys
import entry
import hashing
import read_file
import write_file
def main():
# TODO: Default location to search
parser = argparse.ArgumentParser(
description = """Search sshwot file(s) for given host and/or
fingerprint.""",
# We want to provide help on --help, but the default thing
# also adds -h, which we don't want
add_help = False
)
# --help to get help
parser.add_argument('--help',
action = 'help',
help = 'show this help message and exit'
)
# -o can be used to write the results to a file instead of stdout
parser.add_argument('-o',
# Given one argument, we open a file of that name and store it
# to outfile, which will be sys.stdout.buffer otherwise.
# We use .buffer since we're going to write binary data
action = 'store',
dest = 'outfile',
type = argparse.FileType('wb'),
default = sys.stdout.buffer,
# This is what will be displayed in the help after -o
metavar = 'outfile',
help = 'write the sshwot file to a given file instead of the stdout'
)
# -h/--host for domain, -p/--port for port (default port is 22)
parser.add_argument('-h', '--host',
action = 'store',
dest = 'host',
help = 'the domain to filter for'
)
parser.add_argument('-p', '--port',
action = 'store',
dest = 'port',
# Automatically convert to integer
type = int,
help = 'the port associated with the given host'
)
# -f/--fingerprint for fingerprint
parser.add_argument('-f', '--fingerprint',
action = 'store',
dest = 'fingerprint',
help = 'the fingerprint to filter for'
)
# Input file(s)
parser.add_argument('infiles',
nargs = '*',
type = argparse.FileType('rb'),
# The text shown for these in the usage
metavar = 'sshwot-file',
help = 'a sshwot file to search'
)
# This automatically parses the command line args for us. If it
# returns, we have correct arguments
args = parser.parse_args()
# Ensure we have a host if port was specified
if args.port is not None and args.host is None:
print('If you specify a port you need to specify a domain too', file=sys.stderr)
sys.exit(1)
# Ensure we're filtering for something
if args.host is None and args.fingerprint is None:
print('Specify a host and/or a fingerprint to filter for', file=sys.stderr)
sys.exit(1)
# Default to port 22
port = 22 if args.port is None else args.port
# Check the validity of the fingerprint and de-base64 it, if we have it
if args.fingerprint is not None:
if args.fingerprint[0:7].upper() != 'SHA256:':
print('We can only handle sha256 fingerprints (starts with SHA256:)')
sys.exit(1)
# We encode this, because hashing.base64dec expects bytes
fingerprint = hashing.base64dec(args.fingerprint[7:].encode())
else:
fingerprint = None
matches = []
for infile in args.infiles:
entries, file_comment = read_file.read(infile)
# Filter by host if it's present
if args.host is not None:
entries = entry.filter_by_host(entries, args.host, port)
# Filter by fingerprint if it's present
if fingerprint is not None:
entries = entry.filter_by_fingerprint(entries, fingerprint)
matches.extend(entries)
# Print the matches in sshwot format
write_file.write(args.outfile, matches)
if __name__ == '__main__':
try:
main()
except Exception as err:
print('Error: %s' % err, file=sys.stderr)
sys.exit(1)