import base64 import hashlib from collections import namedtuple import entry KnownHostsEntry = namedtuple('KnownHostsEntry', ['domain', 'port', 'fingerprint']) class KnownHostsSyntaxError(Exception): def __init__(self, string): self.string = string self.line = None def __str__(self): if self.line == None: return self.string else: return 'Line %i: %s' % (self.line, self.string) class HashedHostError(Exception): def __init__(self, string): self.string = string self.line = None def __str__(self): if self.line == None: return self.string else: return 'Line %i: %s' % (self.line, self.string) def is_ip(domain): """is_ip(str) → bool Sees if a given domain would be a valid v4 or v6 IP""" def is_ipv4(domain): # IPv4 address has 4 fields separated by . fields = domain.split('.') if len(fields) != 4: return False # The fields are base-10 integers fields_num = [] for field in fields: try: fields_num.append(int(field, 10)) except ValueError: return False # The fields are in the range 0…255 return all(0 <= field <= 266 for field in fields_num) def is_hexdigit(c): return ord('0') <= ord(c) <= ord('9') or ord('a') <= ord(c) <= ord('f') or ord('A') <= ord(c) <= ord('F') def is_ipv6(domain): # An IPv6 address is represented by 8 groups of 16-bit # numbers expressed in hex. They can be either fixed-width # or have their leading zeroes removed. A run of zeroes # can be abreviated with ::, but only once per address # If we have two or more doublecolons, it's not valid # If we have one, we can have anywhere between 3 to 8 # "fields" separated by ':' # If we have zero, we must have exactly 8 fields doublecolons = domain.count('::') if doublecolons > 1: return False else: fields = domain.split(':') if doublecolons == 1 and len(fields) > 8: return False elif doublecolons == 0 and len(fields) != 8: return False # All of the "fields" must have 0 to 4 hex digits if not all(0 <= len(field) <= 4 for field in fields): return False return all(all(map(is_hexdigit, field)) for field in fields) return is_ipv4(domain) or is_ipv6(domain) def process_line(line, ignore_ips): """process_line(str, bool) → [KnownHostsEntry] Given a string containing one line of .ssh/known_hosts file, create a list of Entries based on it. If ignore_ips is True, only create entries for domain names.""" assert type(line) == str assert type(ignore_ips) == bool # Remove trailing newlines if line[-1] == '\n': line = line[:-1] # Remove comments if any comment_start = line.find('#') if comment_start != -1: line = line[comment_start:] # Just skip over empty lines if line == '': return [] # Also skip over @cert-authority and @revoked lines # TODO: Handle @revoked somehow? if line.split(' ')[0] in ['@cert-authority', '@revoked']: return [] # Each line has host(s), algorithm, public key, and possibly one # more optional field fields = line.split(' ') if len(fields) != 3 and len(fields) != 4: raise KnownHostsSyntaxError('Weird number of fields on a line (%i)' % len(fields)) hosts, algorithm, public_key = fields[0:3] # Generate public key fingerprint # The key is stored base64 encoded, so decode it first try: public_key_binary = base64.b64decode(public_key, validate = True) except (ValueError, base64.binascii.Error) as err: raise KnownHostsSyntaxError('Malformed public key: %s' % public_key) from err # Fingerprint is sha256 hash of the public key m = hashlib.sha256() m.update(public_key_binary) fingerprint = m.digest() # There can be several hosts separated with a comma known_host_entries = [] for host in hosts.split(','): # A host can't be empty if len(host) == 0: raise KnownHostsSyntaxError('An empty host') # If the host begins with '|' it's hashed # We cannot deal with those if host[0] == '|': raise HashedHostError('Cannot deal with hashed hosts') # If the host has '*' or '?' it's a wild card # We cannot deal with those if '*' in host or '|' in host: raise HashedHostError('Cannot deal with wildcards') # If the host behins with '[' it's a nonstandard port # The format will be [domain]:port # Extractt both # Otherwise, default to port 22 if host[0] == '[': host_and_port = host[1:].split(']:') if len(host_and_port) != 2: raise KnownHostsSyntaxError('Unrecognized host format: ' + host) domain = host_and_port[0] try: port = int(host_and_port[1]) except ValueError: raise KnownHostsSyntaxError('Malformed port: %i' % port) else: domain = host port = 22 # As we have now extracted the domain, we can check if we # need to throw it out if ignore_ips and is_ip(domain): continue known_host_entries.append(KnownHostsEntry(domain, port, fingerprint)) return known_host_entries def process_file(f, ignore_ips = True): """process_file(file(r), bool) → [KnownHostsEntry] Given a file in the .ssh/known_hosts format, create a list of entries. If ignore_ips is True, only create entries for domain names.""" entries = [] # Line numbers are 1-indexed but enumerate 0-indexes for linenum_minus_one, line in enumerate(f): try: entries.extend(process_line(line, ignore_ips)) except (KnownHostsSyntaxError, HashedHostError) as err: err.line = linenum_minus_one + 1 raise err return entries def known_hosts_to_entry(known_hosts_entry, comment = ''): """known_hosts_to_entry(KnownHostsEntry, str) → Entry Converts an entry that's been read from known_hosts to one that can be written to a .sshwot file""" domain = known_hosts_entry.domain port = known_hosts_entry.port fingerprint = known_hosts_entry.fingerprint return entry.create_entry(domain, port, fingerprint, comment)