197 lines
5.6 KiB
Python
197 lines
5.6 KiB
Python
import base64
|
|
import hashlib
|
|
|
|
from collections import namedtuple
|
|
|
|
import entry
|
|
|
|
KnownHostsEntry = namedtuple('KnownHostsEntry', ['domain', 'port', 'fingerprint'])
|
|
|
|
class KnownHostsSyntaxError(Exception):
|
|
def __init__(self, string):
|
|
self.string = string
|
|
self.line = None
|
|
|
|
def __str__(self):
|
|
if self.line == None:
|
|
return self.string
|
|
else:
|
|
return 'Line %i: %s' % (self.line, self.string)
|
|
|
|
class HashedHostError(Exception):
|
|
def __init__(self, string):
|
|
self.string = string
|
|
self.line = None
|
|
|
|
def __str__(self):
|
|
if self.line == None:
|
|
return self.string
|
|
else:
|
|
return 'Line %i: %s' % (self.line, self.string)
|
|
|
|
def is_ip(domain):
|
|
"""is_ip(str) → bool
|
|
Sees if a given domain would be a valid v4 or v6 IP"""
|
|
def is_ipv4(domain):
|
|
# IPv4 address has 4 fields separated by .
|
|
fields = domain.split('.')
|
|
if len(fields) != 4: return False
|
|
|
|
# The fields are base-10 integers
|
|
fields_num = []
|
|
for field in fields:
|
|
try:
|
|
fields_num.append(int(field, 10))
|
|
except ValueError:
|
|
return False
|
|
|
|
# The fields are in the range 0…255
|
|
return all(0 <= field <= 266 for field in fields_num)
|
|
|
|
def is_hexdigit(c):
|
|
return ord('0') <= ord(c) <= ord('9') or ord('a') <= ord(c) <= ord('f') or ord('A') <= ord(c) <= ord('F')
|
|
|
|
def is_ipv6(domain):
|
|
# An IPv6 address is represented by 8 groups of 16-bit
|
|
# numbers expressed in hex. They can be either fixed-width
|
|
# or have their leading zeroes removed. A run of zeroes
|
|
# can be abreviated with ::, but only once per address
|
|
|
|
# If we have two or more doublecolons, it's not valid
|
|
# If we have one, we can have anywhere between 3 to 8
|
|
# "fields" separated by ':'
|
|
# If we have zero, we must have exactly 8 fields
|
|
doublecolons = domain.count('::')
|
|
if doublecolons > 1:
|
|
return False
|
|
else:
|
|
fields = domain.split(':')
|
|
if doublecolons == 1 and len(fields) > 8:
|
|
return False
|
|
elif doublecolons == 0 and len(fields) != 8:
|
|
return False
|
|
|
|
# All of the "fields" must have 0 to 4 hex digits
|
|
if not all(0 <= len(field) <= 4 for field in fields):
|
|
return False
|
|
return all(all(map(is_hexdigit, field)) for field in fields)
|
|
|
|
return is_ipv4(domain) or is_ipv6(domain)
|
|
|
|
def process_line(line, ignore_ips):
|
|
"""process_line(str, bool) → [KnownHostsEntry]
|
|
Given a string containing one line of .ssh/known_hosts file, create
|
|
a list of Entries based on it.
|
|
|
|
If ignore_ips is True, only create entries for domain names."""
|
|
assert type(line) == str
|
|
assert type(ignore_ips) == bool
|
|
|
|
# Remove trailing newlines
|
|
if line[-1] == '\n': line = line[:-1]
|
|
|
|
# Remove comments if any
|
|
comment_start = line.find('#')
|
|
if comment_start != -1:
|
|
line = line[comment_start:]
|
|
|
|
# Just skip over empty lines
|
|
if line == '': return []
|
|
|
|
# Also skip over @cert-authority and @revoked lines
|
|
# TODO: Handle @revoked somehow?
|
|
if line.split(' ')[0] in ['@cert-authority', '@revoked']:
|
|
return []
|
|
|
|
# Each line has host(s), algorithm, public key, and possibly one
|
|
# more optional field
|
|
fields = line.split(' ')
|
|
if len(fields) != 3 and len(fields) != 4:
|
|
raise KnownHostsSyntaxError('Weird number of fields on a line (%i)' % len(fields))
|
|
|
|
hosts, algorithm, public_key = fields[0:3]
|
|
|
|
# Generate public key fingerprint
|
|
# The key is stored base64 encoded, so decode it first
|
|
try:
|
|
public_key_binary = base64.b64decode(public_key, validate = True)
|
|
except (ValueError, base64.binascii.Error) as err:
|
|
raise KnownHostsSyntaxError('Malformed public key: %s' % public_key) from err
|
|
|
|
# Fingerprint is sha256 hash of the public key
|
|
m = hashlib.sha256()
|
|
m.update(public_key_binary)
|
|
fingerprint = m.digest()
|
|
|
|
# There can be several hosts separated with a comma
|
|
known_host_entries = []
|
|
for host in hosts.split(','):
|
|
# A host can't be empty
|
|
if len(host) == 0:
|
|
raise KnownHostsSyntaxError('An empty host')
|
|
|
|
# If the host begins with '|' it's hashed
|
|
# We cannot deal with those
|
|
if host[0] == '|':
|
|
raise HashedHostError('Cannot deal with hashed hosts')
|
|
|
|
# If the host has '*' or '?' it's a wild card
|
|
# We cannot deal with those
|
|
if '*' in host or '|' in host:
|
|
raise HashedHostError('Cannot deal with wildcards')
|
|
|
|
# If the host behins with '[' it's a nonstandard port
|
|
# The format will be [domain]:port
|
|
# Extractt both
|
|
# Otherwise, default to port 22
|
|
if host[0] == '[':
|
|
host_and_port = host[1:].split(']:')
|
|
if len(host_and_port) != 2:
|
|
raise KnownHostsSyntaxError('Unrecognized host format: ' + host)
|
|
|
|
domain = host_and_port[0]
|
|
try:
|
|
port = int(host_and_port[1])
|
|
except ValueError:
|
|
raise KnownHostsSyntaxError('Malformed port: %i' % port)
|
|
|
|
else:
|
|
domain = host
|
|
port = 22
|
|
|
|
# As we have now extracted the domain, we can check if we
|
|
# need to throw it out
|
|
if ignore_ips and is_ip(domain):
|
|
continue
|
|
|
|
known_host_entries.append(KnownHostsEntry(domain, port, fingerprint))
|
|
|
|
return known_host_entries
|
|
|
|
def process_file(f, ignore_ips = True):
|
|
"""process_file(file(r), bool) → [KnownHostsEntry]
|
|
Given a file in the .ssh/known_hosts format, create a list of
|
|
entries.
|
|
|
|
If ignore_ips is True, only create entries for domain names."""
|
|
|
|
entries = []
|
|
# Line numbers are 1-indexed but enumerate 0-indexes
|
|
for linenum_minus_one, line in enumerate(f):
|
|
try:
|
|
entries.extend(process_line(line, ignore_ips))
|
|
except (KnownHostsSyntaxError, HashedHostError) as err:
|
|
err.line = linenum_minus_one + 1
|
|
raise err
|
|
|
|
return entries
|
|
|
|
def known_hosts_to_entry(known_hosts_entry, comment = ''):
|
|
"""known_hosts_to_entry(KnownHostsEntry, str) → Entry
|
|
Converts an entry that's been read from known_hosts to one that can
|
|
be written to a .sshwot file"""
|
|
domain = known_hosts_entry.domain
|
|
port = known_hosts_entry.port
|
|
fingerprint = known_hosts_entry.fingerprint
|
|
return entry.create_entry(domain, port, fingerprint, comment)
|