onion-grater 29.6 KB
Newer Older
1
#!/usr/bin/python3 -u
2

3
# This filter proxy allows fine-grained access whitelists of commands
4 5
# (and their argunents) and events on a per-application basis, stored
# in:
6
#
7
#     /etc/onion-grater.d/
8
#
9 10
# that are pretty self-explanatory as long as you understand the Tor
# ControlPort language. The format is expressed in YAML where the
anonym's avatar
anonym committed
11 12
# top-level is supposed to be a list, where each element is a
# dictionary looking something like this:
13
#
14
#     - name: blabla
15 16 17 18
#       apparmor-profiles:
#         - path_to_executable_if_that_is_the_name_of_the_apparmor_profile
#         # or
#         - explicit_apparmor_profile_name
19
#         ...
20
#       users:
anonym's avatar
anonym committed
21 22
#         - user
#         ...
23
#       hosts:
anonym's avatar
anonym committed
24
#         - host
25 26
#         ...
#       commands:
anonym's avatar
anonym committed
27
#         command:
28
#           - command_arg_rule
29 30
#           ...
#         ...
31 32 33
#       confs:
#         conf:
#           - conf_arg_rule
34
#           ...
35
#         ...
36
#       events:
anonym's avatar
anonym committed
37 38 39
#         event:
#           event_option: event_option_value
#           ...
40 41
#         ...
#
42 43 44 45 46 47
# `name` (optional) is a string which gives an internal name, useful
# for debugging. When not given, filters will default to the name of
# the file (excluding extension) they were read from (so there can be
# duplicates!). It is advisable to define one filter per file, and
# give helpful filenames instead of using this field.
#
48
# A filter is matched if for each of the relevant qualifiers at
anonym's avatar
anonym committed
49
# least one of the elements match the client. For local (loopback)
50
# clients the following qualifiers are relevant:
anonym's avatar
anonym committed
51
#
52 53 54 55
# * `apparmor-profiles`: a list of strings, each being the name
#   of the AppArmor profile applied to the binary or script of the client,
#   with `*` matching anything. While this matcher always works for binaries,
#   it only works for scripts with an enabled AppArmor profile (not
anonym's avatar
anonym committed
56 57
#   necessarily enforced, complain mode is good enough).
#
58
# * `users`: a list of strings, each describing the user of the
anonym's avatar
anonym committed
59 60
#   client with `*` matching anything.
#
61 62
# For remote (non-local) clients, the following qualifiers are
# relevant:
anonym's avatar
anonym committed
63
#
64
# * hosts: a list of strings, each describing the IPv4 address
anonym's avatar
anonym committed
65 66
#   of the client with `*` matching anything.
#
67 68
# A filter can serve both local and remote clients by having
# qualifiers of both types.
anonym's avatar
anonym committed
69 70 71 72 73 74
#
# `commands` (optional) is a list where each item is a dictionary with
# the obligatory `pattern` key, which is a regular expression that is
# matched against the full argument part of the command. The default
# behavior is to just proxy the line through if matched, but it can be
# altered with these keys:
anonym's avatar
anonym committed
75 76 77
#
# * `replacement`: this rewrites the arguments. The value is a Python
#   format string (str.format()) which will be given the match groups
anonym's avatar
anonym committed
78
#   from the match of `pattern`. The rewritten command is then proxied
79 80 81 82 83 84 85
#   without the need to match any rule. There are also some special
#   patterns that will be replaced as follows:
#
#   - {client-address}: the client's IP address
#   - {client-port}: the client's port
#   - {server-address}: the server's IP address
#   - {server-port}: the server's (listening) port
anonym's avatar
anonym committed
86
#
87 88 89 90 91 92
# * `response`: a list of dictionaries, where the `pattern` and
#   `replacement` keys work exactly as for commands arguments, but now
#   for the response. Note that this means that the response is left
#   intact if `pattern` doesn't match it, and if many `pattern`:s
#   match, only the first one (in the order listed) will trigger a
#   replacement.
anonym's avatar
anonym committed
93 94 95
#
# If a simple regex (as string) is given, it is assumed to be the
# `pattern` which allows a short-hand for this common type of rule.
anonym's avatar
anonym committed
96
#
anonym's avatar
anonym committed
97 98 99
# Note that to allow a command to be run without arguments, the empty
# string must be explicitly given as a `pattern`. Hence, an empty
# argument list does not allow any use of the command.
100
#
101 102 103
# `confs` (optional) is a dictionary, and it's just syntactic sugar to
# generate GETCONF/SETCONF rules. If a key exists, GETCONF of the
# keyname is allowed, and if it has a non-empty list as value, those
104
# values are allowed to be set. The empty string means that resetting
105 106 107
# it is allowed. This is very useful for applications that like to
# SETCONF on multiple configurations at the same time.
#
anonym's avatar
anonym committed
108
# `events` (optional) is a dictionary where the key represents the
109
# event. If a key exists the event is allowed. The value is another
anonym's avatar
anonym committed
110 111 112 113 114 115
# dictionary of options:
#
# * `suppress`: a boolean determining whether we should just fool the
#   client that it has subscribed to the event (i.e. the client
#   request is not filtered) while we suppress them.
#
116 117 118 119
# * `response`: a dictionary, where the `pattern` and `replacement`
#   keys work exactly as for `response` for commands, but now for the
#   events.
#
anonym's avatar
anonym committed
120 121 122
# `restrict-stream-events` (optional) is a boolean, and if set any
# STREAM events sent to the client (after it has subscribed to them)
# will be restricted to those belonging to the client itself. This
123 124
# option only works for local clients and will be unset for remote
# clients.
125

