Add blacklist feature

This commit is contained in:
Juhani Haverinen 2016-08-05 16:18:33 +03:00
parent 3adb47efde
commit b10d070701
1 changed files with 72 additions and 4 deletions

View File

@ -1,5 +1,6 @@
import configparser
import enum
import ipaddress
import os
import pathlib
import select
@ -12,6 +13,7 @@ import urllib.parse
class default_config: None
default_config.blacklist_file = pathlib.Path(os.environ['HOME']) / 'gopher_blacklist'
default_config.charset = 'utf-8'
default_config.fallback_mimetype = 'application/octet-stream'
default_config.gopher_root = pathlib.Path(os.environ['HOME']) / 'gopher'
@ -33,6 +35,12 @@ def die(message, status = 1):
error(message)
sys.exit(status)
# log(message)
# Print a log message to stdout
def log(message):
program_name = os.path.basename(sys.argv[0])
print('%s: %s' % (program_name, message))
# A base for Exeptions that are used with one argument and that return a string that incorporates said argument
class OneArgumentException(Exception):
def __init__(self, argument):
@ -433,11 +441,11 @@ def send_header(sock, protocol, status, mimetype, *, config):
elif status == Status.error:
# Technically -2 means "Try again later", but there is no code for "server blew up"
header = b'--2\r\n'
elif protocol == Protocol.gopher:
# Gopher has no header
header = b''
else:
unreachable()
@ -457,7 +465,7 @@ def send_binaryfile(sock, reader, protocol, *, config):
left = buffer_max
buffer.append(byte)
# If there was something left in the buffer, flush it
if len(buffer) != 0:
sock.sendall(buffer)
@ -494,7 +502,7 @@ def send_textfile(sock, reader, protocol, *, config):
# Signal end of text
sock.sendall(b'.\r\n')
else:
unreachable()
@ -601,6 +609,56 @@ class Threads_controller:
with self.threads_lock:
self.threads_amount -= 1
class IPParseError(OneArgumentException):
text = 'Error parsing IP: %s'
# read_blacklist(blacklist_file) → blacklist
# Reads the contents of the blacklist file into a form usable by ip_in_ranges()
def read_blacklist(blacklist_file):
try:
file = open(str(blacklist_file), 'r')
except FileNotFoundError:
return []
lines = file.read().split('\n')
file.close()
blacklist = []
for line in lines:
# Comment handling
if '#' in line:
line = line[:line.index('#')]
# Remove surrounding whitespace
line = line.strip()
# If an empty line, skip
if line == '':
continue
try:
ip_range = ipaddress.ip_network(line)
except ValueError:
raise IPParseError('Invalid format: ' + line)
blacklist.append(ip_range)
return blacklist
# ip_in_ranges(ip, ip_ranges) → in_rages
# Checks whether an ip address is in given ranges
def ip_in_ranges(ip, ip_ranges):
try:
ip = ipaddress.ip_address(ip)
except ValueError:
raise IPParseError('Invalid format: ' + line)
for ip_range in ip_ranges:
if ip in ip_range:
return True
return False
# listen(config) → (Never returns)
# Binds itself to all interfaces on designated port and listens on incoming connections
# Spawns worker threads to handle the connections
@ -625,6 +683,9 @@ def listen(config):
# Create a controller object for the worker threads
threads_controller = Threads_controller()
# Read blacklist of addresses
blacklist = read_blacklist(config.blacklist_file)
while True:
# Wait for listening sockets to get activity
events = listening.poll()
@ -635,6 +696,13 @@ def listen(config):
# Accept and handle the connection
conn, addr = s.accept()
# Check if connection is from a blacklisted IP address
if ip_in_ranges(addr[0], blacklist):
# It was, skip event
conn.close()
log('Connection from blacklisted address %s' % addr[0])
continue
# Set timeout for socket
conn.settimeout(config.socket_timeout)