diff --git a/neomi.py b/neomi.py index 4961549..0572a78 100644 --- a/neomi.py +++ b/neomi.py @@ -1,5 +1,6 @@ import configparser import enum +import ipaddress import os import pathlib import select @@ -12,6 +13,7 @@ import urllib.parse class default_config: None +default_config.blacklist_file = pathlib.Path(os.environ['HOME']) / 'gopher_blacklist' default_config.charset = 'utf-8' default_config.fallback_mimetype = 'application/octet-stream' default_config.gopher_root = pathlib.Path(os.environ['HOME']) / 'gopher' @@ -33,6 +35,12 @@ def die(message, status = 1): error(message) sys.exit(status) +# log(message) +# Print a log message to stdout +def log(message): + program_name = os.path.basename(sys.argv[0]) + print('%s: %s' % (program_name, message)) + # A base for Exeptions that are used with one argument and that return a string that incorporates said argument class OneArgumentException(Exception): def __init__(self, argument): @@ -433,11 +441,11 @@ def send_header(sock, protocol, status, mimetype, *, config): elif status == Status.error: # Technically -2 means "Try again later", but there is no code for "server blew up" header = b'--2\r\n' - + elif protocol == Protocol.gopher: # Gopher has no header header = b'' - + else: unreachable() @@ -457,7 +465,7 @@ def send_binaryfile(sock, reader, protocol, *, config): left = buffer_max buffer.append(byte) - + # If there was something left in the buffer, flush it if len(buffer) != 0: sock.sendall(buffer) @@ -494,7 +502,7 @@ def send_textfile(sock, reader, protocol, *, config): # Signal end of text sock.sendall(b'.\r\n') - + else: unreachable() @@ -601,6 +609,56 @@ class Threads_controller: with self.threads_lock: self.threads_amount -= 1 +class IPParseError(OneArgumentException): + text = 'Error parsing IP: %s' + +# read_blacklist(blacklist_file) → blacklist +# Reads the contents of the blacklist file into a form usable by ip_in_ranges() +def read_blacklist(blacklist_file): + try: + file = open(str(blacklist_file), 'r') + except FileNotFoundError: + return [] + + lines = file.read().split('\n') + file.close() + + blacklist = [] + for line in lines: + # Comment handling + if '#' in line: + line = line[:line.index('#')] + + # Remove surrounding whitespace + line = line.strip() + + # If an empty line, skip + if line == '': + continue + + try: + ip_range = ipaddress.ip_network(line) + except ValueError: + raise IPParseError('Invalid format: ' + line) + + blacklist.append(ip_range) + + return blacklist + +# ip_in_ranges(ip, ip_ranges) → in_rages +# Checks whether an ip address is in given ranges +def ip_in_ranges(ip, ip_ranges): + try: + ip = ipaddress.ip_address(ip) + except ValueError: + raise IPParseError('Invalid format: ' + line) + + for ip_range in ip_ranges: + if ip in ip_range: + return True + + return False + # listen(config) → (Never returns) # Binds itself to all interfaces on designated port and listens on incoming connections # Spawns worker threads to handle the connections @@ -625,6 +683,9 @@ def listen(config): # Create a controller object for the worker threads threads_controller = Threads_controller() + # Read blacklist of addresses + blacklist = read_blacklist(config.blacklist_file) + while True: # Wait for listening sockets to get activity events = listening.poll() @@ -635,6 +696,13 @@ def listen(config): # Accept and handle the connection conn, addr = s.accept() + # Check if connection is from a blacklisted IP address + if ip_in_ranges(addr[0], blacklist): + # It was, skip event + conn.close() + log('Connection from blacklisted address %s' % addr[0]) + continue + # Set timeout for socket conn.settimeout(config.socket_timeout)