anonym's avatar
anonym committed
126
import argparse
127
import fcntl
128
import glob
129
import ipaddress
130
import os.path
anonym's avatar
anonym committed
131 132
import psutil
import re
133
import socket
anonym's avatar
anonym committed
134
import socketserver
135 136
import stem
import stem.control
137
import stem.connection
138
import struct
139
import sys
140
import textwrap
141
import time
anonym's avatar
anonym committed
142
import yaml
143

144
DEFAULT_LISTEN_ADDRESS = 'localhost'
145
DEFAULT_LISTEN_PORT = 9051
146 147
DEFAULT_COOKIE_PATH = '/run/tor/control.authcookie'
DEFAULT_CONTROL_SOCKET_PATH = '/run/tor/control'
anonym's avatar
anonym committed
148

149

150
class NoRewriteMatch(RuntimeError):
151 152 153
    """
    Error when no matching rewrite rule was found but one was expected.
    """
154 155
    pass

156

157 158 159 160 161
def log(msg):
    print(msg, file=sys.stderr)
    sys.stderr.flush()


162 163
def pid_of_laddr(address):
    try:
164
        return next(conn for conn in psutil.net_connections()
165 166 167 168 169
                    if conn.laddr == address).pid
    except StopIteration:
        return None


170
def apparmor_profile_of_pid(pid):
171 172 173 174 175
    # Here we leverage AppArmor's in-kernel solution for determining
    # the exact executable invoked. Looking at /proc/pid/exe when an
    # interpreted script is running will just point to the
    # interpreter's binary, which is not fine-grained enough, but
    # AppArmor will be aware of which script is running for processes
176 177 178
    # using one of its profiles. However, we fallback to /proc/pid/exe
    # in case there is no AppArmor profile, so the only unsupported
    # mode here is unconfined scripts.
179
    enabled_aa_profile_re = r'^(.+) \((?:complain|enforce)\)$'
180 181
    with open('/proc/{}/attr/current'.format(str(pid)), "rb") as fh:
        aa_profile_status = str(fh.read().strip(), 'UTF-8')
182 183 184
        apparmor_profile_match = re.match(enabled_aa_profile_re, aa_profile_status)
        if apparmor_profile_match:
            return apparmor_profile_match.group(1)
185 186
        else:
            return psutil.Process(pid).exe()
187 188


189 190 191 192 193 194 195 196 197
def get_ip_address(ifname):
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    return socket.inet_ntoa(fcntl.ioctl(
        s.fileno(),
        0x8915,  # SIOCGIFADDR
        struct.pack('256s', bytes(ifname[:15], 'utf-8'))
    )[20:24])


198
class FilteredControlPortProxySession:
199 200 201 202 203 204 205
    """
    Class used to deal with a single session, delegated from the handler
    (FilteredControlPortProxyHandler). Its main job is proxying the traffic
    between the client and the real control port, blocking or rewriting
    it as dictated by the filter rule set.
    """

206 207
    # Limit the length of a line, to prevent DoS attacks trying to
    # crash this filter proxy by sending infinitely long lines.
208
    MAX_LINESIZE = 10*1024
