Skip to content
onion-grater 29.6 KiB
Newer Older
#!/usr/bin/python3 -u
# This filter proxy allows fine-grained access whitelists of commands
# (and their argunents) and events on a per-application basis, stored
# in:
#     /etc/onion-grater.d/
# that are pretty self-explanatory as long as you understand the Tor
# ControlPort language. The format is expressed in YAML where the
# top-level is supposed to be a list, where each element is a
# dictionary looking something like this:
#       apparmor-profiles:
#         - path_to_executable_if_that_is_the_name_of_the_apparmor_profile
#         # or
#         - explicit_apparmor_profile_name
#         - user
#         ...
#         - host
#         command:
#       confs:
#         conf:
#           - conf_arg_rule
#         event:
#           event_option: event_option_value
#           ...
# `name` (optional) is a string which gives an internal name, useful
# for debugging. When not given, filters will default to the name of
# the file (excluding extension) they were read from (so there can be
# duplicates!). It is advisable to define one filter per file, and
# give helpful filenames instead of using this field.
#
# A filter is matched if for each of the relevant qualifiers at
# least one of the elements match the client. For local (loopback)
# clients the following qualifiers are relevant:
# * `apparmor-profiles`: a list of strings, each being the name
#   of the AppArmor profile applied to the binary or script of the client,
#   with `*` matching anything. While this matcher always works for binaries,
#   it only works for scripts with an enabled AppArmor profile (not
#   necessarily enforced, complain mode is good enough).
#
# * `users`: a list of strings, each describing the user of the
#   client with `*` matching anything.
#
# For remote (non-local) clients, the following qualifiers are
# relevant:
# * hosts: a list of strings, each describing the IPv4 address
#   of the client with `*` matching anything.
#
# A filter can serve both local and remote clients by having
# qualifiers of both types.
#
# `commands` (optional) is a list where each item is a dictionary with
# the obligatory `pattern` key, which is a regular expression that is
# matched against the full argument part of the command. The default
# behavior is to just proxy the line through if matched, but it can be
# altered with these keys:
#
# * `replacement`: this rewrites the arguments. The value is a Python
#   format string (str.format()) which will be given the match groups
#   from the match of `pattern`. The rewritten command is then proxied
#   without the need to match any rule. There are also some special
#   patterns that will be replaced as follows:
#
#   - {client-address}: the client's IP address
#   - {client-port}: the client's port
#   - {server-address}: the server's IP address
#   - {server-port}: the server's (listening) port
# * `response`: a list of dictionaries, where the `pattern` and
#   `replacement` keys work exactly as for commands arguments, but now
#   for the response. Note that this means that the response is left
#   intact if `pattern` doesn't match it, and if many `pattern`:s
#   match, only the first one (in the order listed) will trigger a
#   replacement.
#
# If a simple regex (as string) is given, it is assumed to be the
# `pattern` which allows a short-hand for this common type of rule.
# Note that to allow a command to be run without arguments, the empty
# string must be explicitly given as a `pattern`. Hence, an empty
# argument list does not allow any use of the command.
# `confs` (optional) is a dictionary, and it's just syntactic sugar to
# generate GETCONF/SETCONF rules. If a key exists, GETCONF of the
# keyname is allowed, and if it has a non-empty list as value, those
# values are allowed to be set. The empty string means that resetting
# it is allowed. This is very useful for applications that like to
# SETCONF on multiple configurations at the same time.
#
# `events` (optional) is a dictionary where the key represents the
# event. If a key exists the event is allowed. The value is another
# dictionary of options:
#
# * `suppress`: a boolean determining whether we should just fool the
#   client that it has subscribed to the event (i.e. the client
#   request is not filtered) while we suppress them.
#
# * `response`: a dictionary, where the `pattern` and `replacement`
#   keys work exactly as for `response` for commands, but now for the
#   events.
#
# `restrict-stream-events` (optional) is a boolean, and if set any
# STREAM events sent to the client (after it has subscribed to them)
# will be restricted to those belonging to the client itself. This
# option only works for local clients and will be unset for remote
# clients.
import argparse
anonym's avatar
anonym committed
import psutil
import re
import socket
anonym's avatar
anonym committed
import socketserver
import stem
import stem.control
import struct
import textwrap
anonym's avatar
anonym committed
import yaml
DEFAULT_LISTEN_ADDRESS = 'localhost'
DEFAULT_COOKIE_PATH = '/run/tor/control.authcookie'
DEFAULT_CONTROL_SOCKET_PATH = '/run/tor/control'
class NoRewriteMatch(RuntimeError):
    """
    Error when no matching rewrite rule was found but one was expected.
    """
