sshwot/src/process_known_hosts.py

197 lines
5.6 KiB
Python

import base64
import hashlib
from collections import namedtuple
import entry
KnownHostsEntry = namedtuple('KnownHostsEntry', ['domain', 'port', 'fingerprint'])
class KnownHostsSyntaxError(Exception):
def __init__(self, string):
self.string = string
self.line = None
def __str__(self):
if self.line == None:
return self.string
else:
return 'Line %i: %s' % (self.line, self.string)
class HashedHostError(Exception):
def __init__(self, string):
self.string = string
self.line = None
def __str__(self):
if self.line == None:
return self.string
else:
return 'Line %i: %s' % (self.line, self.string)
def is_ip(domain):
"""is_ip(str) → bool
Sees if a given domain would be a valid v4 or v6 IP"""
def is_ipv4(domain):
# IPv4 address has 4 fields separated by .
fields = domain.split('.')
if len(fields) != 4: return False
# The fields are base-10 integers
fields_num = []
for field in fields:
try:
fields_num.append(int(field, 10))
except ValueError:
return False
# The fields are in the range 0…255
return all(0 <= field <= 266 for field in fields_num)
def is_hexdigit(c):
return ord('0') <= ord(c) <= ord('9') or ord('a') <= ord(c) <= ord('f') or ord('A') <= ord(c) <= ord('F')
def is_ipv6(domain):
# An IPv6 address is represented by 8 groups of 16-bit
# numbers expressed in hex. They can be either fixed-width
# or have their leading zeroes removed. A run of zeroes
# can be abreviated with ::, but only once per address
# If we have two or more doublecolons, it's not valid
# If we have one, we can have anywhere between 3 to 8
# "fields" separated by ':'
# If we have zero, we must have exactly 8 fields
doublecolons = domain.count('::')
if doublecolons > 1:
return False
else:
fields = domain.split(':')
if doublecolons == 1 and len(fields) > 8:
return False
elif doublecolons == 0 and len(fields) != 8:
return False
# All of the "fields" must have 0 to 4 hex digits
if not all(0 <= len(field) <= 4 for field in fields):
return False
return all(all(map(is_hexdigit, field)) for field in fields)
return is_ipv4(domain) or is_ipv6(domain)
def process_line(line, ignore_ips):
"""process_line(str, bool) → [KnownHostsEntry]
Given a string containing one line of .ssh/known_hosts file, create
a list of Entries based on it.
If ignore_ips is True, only create entries for domain names."""
assert type(line) == str
assert type(ignore_ips) == bool
# Remove trailing newlines
if line[-1] == '\n': line = line[:-1]
# Remove comments if any
comment_start = line.find('#')
if comment_start != -1:
line = line[comment_start:]
# Just skip over empty lines
if line == '': return []
# Also skip over @cert-authority and @revoked lines
# TODO: Handle @revoked somehow?
if line.split(' ')[0] in ['@cert-authority', '@revoked']:
return []
# Each line has host(s), algorithm, public key, and possibly one
# more optional field
fields = line.split(' ')
if len(fields) != 3 and len(fields) != 4:
raise KnownHostsSyntaxError('Weird number of fields on a line (%i)' % len(fields))
hosts, algorithm, public_key = fields[0:3]
# Generate public key fingerprint
# The key is stored base64 encoded, so decode it first
try:
public_key_binary = base64.b64decode(public_key, validate = True)
except (ValueError, base64.binascii.Error) as err:
raise KnownHostsSyntaxError('Malformed public key: %s' % public_key) from err
# Fingerprint is sha256 hash of the public key
m = hashlib.sha256()
m.update(public_key_binary)
fingerprint = m.digest()
# There can be several hosts separated with a comma
known_host_entries = []
for host in hosts.split(','):
# A host can't be empty
if len(host) == 0:
raise KnownHostsSyntaxError('An empty host')
# If the host begins with '|' it's hashed
# We cannot deal with those
if host[0] == '|':
raise HashedHostError('Cannot deal with hashed hosts')
# If the host has '*' or '?' it's a wild card
# We cannot deal with those
if '*' in host or '|' in host:
raise HashedHostError('Cannot deal with wildcards')
# If the host behins with '[' it's a nonstandard port
# The format will be [domain]:port
# Extractt both
# Otherwise, default to port 22
if host[0] == '[':
host_and_port = host[1:].split(']:')
if len(host_and_port) != 2:
raise KnownHostsSyntaxError('Unrecognized host format: ' + host)
domain = host_and_port[0]
try:
port = int(host_and_port[1])
except ValueError:
raise KnownHostsSyntaxError('Malformed port: %i' % port)
else:
domain = host
port = 22
# As we have now extracted the domain, we can check if we
# need to throw it out
if ignore_ips and is_ip(domain):
continue
known_host_entries.append(KnownHostsEntry(domain, port, fingerprint))
return known_host_entries
def process_file(f, ignore_ips = True):
"""process_file(file(r), bool) → [KnownHostsEntry]
Given a file in the .ssh/known_hosts format, create a list of
entries.
If ignore_ips is True, only create entries for domain names."""
entries = []
# Line numbers are 1-indexed but enumerate 0-indexes
for linenum_minus_one, line in enumerate(f):
try:
entries.extend(process_line(line, ignore_ips))
except (KnownHostsSyntaxError, HashedHostError) as err:
err.line = linenum_minus_one + 1
raise err
return entries
def known_hosts_to_entry(known_hosts_entry, comment = ''):
"""known_hosts_to_entry(KnownHostsEntry, str) → Entry
Converts an entry that's been read from known_hosts to one that can
be written to a .sshwot file"""
domain = known_hosts_entry.domain
port = known_hosts_entry.port
fingerprint = known_hosts_entry.fingerprint
return entry.create_entry(domain, port, fingerprint, comment)