209 210 211 212

    def __init__(self, handler):
        self.allowed_commands = handler.allowed_commands
        self.allowed_events = handler.allowed_events
213
        self.client_address = handler.client_address
214 215 216 217 218 219 220 221 222 223
        self.client_pid = handler.client_pid
        self.controller = handler.controller
        self.debug_log = handler.debug_log
        self.filter_name = handler.filter_name
        self.restrict_stream_events = handler.restrict_stream_events
        self.rfile = handler.rfile
        self.server_address = handler.server_address
        self.wfile = handler.wfile
        self.client_streams = set()
        self.subscribed_event_listeners = []
224

225
    def debug_log_send(self, line):
226
        if global_args.print_responses:
227
            self.debug_log(line, format_multiline=True, sep=': <- ')
228

229
    def debug_log_recv(self, line):
230
        if global_args.print_requests:
231
            self.debug_log(line, format_multiline=True, sep=': -> ')
232

233
    def debug_log_rewrite(self, kind, old, new):
234 235
        if kind not in ['command', 'received event', 'response'] or \
           (kind == 'command' and not global_args.print_responses) or \
236 237
           (kind in ['received event', 'response']
            and not global_args.print_requests):
238 239 240 241
            return
        if new != old:
            old = textwrap.indent(old.strip(), ' '*4)
            new = textwrap.indent(new.strip(), ' '*4)
242 243
            self.debug_log("rewrote {}:\n{}\nto:\n{}".format(kind, old, new),
                           format_multiline=False)
244

245
    def respond(self, line, raw=False):
246 247
        if line.isspace():
            return
248 249
        self.debug_log_send(line)
        self.wfile.write(bytes(line, 'ascii'))
250 251
        if not raw:
            self.wfile.write(bytes("\r\n", 'ascii'))
252
        self.wfile.flush()
253

254 255
    def get_rule(self, cmd, arg_str):
        allowed_args = self.allowed_commands.get(cmd, [])
256
        return next((rule for rule in allowed_args
257
                     if re.match(rule['pattern'] + "$", arg_str)), None)
258

259
    def proxy_line(self, line, args_rewriter=None, response_rewriter=None):
260 261
        if args_rewriter:
            new_line = args_rewriter(line)
262
            self.debug_log_rewrite('command', line, new_line)
263
            line = new_line
264
        response = self.controller.msg(line.strip()).raw_content()
265
        if response_rewriter:
266
            new_response = response_rewriter(response)
267
            self.debug_log_rewrite('response', response, new_response)
268
            response = new_response
269
        self.respond(response, raw=True)
270

271 272 273 274
    def filter_line(self, line):
        self.debug_log("command filtered: {}".format(line))
        self.respond("510 Command filtered")

275
    def rewrite_line(self, replacers, line):
276
        builtin_replacers = {
277 278 279 280
            'client-address': self.client_address[0],
            'client-port':    str(self.client_address[1]),
            'server-address': self.server_address[0],
            'server-port':    str(self.server_address[1]),
281
        }
282 283 284 285
        terminator = ''
        if line[-2:] == "\r\n":
            terminator = "\r\n"
            line = line[:-2]
286 287 288
        for r in replacers:
            match = re.match(r['pattern'] + "$", line)
            if match:
289 290 291
                return r['replacement'].format(
                    *match.groups(), **builtin_replacers
                ) + terminator
292
        raise NoRewriteMatch()
293

294
    def rewrite_matched_line(self, replacers, line):
295
        try:
296
            return self.rewrite_line(replacers, line)
297
        except NoRewriteMatch:
298 299
            return line

300
    def rewrite_matched_lines(self, replacers, lines):
301
        split_lines = lines.strip().split("\r\n")
302
        return "\r\n".join([self.rewrite_matched_line(replacers, line)
303
                            for line in split_lines]) + "\r\n"
304

305
    def event_cb(self, event, event_rewriter=None):
306
        if self.restrict_stream_events and \
307
           isinstance(event, stem.response.events.StreamEvent) and \
308
           not global_args.disable_filtering:
309
            if event.id not in self.client_streams:
310 311
                if event.status in [stem.StreamStatus.NEW,
                                    stem.StreamStatus.NEWRESOLVE] and \
312 313 314
                   self.client_pid == pid_of_laddr((event.source_address,
                                                    event.source_port)):
                    self.client_streams.add(event.id)