def log(msg):
    print(msg, file=sys.stderr)
    sys.stderr.flush()


def pid_of_laddr(address):
    try:
        return next(conn for conn in psutil.net_connections()
                    if conn.laddr == address).pid
    except StopIteration:
        return None


    # Here we leverage AppArmor's in-kernel solution for determining
    # the exact executable invoked. Looking at /proc/pid/exe when an
    # interpreted script is running will just point to the
    # interpreter's binary, which is not fine-grained enough, but
    # AppArmor will be aware of which script is running for processes
    # using one of its profiles. However, we fallback to /proc/pid/exe
    # in case there is no AppArmor profile, so the only unsupported
    # mode here is unconfined scripts.
    enabled_aa_profile_re = r'^(.+) \((?:complain|enforce)\)$'
    with open('/proc/{}/attr/current'.format(str(pid)), "rb") as fh:
        aa_profile_status = str(fh.read().strip(), 'UTF-8')
        apparmor_profile_match = re.match(enabled_aa_profile_re, aa_profile_status)
        if apparmor_profile_match:
            return apparmor_profile_match.group(1)
        else:
            return psutil.Process(pid).exe()
def get_ip_address(ifname):
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    return socket.inet_ntoa(fcntl.ioctl(
        s.fileno(),
        0x8915,  # SIOCGIFADDR
        struct.pack('256s', bytes(ifname[:15], 'utf-8'))
    )[20:24])