315 316 317 318
                else:
                    return
            elif event.status in [stem.StreamStatus.FAILED,
                                  stem.StreamStatus.CLOSED]:
319
                self.client_streams.remove(event.id)
320 321
        raw_event_content = event.raw_content()
        if event_rewriter:
322
            new_raw_event_content = event_rewriter(raw_event_content)
323
            self.debug_log_rewrite(
324 325 326
                'received event', raw_event_content, new_raw_event_content
            )
            raw_event_content = new_raw_event_content
327 328
            if raw_event_content.strip() == '':
                return
329
        self.respond(raw_event_content, raw=True)
330

331 332
    def update_event_subscriptions(self, events):
        for listener, event in self.subscribed_event_listeners:
333
            if event not in events:
334 335
                self.controller.remove_event_listener(listener)
                self.subscribed_event_listeners.remove((listener, event))
336
                if global_args.print_responses:
337
                    self.debug_log("unsubscribed from event '{}'".format(event))
anonym's avatar
anonym committed
338
        for event in events:
339
            if any(event == event_ for _, event_ in self.subscribed_event_listeners):
340
                if global_args.print_responses:
341 342
                    self.debug_log("already subscribed to event '{}'"
                                   .format(event))
343
                continue
344
            rule = self.allowed_events.get(event, {}) or {}
anonym's avatar
anonym committed
345 346 347 348
            if not rule.get('suppress', False) or \
               global_args.disable_filtering:
                event_rewriter = None
                if 'response' in rule:
349
                    replacers = rule['response']
anonym's avatar
anonym committed
350
                    def _event_rewriter(line):
351
                        return self.rewrite_matched_line(replacers, line)
anonym's avatar
anonym committed
352 353
                    event_rewriter = _event_rewriter
                def _event_cb(event):
354 355
                    self.event_cb(event, event_rewriter=event_rewriter)
                self.controller.add_event_listener(
anonym's avatar
anonym committed
356 357
                    _event_cb, getattr(stem.control.EventType, event)
                )
358
                self.subscribed_event_listeners.append((_event_cb, event))
359
                if global_args.print_responses:
360
                    self.debug_log("subscribed to event '{}'".format(event))
361 362
            else:
                if global_args.print_responses:
363
                    self.debug_log("suppressed subscription to event '{}'"
364
                                   .format(event))
365
        self.respond("250 OK")
366

367 368 369 370 371 372 373 374
    def handle(self):
        while True:
            binary_line = self.rfile.readline(self.MAX_LINESIZE)
            if binary_line == b'':
                # Deal with clients that close the socket without a QUIT.
                break
            line = str(binary_line, 'ascii')
            if line.isspace():
375 376
                self.debug_log('ignoring received empty (or whitespace-only) '
                               + 'line')
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
                continue
            match = re.match(
                r'(?P<cmd>\S+)(?P<cmd_arg_sep>\s*)(?P<arg_str>[^\r\n]*)\r?\n$',
                line
            )
            if not match:
                self.debug_log("received bad line (escapes made explicit): " +
                               repr(line))
                # Hopefully the next line is ok...
                continue
            self.debug_log_recv(line)
            cmd         = match.group('cmd')
            cmd_arg_sep = match.group('cmd_arg_sep')
            arg_str     = match.group('arg_str')
            args = arg_str.split()
            cmd = cmd.upper()

            if cmd == "PROTOCOLINFO":
                # Stem calls PROTOCOLINFO before authenticating. Tell the
                # client that there is no authentication.
                self.respond("250-PROTOCOLINFO 1")
                self.respond("250-AUTH METHODS=NULL")
                self.respond("250-VERSION Tor=\"{}\""
                             .format(self.controller.get_version()))
                self.respond("250 OK")

            elif cmd == "AUTHENTICATE":
                # We have already authenticated, and the filtered port is
                # access-restricted according to our filter instead.
                self.respond("250 OK")

            elif cmd == "QUIT":
                self.respond("250 closing connection")
                break

            elif cmd == "SETEVENTS":
                # The control language doesn't care about case for
                # the event type.
                events = [event.upper() for event in args]
416 417 418 419 420
                if not global_args.disable_filtering and \
                   any(event not in self.allowed_events for event in events):
                    self.filter_line(line)
                else:
                    self.update_event_subscriptions(events)
421

422
            else:
423
                rule = self.get_rule(cmd, arg_str)
424
                if rule is None and global_args.disable_filtering:
425
                    rule = {}
426
                if rule is not None:
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
                    args_rewriter = None
                    response_rewriter = None

                    if 'response' in rule:
                        def _response_rewriter(lines):
                            return self.rewrite_matched_lines(rule['response'],
                                                              lines)
                        response_rewriter = _response_rewriter

                    if 'replacement' in rule:
                        def _args_rewriter(line):
                            # We also want to match the command in `line`
                            # and add it back to the replacement string.
                            # We make sure to keep the exact white spaces
                            # separating the command and arguments, to not
                            # rewrite the line unnecessarily.
                            prefix = cmd + cmd_arg_sep
                            replacer = {
                                'pattern':     prefix + rule['pattern'],
                                'replacement': prefix + rule['replacement']
                            }
                            return self.rewrite_line([replacer], line)
                        args_rewriter = _args_rewriter

                    self.proxy_line(line, args_rewriter=args_rewriter,
                                    response_rewriter=response_rewriter)
                else:
454
                    self.filter_line(line)
455 456


457
class FilteredControlPortProxyHandler(socketserver.StreamRequestHandler):
458 459 460 461 462 463 464
    """
    Class handing each control port connection and collecting information
    about the origin and using it to find a matching filter rule set. It
    then delegates the session handling (the actual filtering) to a
    FilteredControlPortProxySession object.
    """

465
    def debug_log(self, line, format_multiline=False, sep=': '):
466 467 468 469 470 471
        line = line.strip()
        if format_multiline and "\n" in line:
            sep += "(multi-line)\n"
            line = textwrap.indent(line, ' '*4)
        log(self.client_desc + sep + line)

472 473
    def setup(self):
        super(type(self), self).setup()
474 475 476 477 478 479 480
        self.allowed_commands = {}
        self.allowed_events = {}
        self.client_desc = None
        self.client_pid = None
        self.client_streams = set()
        self.controller = None
        self.filter_name = None
481
        self.filters = []
482 483 484
        self.restrict_stream_events = False
        self.server_address = self.server.server_address
        self.subscribed_event_listeners = []
485
        for filter_file in glob.glob('/etc/onion-grater.d/*.yml'):
486 487
            try:
                with open(filter_file, "rb") as fh:
488
                    filters = yaml.safe_load(fh.read())
489 490 491 492 493
                    name = re.sub(r'\.yml$', '', os.path.basename(filter_file))
                    for filter_ in filters:
                        if name not in filter_:
                            filter_['name'] = name
                    self.filters += filters
494 495 496
            except (yaml.parser.ParserError, yaml.scanner.ScannerError) as err:
                log("filter '{}' has bad YAML and was not loaded: {}"
                    .format(filter_file, str(err)))
497

498
    def add_allowed_commands(self, commands):
499 500 501 502 503
        for cmd in commands:
            allowed_args = commands[cmd]
            # An empty argument list allows nothing, but will
            # make some code below easier than if it can be
            # None as well.
504
            if allowed_args is None:
505 506 507 508
                allowed_args = []
            for i in range(len(allowed_args)):
                if isinstance(allowed_args[i], str):
                    allowed_args[i] = {'pattern': allowed_args[i]}
509
            self.allowed_commands[cmd.upper()] = allowed_args
510

511
    def add_allowed_confs_commands(self, confs):
512 513 514
        combined_getconf_rule = {'pattern': "(" + "|".join([
            key for key in confs]) + ")"}
        setconf_reset_part = "\s*|\s*".join([
515 516
            key for key in confs
            if isinstance(confs[key], list) and '' in confs[key]]
517 518 519 520
        )
        setconf_assignment_part = "\s*|\s*".join([
            "{}=({})".format(
                key, "|".join(confs[key])
521 522 523
            )
            for key in confs
            if isinstance(confs[key], list) and len(confs[key]) > 0])
524 525 526 527 528 529 530 531 532 533
        setconf_parts = []
        for part in [setconf_reset_part, setconf_assignment_part]:
            if part and part != '':
                setconf_parts.append(part)
        combined_setconf_rule = {
            'pattern': "({})+".format("\s*|\s*".join(setconf_parts))
        }
        for cmd, rule in [('GETCONF', combined_getconf_rule),
                          ('SETCONF', combined_setconf_rule)]:
            if rule['pattern'] != "()+":
534 535 536
                if cmd not in self.allowed_commands:
                    self.allowed_commands[cmd] = []
                self.allowed_commands[cmd].append(rule)