class FilteredControlPortProxySession:
    """
    Class used to deal with a single session, delegated from the handler
    (FilteredControlPortProxyHandler). Its main job is proxying the traffic
    between the client and the real control port, blocking or rewriting
    it as dictated by the filter rule set.
    """

    # Limit the length of a line, to prevent DoS attacks trying to
    # crash this filter proxy by sending infinitely long lines.
    MAX_LINESIZE = 10*1024

    def __init__(self, handler):
        self.allowed_commands = handler.allowed_commands
        self.allowed_events = handler.allowed_events
        self.client_address = handler.client_address
        self.client_pid = handler.client_pid
        self.controller = handler.controller
        self.debug_log = handler.debug_log
        self.filter_name = handler.filter_name
        self.restrict_stream_events = handler.restrict_stream_events
        self.rfile = handler.rfile
        self.server_address = handler.server_address
        self.wfile = handler.wfile
        self.client_streams = set()
        self.subscribed_event_listeners = []
    def debug_log_send(self, line):
        if global_args.print_responses:
            self.debug_log(line, format_multiline=True, sep=': <- ')
    def debug_log_recv(self, line):
            self.debug_log(line, format_multiline=True, sep=': -> ')
    def debug_log_rewrite(self, kind, old, new):
        if kind not in ['command', 'received event', 'response'] or \
           (kind == 'command' and not global_args.print_responses) or \
           (kind in ['received event', 'response']
            and not global_args.print_requests):
            return
        if new != old:
            old = textwrap.indent(old.strip(), ' '*4)
            new = textwrap.indent(new.strip(), ' '*4)
            self.debug_log("rewrote {}:\n{}\nto:\n{}".format(kind, old, new),
                           format_multiline=False)
    def respond(self, line, raw=False):
        if line.isspace():
            return
        self.debug_log_send(line)
        self.wfile.write(bytes(line, 'ascii'))
        if not raw:
            self.wfile.write(bytes("\r\n", 'ascii'))
        self.wfile.flush()
    def get_rule(self, cmd, arg_str):
        allowed_args = self.allowed_commands.get(cmd, [])
        return next((rule for rule in allowed_args
                     if re.match(rule['pattern'] + "$", arg_str)), None)
    def proxy_line(self, line, args_rewriter=None, response_rewriter=None):
        if args_rewriter:
            new_line = args_rewriter(line)
            self.debug_log_rewrite('command', line, new_line)
        response = self.controller.msg(line.strip()).raw_content()
            new_response = response_rewriter(response)
            self.debug_log_rewrite('response', response, new_response)
            response = new_response
        self.respond(response, raw=True)
    def filter_line(self, line):
        self.debug_log("command filtered: {}".format(line))
        self.respond("510 Command filtered")

    def rewrite_line(self, replacers, line):
        builtin_replacers = {
            'client-address': self.client_address[0],
            'client-port':    str(self.client_address[1]),
            'server-address': self.server_address[0],
            'server-port':    str(self.server_address[1]),
        terminator = ''
        if line[-2:] == "\r\n":
            terminator = "\r\n"
            line = line[:-2]
        for r in replacers:
            match = re.match(r['pattern'] + "$", line)
            if match:
                return r['replacement'].format(
                    *match.groups(), **builtin_replacers
                ) + terminator
    def rewrite_matched_line(self, replacers, line):
            return self.rewrite_line(replacers, line)
    def rewrite_matched_lines(self, replacers, lines):
        split_lines = lines.strip().split("\r\n")
        return "\r\n".join([self.rewrite_matched_line(replacers, line)
                            for line in split_lines]) + "\r\n"
    def event_cb(self, event, event_rewriter=None):
        if self.restrict_stream_events and \
           isinstance(event, stem.response.events.StreamEvent) and \
           not global_args.disable_filtering:
            if event.id not in self.client_streams:
                if event.status in [stem.StreamStatus.NEW,
                                    stem.StreamStatus.NEWRESOLVE] and \
                   self.client_pid == pid_of_laddr((event.source_address,
                                                    event.source_port)):
                    self.client_streams.add(event.id)
                else:
                    return
            elif event.status in [stem.StreamStatus.FAILED,
                                  stem.StreamStatus.CLOSED]:
                self.client_streams.remove(event.id)
        raw_event_content = event.raw_content()
        if event_rewriter:
            new_raw_event_content = event_rewriter(raw_event_content)
            self.debug_log_rewrite(
                'received event', raw_event_content, new_raw_event_content
            )
            raw_event_content = new_raw_event_content
            if raw_event_content.strip() == '':
                return
        self.respond(raw_event_content, raw=True)
    def update_event_subscriptions(self, events):
        for listener, event in self.subscribed_event_listeners:
            if event not in events:
                self.controller.remove_event_listener(listener)
                self.subscribed_event_listeners.remove((listener, event))
                if global_args.print_responses:
                    self.debug_log("unsubscribed from event '{}'".format(event))
        for event in events:
            if any(event == event_ for _, event_ in self.subscribed_event_listeners):
                if global_args.print_responses:
                    self.debug_log("already subscribed to event '{}'"
                                   .format(event))
            rule = self.allowed_events.get(event, {}) or {}
            if not rule.get('suppress', False) or \
               global_args.disable_filtering:
                event_rewriter = None
                if 'response' in rule:
                    replacers = rule['response']
                    def _event_rewriter(line):
                        return self.rewrite_matched_line(replacers, line)
                    event_rewriter = _event_rewriter
                def _event_cb(event):
                    self.event_cb(event, event_rewriter=event_rewriter)
                self.controller.add_event_listener(
                    _event_cb, getattr(stem.control.EventType, event)
                )
                self.subscribed_event_listeners.append((_event_cb, event))
                if global_args.print_responses:
                    self.debug_log("subscribed to event '{}'".format(event))
            else:
                if global_args.print_responses:
                    self.debug_log("suppressed subscription to event '{}'"
        self.respond("250 OK")
    def handle(self):
        while True:
            binary_line = self.rfile.readline(self.MAX_LINESIZE)
            if binary_line == b'':
                # Deal with clients that close the socket without a QUIT.
                break
            line = str(binary_line, 'ascii')
            if line.isspace():
                self.debug_log('ignoring received empty (or whitespace-only) '
                               + 'line')
                continue
            match = re.match(
                r'(?P<cmd>\S+)(?P<cmd_arg_sep>\s*)(?P<arg_str>[^\r\n]*)\r?\n$',
                line
            )
            if not match:
                self.debug_log("received bad line (escapes made explicit): " +
                               repr(line))
                # Hopefully the next line is ok...
                continue
            self.debug_log_recv(line)
            cmd         = match.group('cmd')
            cmd_arg_sep = match.group('cmd_arg_sep')
            arg_str     = match.group('arg_str')
            args = arg_str.split()
            cmd = cmd.upper()

            if cmd == "PROTOCOLINFO":
                # Stem calls PROTOCOLINFO before authenticating. Tell the
                # client that there is no authentication.
                self.respond("250-PROTOCOLINFO 1")
                self.respond("250-AUTH METHODS=NULL")
                self.respond("250-VERSION Tor=\"{}\""
                             .format(self.controller.get_version()))
                self.respond("250 OK")

            elif cmd == "AUTHENTICATE":
                # We have already authenticated, and the filtered port is
                # access-restricted according to our filter instead.
                self.respond("250 OK")

            elif cmd == "QUIT":
                self.respond("250 closing connection")
                break

            elif cmd == "SETEVENTS":
                # The control language doesn't care about case for
                # the event type.
                events = [event.upper() for event in args]
                if not global_args.disable_filtering and \
                   any(event not in self.allowed_events for event in events):
                    self.filter_line(line)
                else:
                    self.update_event_subscriptions(events)
                rule = self.get_rule(cmd, arg_str)
                if rule is None and global_args.disable_filtering:
                if rule is not None:
                    args_rewriter = None
                    response_rewriter = None

                    if 'response' in rule:
                        def _response_rewriter(lines):
                            return self.rewrite_matched_lines(rule['response'],
                                                              lines)
                        response_rewriter = _response_rewriter

                    if 'replacement' in rule:
                        def _args_rewriter(line):
                            # We also want to match the command in `line`
                            # and add it back to the replacement string.
                            # We make sure to keep the exact white spaces
                            # separating the command and arguments, to not
                            # rewrite the line unnecessarily.
                            prefix = cmd + cmd_arg_sep
                            replacer = {
                                'pattern':     prefix + rule['pattern'],
                                'replacement': prefix + rule['replacement']
                            }
                            return self.rewrite_line([replacer], line)
                        args_rewriter = _args_rewriter

                    self.proxy_line(line, args_rewriter=args_rewriter,
                                    response_rewriter=response_rewriter)
                else:
class FilteredControlPortProxyHandler(socketserver.StreamRequestHandler):
    """
    Class handing each control port connection and collecting information
    about the origin and using it to find a matching filter rule set. It
    then delegates the session handling (the actual filtering) to a
    FilteredControlPortProxySession object.
    """

    def debug_log(self, line, format_multiline=False, sep=': '):
        line = line.strip()
        if format_multiline and "\n" in line:
            sep += "(multi-line)\n"
            line = textwrap.indent(line, ' '*4)
        log(self.client_desc + sep + line)

    def setup(self):
        super(type(self), self).setup()
        self.allowed_commands = {}
        self.allowed_events = {}
        self.client_desc = None
        self.client_pid = None
        self.client_streams = set()
        self.controller = None
        self.filter_name = None
        self.restrict_stream_events = False
        self.server_address = self.server.server_address
        self.subscribed_event_listeners = []
        for filter_file in glob.glob('/etc/onion-grater.d/*.yml'):
            try:
                with open(filter_file, "rb") as fh:
                    filters = yaml.safe_load(fh.read())
                    name = re.sub(r'\.yml$', '', os.path.basename(filter_file))
                    for filter_ in filters:
                        if name not in filter_:
                            filter_['name'] = name
                    self.filters += filters
            except (yaml.parser.ParserError, yaml.scanner.ScannerError) as err:
                log("filter '{}' has bad YAML and was not loaded: {}"
                    .format(filter_file, str(err)))
    def add_allowed_commands(self, commands):
        for cmd in commands:
            allowed_args = commands[cmd]
            # An empty argument list allows nothing, but will
            # make some code below easier than if it can be
            # None as well.
            if allowed_args is None:
                allowed_args = []
            for i in range(len(allowed_args)):
                if isinstance(allowed_args[i], str):
                    allowed_args[i] = {'pattern': allowed_args[i]}
            self.allowed_commands[cmd.upper()] = allowed_args
    def add_allowed_confs_commands(self, confs):
        combined_getconf_rule = {'pattern': "(" + "|".join([
            key for key in confs]) + ")"}
        setconf_reset_part = "\s*|\s*".join([
            key for key in confs
            if isinstance(confs[key], list) and '' in confs[key]]
        )
        setconf_assignment_part = "\s*|\s*".join([
            "{}=({})".format(
                key, "|".join(confs[key])
            )
            for key in confs
            if isinstance(confs[key], list) and len(confs[key]) > 0])
        setconf_parts = []
        for part in [setconf_reset_part, setconf_assignment_part]:
            if part and part != '':
                setconf_parts.append(part)
        combined_setconf_rule = {
            'pattern': "({})+".format("\s*|\s*".join(setconf_parts))
        }
        for cmd, rule in [('GETCONF', combined_getconf_rule),
                          ('SETCONF', combined_setconf_rule)]:
            if rule['pattern'] != "()+":
                if cmd not in self.allowed_commands:
                    self.allowed_commands[cmd] = []
                self.allowed_commands[cmd].append(rule)
    def add_allowed_events(self, events):
        for event in events:
            opts = events[event]
            # Same as for the `commands` argument list, let's
            # add an empty dict to simplify later code.
            self.allowed_events[event.upper()] = opts

    def match_and_parse_filter(self, matchers):
        matched_filters = [filter_ for filter_ in self.filters
                           if all(any(val == expected_val or val == '*'
                                      for val in filter_.get(key, []))
                                  for key, expected_val in matchers)]
        if len(matched_filters) == 0:
        elif len(matched_filters) > 1:
            raise RuntimeError('multiple filters matched: ' +
                               ', '.join(matched_filters))
        matched_filter = matched_filters[0]
        self.filter_name = matched_filter['name']
        commands = matched_filter.get('commands', {}) or {}
        self.add_allowed_commands(commands)
        confs = matched_filter.get('confs', {}) or {}
        self.add_allowed_confs_commands(confs)
        events = matched_filter.get('events', {}) or {}
        self.add_allowed_events(events)
        self.restrict_stream_events = bool(matched_filter.get(
            'restrict-stream-events', False
        ))

    def connect_to_real_control_port(self):
        tries = 0
        # If tor isn't running this would just loop endlessly as fast
        # as possible, so let's rate limit it so it at least cannot
        # become a performance issue.
        while not controller:
            if tries >= 3:
                time.sleep(1)
            controller = stem.connection.connect(control_socket=global_args.control_socket_path)
        stem.connection.authenticate_cookie(controller, cookie_path=global_args.control_cookie_path)
anonym's avatar
anonym committed
        return controller
        client_host = self.client_address[0]
        local_connection = ipaddress.ip_address(client_host).is_loopback
        if local_connection:
            self.client_pid = pid_of_laddr(self.client_address)
            # Deal with the race between looking up the PID, and the
            # client being killed before we find the PID.
            if not self.client_pid:
                return
            client_apparmor_profile = apparmor_profile_of_pid(self.client_pid)
            client_user = psutil.Process(self.client_pid).username()
            matchers = [
                ('apparmor-profiles', client_apparmor_profile),
                ('users',             client_user),
            self.client_pid = None
            matchers = [
                ('hosts', client_host),
        self.match_and_parse_filter(matchers)
            self.client_desc = '{aa_profile} (pid: {pid}, user: {user}, ' \
                               'port: {port}, filter: {filter_name})'.format(
                                   pid=self.client_pid,
                                   user=client_user,
                                   port=self.client_address[1],
                                   filter_name=self.filter_name
            self.client_desc = '{1}:{2} (filter: {0})'.format(
                self.filter_name, *self.client_address
        if self.restrict_stream_events and not local_connection:
            self.debug_log(
                "filter '{}' has `restrict-stream-events` set "
                "and we are remote so the option was disabled"
                .format(self.filter_name)
            self.restrict_stream_events = False

        if len(self.filters) == 0:
            status = 'no matching filter found, using an empty one'
        else:
            status = 'loaded filter: {}'.format(self.filter_name)
        log('{} connected: {}'.format(self.client_desc, status))
            log('Final rules:')
                'commands': self.allowed_commands,
                'events': self.allowed_events,
                'restrict-stream-events': self.restrict_stream_events,
        disconnect_reason = "client quit"
            self.controller = self.connect_to_real_control_port()
            session = FilteredControlPortProxySession(self)
            session.handle()
        except (ConnectionResetError, BrokenPipeError) as err:
            # Handle clients disconnecting abruptly
            disconnect_reason = str(err)
        except stem.SocketError:
            # Handle client closing its socket abruptly
            disconnect_reason = "Client closed its socket"
        except stem.SocketClosed:
            # Handle Tor closing its socket abruptly
            disconnect_reason = "Tor closed its socket"
            if self.controller:
                self.controller.close()
            log('{} disconnected: {}'.format(self.client_desc,
                                             disconnect_reason))
class FilteredControlPortProxy(socketserver.ThreadingTCPServer):
    """
    Simple subclass just setting some defaults differently.
    """

    # So we can restart when the listening port if in TIME_WAIT state
    # after an abrupt shutdown.
    allow_reuse_address = True
    # So all server threads immediately quit when the main thread
    # quits.
    daemon_threads = True


    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--listen-address",
        type=str, metavar='ADDR', default=DEFAULT_LISTEN_ADDRESS,
        help="specifies the address on which the server listens " +
             "(default: {})".format(DEFAULT_LISTEN_ADDRESS)
    )
    parser.add_argument(
        "--listen-port",
        type=int, metavar='PORT', default=DEFAULT_LISTEN_PORT,
        help="specifies the port on which the server listens " +
             "(default: {})".format(DEFAULT_LISTEN_PORT)
    )
    parser.add_argument(
        "--listen-interface",
        type=str, metavar='INTERFACE',
        help="specifies the interface on which the server listens " +
             "(default: NULL)"
    )
    parser.add_argument(
        "--control-cookie-path",
        type=str, metavar='PATH', default=DEFAULT_COOKIE_PATH,
        help="specifies the path to Tor's control authentication cookie " +
             "(default: {})".format(DEFAULT_COOKIE_PATH)
    )
    parser.add_argument(
        "--control-socket-path",
        type=str, metavar='PATH', default=DEFAULT_CONTROL_SOCKET_PATH,
        help="specifies the path to Tor's control socket " +
             "(default: {})".format(DEFAULT_CONTROL_SOCKET_PATH)
    )
    parser.add_argument(
        "--complain",
        action='store_true', default=False,
        help="disables all filtering and just prints the commands sent " +
    parser.add_argument(
        "--debug",
        action='store_true', default=False,
        help="prints all requests and responses"
    )
    # We put the argparse results in the global scope since it's
    # awkward to extend socketserver so additional data can be sent to
    # the request handler, where we need access to the arguments.
    global global_args
    global_args = parser.parse_args()
    # Deal with overlapping functionality between arguments
    global_args.__dict__['disable_filtering'] = global_args.complain
    global_args.__dict__['print_requests'] = global_args.complain or \
                                             global_args.debug
    global_args.__dict__['print_responses'] = global_args.debug
    if global_args.listen_interface:
        ip_address = get_ip_address(global_args.listen_interface)
        if global_args.debug:
            log("IP address for interface {} : {}".format(
                global_args.listen_interface,ip_address))
    else:
        ip_address = global_args.listen_address
    address = (ip_address, global_args.listen_port)
anonym's avatar
anonym committed
    server = FilteredControlPortProxy(address, FilteredControlPortProxyHandler)
    log("Tor control port filter started, listening on {}:{}".format(*address))
    try:
        server.serve_forever()
    except KeyboardInterrupt:
        pass


if __name__ == "__main__":