537

538
    def add_allowed_events(self, events):
539 540 541 542
        for event in events:
            opts = events[event]
            # Same as for the `commands` argument list, let's
            # add an empty dict to simplify later code.
543
            if opts is None:
544
                opts = {}
545 546 547
            self.allowed_events[event.upper()] = opts

    def match_and_parse_filter(self, matchers):
548 549 550
        matched_filters = [filter_ for filter_ in self.filters
                           if all(any(val == expected_val or val == '*'
                                      for val in filter_.get(key, []))
551 552
                                  for key, expected_val in matchers)]
        if len(matched_filters) == 0:
553
            return
554 555 556 557
        elif len(matched_filters) > 1:
            raise RuntimeError('multiple filters matched: ' +
                               ', '.join(matched_filters))
        matched_filter = matched_filters[0]
558
        self.filter_name = matched_filter['name']
559
        commands = matched_filter.get('commands', {}) or {}
560
        self.add_allowed_commands(commands)
561
        confs = matched_filter.get('confs', {}) or {}
562
        self.add_allowed_confs_commands(confs)
563
        events = matched_filter.get('events', {}) or {}
564 565
        self.add_allowed_events(events)
        self.restrict_stream_events = bool(matched_filter.get(
566 567 568
            'restrict-stream-events', False
        ))

569
    def connect_to_real_control_port(self):
570
        controller = None
571 572 573 574
        tries = 0
        # If tor isn't running this would just loop endlessly as fast
        # as possible, so let's rate limit it so it at least cannot
        # become a performance issue.
575
        while not controller:
576 577
            if tries >= 3:
                time.sleep(1)
578
            controller = stem.connection.connect(control_socket=global_args.control_socket_path)
579
            tries += 1
580
        stem.connection.authenticate_cookie(controller, cookie_path=global_args.control_cookie_path)
anonym's avatar
anonym committed
581
        return controller
582

583
    def handle(self):
584 585 586
        client_host = self.client_address[0]
        local_connection = ipaddress.ip_address(client_host).is_loopback
        if local_connection:
587
            self.client_pid = pid_of_laddr(self.client_address)
588 589
            # Deal with the race between looking up the PID, and the
            # client being killed before we find the PID.
590 591
            if not self.client_pid:
                return
592
            client_apparmor_profile = apparmor_profile_of_pid(self.client_pid)
593
            client_user = psutil.Process(self.client_pid).username()
anonym's avatar
anonym committed
594
            matchers = [
595 596
                ('apparmor-profiles', client_apparmor_profile),
                ('users',             client_user),
anonym's avatar
anonym committed
597
            ]
598
        else:
599
            self.client_pid = None
anonym's avatar
anonym committed
600
            matchers = [
601
                ('hosts', client_host),
anonym's avatar
anonym committed
602
            ]
603
        self.match_and_parse_filter(matchers)
604
        if local_connection:
605
            self.client_desc = '{aa_profile} (pid: {pid}, user: {user}, ' \
606
                               'port: {port}, filter: {filter_name})'.format(
607
                                   aa_profile=client_apparmor_profile,
608 609 610 611
                                   pid=self.client_pid,
                                   user=client_user,
                                   port=self.client_address[1],
                                   filter_name=self.filter_name
612
                               )
613
        else:
614 615
            self.client_desc = '{1}:{2} (filter: {0})'.format(
                self.filter_name, *self.client_address
616
            )
617 618
        if self.restrict_stream_events and not local_connection:
            self.debug_log(
619
                "filter '{}' has `restrict-stream-events` set "
anonym's avatar
anonym committed
620
                "and we are remote so the option was disabled"
621
                .format(self.filter_name)
anonym's avatar
anonym committed
622
            )
623 624 625 626 627 628 629
            self.restrict_stream_events = False

        if len(self.filters) == 0:
            status = 'no matching filter found, using an empty one'
        else:
            status = 'loaded filter: {}'.format(self.filter_name)
        log('{} connected: {}'.format(self.client_desc, status))
630
        if global_args.debug:
631
            log('Final rules:')
632
            log(yaml.dump({
633 634 635
                'commands': self.allowed_commands,
                'events': self.allowed_events,
                'restrict-stream-events': self.restrict_stream_events,
636
            }))
637
        disconnect_reason = "client quit"
638
        try:
639 640 641
            self.controller = self.connect_to_real_control_port()
            session = FilteredControlPortProxySession(self)
            session.handle()
642
        except (ConnectionResetError, BrokenPipeError) as err:
643
            # Handle clients disconnecting abruptly
644
            disconnect_reason = str(err)
645 646 647
        except stem.SocketError:
            # Handle client closing its socket abruptly
            disconnect_reason = "Client closed its socket"
648
        except stem.SocketClosed:
anonym's avatar
anonym committed
649
            # Handle Tor closing its socket abruptly
650
            disconnect_reason = "Tor closed its socket"
651
        finally:
652 653 654 655
            if self.controller:
                self.controller.close()
            log('{} disconnected: {}'.format(self.client_desc,
                                             disconnect_reason))
656 657


658
class FilteredControlPortProxy(socketserver.ThreadingTCPServer):
659 660 661 662
    """
    Simple subclass just setting some defaults differently.
    """

anonym's avatar
anonym committed
663
    # So we can restart when the listening port if in TIME_WAIT state
664 665 666 667 668 669 670
    # after an abrupt shutdown.
    allow_reuse_address = True
    # So all server threads immediately quit when the main thread
    # quits.
    daemon_threads = True


671
def main():
anonym's avatar
anonym committed
672
    parser = argparse.ArgumentParser()
673
    parser.add_argument(
674
        "--listen-address",
675 676
        type=str, metavar='ADDR', default=DEFAULT_LISTEN_ADDRESS,
        help="specifies the address on which the server listens " +
677 678
             "(default: {})".format(DEFAULT_LISTEN_ADDRESS)
    )
anonym's avatar
anonym committed
679
    parser.add_argument(
680
        "--listen-port",
681 682
        type=int, metavar='PORT', default=DEFAULT_LISTEN_PORT,
        help="specifies the port on which the server listens " +
683 684
             "(default: {})".format(DEFAULT_LISTEN_PORT)
    )
685 686 687 688 689 690
    parser.add_argument(
        "--listen-interface",
        type=str, metavar='INTERFACE',
        help="specifies the interface on which the server listens " +
             "(default: NULL)"
    )
anonym's avatar
anonym committed
691
    parser.add_argument(
692
        "--control-cookie-path",
693 694
        type=str, metavar='PATH', default=DEFAULT_COOKIE_PATH,
        help="specifies the path to Tor's control authentication cookie " +
695 696
             "(default: {})".format(DEFAULT_COOKIE_PATH)
    )
anonym's avatar
anonym committed
697
    parser.add_argument(
698
        "--control-socket-path",
699 700
        type=str, metavar='PATH', default=DEFAULT_CONTROL_SOCKET_PATH,
        help="specifies the path to Tor's control socket " +
701 702
             "(default: {})".format(DEFAULT_CONTROL_SOCKET_PATH)
    )
703 704
    parser.add_argument(
        "--complain",
705 706
        action='store_true', default=False,
        help="disables all filtering and just prints the commands sent " +
707 708
             "by the client"
    )
709 710 711
    parser.add_argument(
        "--debug",
        action='store_true', default=False,
712 713
        help="prints all requests and responses"
    )
714 715 716
    # We put the argparse results in the global scope since it's
    # awkward to extend socketserver so additional data can be sent to
    # the request handler, where we need access to the arguments.
anonym's avatar
anonym committed
717 718
    global global_args
    global_args = parser.parse_args()
719 720 721 722 723
    # Deal with overlapping functionality between arguments
    global_args.__dict__['disable_filtering'] = global_args.complain
    global_args.__dict__['print_requests'] = global_args.complain or \
                                             global_args.debug
    global_args.__dict__['print_responses'] = global_args.debug
724 725 726 727 728 729 730 731
    if global_args.listen_interface:
        ip_address = get_ip_address(global_args.listen_interface)
        if global_args.debug:
            log("IP address for interface {} : {}".format(
                global_args.listen_interface,ip_address))
    else:
        ip_address = global_args.listen_address
    address = (ip_address, global_args.listen_port)
anonym's avatar
anonym committed
732 733
    server = FilteredControlPortProxy(address, FilteredControlPortProxyHandler)
    log("Tor control port filter started, listening on {}:{}".format(*address))
734 735 736 737
    try:
        server.serve_forever()
    except KeyboardInterrupt:
        pass
738 739 740


if __name__ == "__main__":
741
    main()