onion-grater 29.1 KB
Newer Older
1
#!/usr/bin/python3 -u
2

3
# This filter proxy allows fine-grained access whitelists of commands
4 5
# (and their argunents) and events on a per-application basis, stored
# in:
6
#
7
#     /etc/onion-grater.d/
8
#
9 10
# that are pretty self-explanatory as long as you understand the Tor
# ControlPort language. The format is expressed in YAML where the
anonym's avatar
anonym committed
11 12
# top-level is supposed to be a list, where each element is a
# dictionary looking something like this:
13
#
14
#     - name: blabla
15
#       exe-paths:
anonym's avatar
anonym committed
16
#         - path_to_executable
17
#         ...
18
#       users:
anonym's avatar
anonym committed
19 20
#         - user
#         ...
21
#       hosts:
anonym's avatar
anonym committed
22
#         - host
23 24
#         ...
#       commands:
anonym's avatar
anonym committed
25
#         command:
26
#           - command_arg_rule
27 28
#           ...
#         ...
29 30 31
#       confs:
#         conf:
#           - conf_arg_rule
32
#           ...
33
#         ...
34
#       events:
anonym's avatar
anonym committed
35 36 37
#         event:
#           event_option: event_option_value
#           ...
38 39
#         ...
#
40 41 42 43 44 45
# `name` (optional) is a string which gives an internal name, useful
# for debugging. When not given, filters will default to the name of
# the file (excluding extension) they were read from (so there can be
# duplicates!). It is advisable to define one filter per file, and
# give helpful filenames instead of using this field.
#
46
# A filter is matched if for each of the relevant qualifiers at
anonym's avatar
anonym committed
47
# least one of the elements match the client. For local (loopback)
48
# clients the following qualifiers are relevant:
anonym's avatar
anonym committed
49
#
50
# * `exe-paths`: a list of strings, each describing the path to
anonym's avatar
anonym committed
51 52 53 54 55
#   the binary or script of the client with `*` matching
#   anything. While this matcher always works for binaries, it only
#   works for scripts with an enabled AppArmor profile (not
#   necessarily enforced, complain mode is good enough).
#
56
# * `users`: a list of strings, each describing the user of the
anonym's avatar
anonym committed
57 58
#   client with `*` matching anything.
#
59 60
# For remote (non-local) clients, the following qualifiers are
# relevant:
anonym's avatar
anonym committed
61
#
62
# * hosts: a list of strings, each describing the IPv4 address
anonym's avatar
anonym committed
63 64
#   of the client with `*` matching anything.
#
65 66
# A filter can serve both local and remote clients by having
# qualifiers of both types.
anonym's avatar
anonym committed
67 68 69 70 71 72
#
# `commands` (optional) is a list where each item is a dictionary with
# the obligatory `pattern` key, which is a regular expression that is
# matched against the full argument part of the command. The default
# behavior is to just proxy the line through if matched, but it can be
# altered with these keys:
anonym's avatar
anonym committed
73 74 75
#
# * `replacement`: this rewrites the arguments. The value is a Python
#   format string (str.format()) which will be given the match groups
anonym's avatar
anonym committed
76
#   from the match of `pattern`. The rewritten command is then proxied
77 78 79 80 81 82 83
#   without the need to match any rule. There are also some special
#   patterns that will be replaced as follows:
#
#   - {client-address}: the client's IP address
#   - {client-port}: the client's port
#   - {server-address}: the server's IP address
#   - {server-port}: the server's (listening) port
anonym's avatar
anonym committed
84
#
85 86 87 88 89 90
# * `response`: a list of dictionaries, where the `pattern` and
#   `replacement` keys work exactly as for commands arguments, but now
#   for the response. Note that this means that the response is left
#   intact if `pattern` doesn't match it, and if many `pattern`:s
#   match, only the first one (in the order listed) will trigger a
#   replacement.
anonym's avatar
anonym committed
91 92 93
#
# If a simple regex (as string) is given, it is assumed to be the
# `pattern` which allows a short-hand for this common type of rule.
anonym's avatar
anonym committed
94
#
anonym's avatar
anonym committed
95 96 97
# Note that to allow a command to be run without arguments, the empty
# string must be explicitly given as a `pattern`. Hence, an empty
# argument list does not allow any use of the command.
98
#
99 100 101
# `confs` (optional) is a dictionary, and it's just syntactic sugar to
# generate GETCONF/SETCONF rules. If a key exists, GETCONF of the
# keyname is allowed, and if it has a non-empty list as value, those
102
# values are allowed to be set. The empty string means that resetting
103 104 105
# it is allowed. This is very useful for applications that like to
# SETCONF on multiple configurations at the same time.
#
anonym's avatar
anonym committed
106
# `events` (optional) is a dictionary where the key represents the
107
# event. If a key exists the event is allowed. The value is another
anonym's avatar
anonym committed
108 109 110 111 112 113
# dictionary of options:
#
# * `suppress`: a boolean determining whether we should just fool the
#   client that it has subscribed to the event (i.e. the client
#   request is not filtered) while we suppress them.
#
114 115 116 117
# * `response`: a dictionary, where the `pattern` and `replacement`
#   keys work exactly as for `response` for commands, but now for the
#   events.
#
anonym's avatar
anonym committed
118 119 120
# `restrict-stream-events` (optional) is a boolean, and if set any
# STREAM events sent to the client (after it has subscribed to them)
# will be restricted to those belonging to the client itself. This
121 122
# option only works for local clients and will be unset for remote
# clients.
123

anonym's avatar
anonym committed
124
import argparse
125
import fcntl
126
import glob
127
import ipaddress
128
import os.path
anonym's avatar
anonym committed
129 130
import psutil
import re
131
import socket
anonym's avatar
anonym committed
132
import socketserver
133 134
import stem
import stem.control
135
import struct
136
import sys
137
import textwrap
anonym's avatar
anonym committed
138
import yaml
139

140
DEFAULT_LISTEN_ADDRESS = 'localhost'
141
DEFAULT_LISTEN_PORT = 9051
142 143
DEFAULT_COOKIE_PATH = '/run/tor/control.authcookie'
DEFAULT_CONTROL_SOCKET_PATH = '/run/tor/control'
anonym's avatar
anonym committed
144

145

146
class NoRewriteMatch(RuntimeError):
147 148 149
    """
    Error when no matching rewrite rule was found but one was expected.
    """
150 151
    pass

152

153 154 155 156 157
def log(msg):
    print(msg, file=sys.stderr)
    sys.stderr.flush()


158 159
def pid_of_laddr(address):
    try:
160
        return next(conn for conn in psutil.net_connections()
161 162 163 164 165
                    if conn.laddr == address).pid
    except StopIteration:
        return None


166 167 168 169 170 171
def exe_path_of_pid(pid):
    # Here we leverage AppArmor's in-kernel solution for determining
    # the exact executable invoked. Looking at /proc/pid/exe when an
    # interpreted script is running will just point to the
    # interpreter's binary, which is not fine-grained enough, but
    # AppArmor will be aware of which script is running for processes
172 173 174
    # using one of its profiles. However, we fallback to /proc/pid/exe
    # in case there is no AppArmor profile, so the only unsupported
    # mode here is unconfined scripts.
175
    enabled_aa_profile_re = r'^(.+) \((?:complain|enforce)\)$'
176 177 178 179 180 181 182
    with open('/proc/{}/attr/current'.format(str(pid)), "rb") as fh:
        aa_profile_status = str(fh.read().strip(), 'UTF-8')
        exe_path_match = re.match(enabled_aa_profile_re, aa_profile_status)
        if exe_path_match:
            return exe_path_match.group(1)
        else:
            return psutil.Process(pid).exe()
183 184


185 186 187 188 189 190 191 192 193
def get_ip_address(ifname):
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    return socket.inet_ntoa(fcntl.ioctl(
        s.fileno(),
        0x8915,  # SIOCGIFADDR
        struct.pack('256s', bytes(ifname[:15], 'utf-8'))
    )[20:24])


194
class FilteredControlPortProxySession:
195 196 197 198 199 200 201
    """
    Class used to deal with a single session, delegated from the handler
    (FilteredControlPortProxyHandler). Its main job is proxying the traffic
    between the client and the real control port, blocking or rewriting
    it as dictated by the filter rule set.
    """

202 203
    # Limit the length of a line, to prevent DoS attacks trying to
    # crash this filter proxy by sending infinitely long lines.
204
    MAX_LINESIZE = 10*1024
205 206 207 208

    def __init__(self, handler):
        self.allowed_commands = handler.allowed_commands
        self.allowed_events = handler.allowed_events
209
        self.client_address = handler.client_address
210 211 212 213 214 215 216 217 218 219
        self.client_pid = handler.client_pid
        self.controller = handler.controller
        self.debug_log = handler.debug_log
        self.filter_name = handler.filter_name
        self.restrict_stream_events = handler.restrict_stream_events
        self.rfile = handler.rfile
        self.server_address = handler.server_address
        self.wfile = handler.wfile
        self.client_streams = set()
        self.subscribed_event_listeners = []
220

221
    def debug_log_send(self, line):
222
        if global_args.print_responses:
223
            self.debug_log(line, format_multiline=True, sep=': <- ')
224

225
    def debug_log_recv(self, line):
226
        if global_args.print_requests:
227
            self.debug_log(line, format_multiline=True, sep=': -> ')
228

229
    def debug_log_rewrite(self, kind, old, new):
230 231
        if kind not in ['command', 'received event', 'response'] or \
           (kind == 'command' and not global_args.print_responses) or \
232 233
           (kind in ['received event', 'response']
            and not global_args.print_requests):
234 235 236 237
            return
        if new != old:
            old = textwrap.indent(old.strip(), ' '*4)
            new = textwrap.indent(new.strip(), ' '*4)
238 239
            self.debug_log("rewrote {}:\n{}\nto:\n{}".format(kind, old, new),
                           format_multiline=False)
240

241
    def respond(self, line, raw=False):
242 243
        if line.isspace():
            return
244 245
        self.debug_log_send(line)
        self.wfile.write(bytes(line, 'ascii'))
246 247
        if not raw:
            self.wfile.write(bytes("\r\n", 'ascii'))
248
        self.wfile.flush()
249

250 251
    def get_rule(self, cmd, arg_str):
        allowed_args = self.allowed_commands.get(cmd, [])
252
        return next((rule for rule in allowed_args
253
                     if re.match(rule['pattern'] + "$", arg_str)), None)
254

255
    def proxy_line(self, line, args_rewriter=None, response_rewriter=None):
256 257
        if args_rewriter:
            new_line = args_rewriter(line)
258
            self.debug_log_rewrite('command', line, new_line)
259
            line = new_line
260
        response = self.controller.msg(line.strip()).raw_content()
261
        if response_rewriter:
262
            new_response = response_rewriter(response)
263
            self.debug_log_rewrite('response', response, new_response)
264
            response = new_response
265
        self.respond(response, raw=True)
266

267 268 269 270
    def filter_line(self, line):
        self.debug_log("command filtered: {}".format(line))
        self.respond("510 Command filtered")

271
    def rewrite_line(self, replacers, line):
272
        builtin_replacers = {
273 274 275 276
            'client-address': self.client_address[0],
            'client-port':    str(self.client_address[1]),
            'server-address': self.server_address[0],
            'server-port':    str(self.server_address[1]),
277
        }
278 279 280 281
        terminator = ''
        if line[-2:] == "\r\n":
            terminator = "\r\n"
            line = line[:-2]
282 283 284
        for r in replacers:
            match = re.match(r['pattern'] + "$", line)
            if match:
285 286 287
                return r['replacement'].format(
                    *match.groups(), **builtin_replacers
                ) + terminator
288
        raise NoRewriteMatch()
289

290
    def rewrite_matched_line(self, replacers, line):
291
        try:
292
            return self.rewrite_line(replacers, line)
293
        except NoRewriteMatch:
294 295
            return line

296
    def rewrite_matched_lines(self, replacers, lines):
297
        split_lines = lines.strip().split("\r\n")
298
        return "\r\n".join([self.rewrite_matched_line(replacers, line)
299
                            for line in split_lines]) + "\r\n"
300

301
    def event_cb(self, event, event_rewriter=None):
302
        if self.restrict_stream_events and \
303
           isinstance(event, stem.response.events.StreamEvent) and \
304
           not global_args.disable_filtering:
305
            if event.id not in self.client_streams:
306 307
                if event.status in [stem.StreamStatus.NEW,
                                    stem.StreamStatus.NEWRESOLVE] and \
308 309 310
                   self.client_pid == pid_of_laddr((event.source_address,
                                                    event.source_port)):
                    self.client_streams.add(event.id)
311 312 313 314
                else:
                    return
            elif event.status in [stem.StreamStatus.FAILED,
                                  stem.StreamStatus.CLOSED]:
315
                self.client_streams.remove(event.id)
316 317
        raw_event_content = event.raw_content()
        if event_rewriter:
318
            new_raw_event_content = event_rewriter(raw_event_content)
319
            self.debug_log_rewrite(
320 321 322
                'received event', raw_event_content, new_raw_event_content
            )
            raw_event_content = new_raw_event_content
323 324
            if raw_event_content.strip() == '':
                return
325
        self.respond(raw_event_content, raw=True)
326

327 328
    def update_event_subscriptions(self, events):
        for listener, event in self.subscribed_event_listeners:
329
            if event not in events:
330 331
                self.controller.remove_event_listener(listener)
                self.subscribed_event_listeners.remove((listener, event))
332
                if global_args.print_responses:
333
                    self.debug_log("unsubscribed from event '{}'".format(event))
anonym's avatar
anonym committed
334
        for event in events:
335
            if any(event == event_ for _, event_ in self.subscribed_event_listeners):
336
                if global_args.print_responses:
337 338
                    self.debug_log("already subscribed to event '{}'"
                                   .format(event))
339
                continue
340
            rule = self.allowed_events.get(event, {}) or {}
anonym's avatar
anonym committed
341 342 343 344
            if not rule.get('suppress', False) or \
               global_args.disable_filtering:
                event_rewriter = None
                if 'response' in rule:
345
                    replacers = rule['response']
anonym's avatar
anonym committed
346
                    def _event_rewriter(line):
347
                        return self.rewrite_matched_line(replacers, line)
anonym's avatar
anonym committed
348 349
                    event_rewriter = _event_rewriter
                def _event_cb(event):
350 351
                    self.event_cb(event, event_rewriter=event_rewriter)
                self.controller.add_event_listener(
anonym's avatar
anonym committed
352 353
                    _event_cb, getattr(stem.control.EventType, event)
                )
354
                self.subscribed_event_listeners.append((_event_cb, event))
355
                if global_args.print_responses:
356
                    self.debug_log("subscribed to event '{}'".format(event))
357 358
            else:
                if global_args.print_responses:
359
                    self.debug_log("suppressed subscription to event '{}'"
360
                                   .format(event))
361
        self.respond("250 OK")
362

363 364 365 366 367 368 369 370
    def handle(self):
        while True:
            binary_line = self.rfile.readline(self.MAX_LINESIZE)
            if binary_line == b'':
                # Deal with clients that close the socket without a QUIT.
                break
            line = str(binary_line, 'ascii')
            if line.isspace():
371 372
                self.debug_log('ignoring received empty (or whitespace-only) '
                               + 'line')
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
                continue
            match = re.match(
                r'(?P<cmd>\S+)(?P<cmd_arg_sep>\s*)(?P<arg_str>[^\r\n]*)\r?\n$',
                line
            )
            if not match:
                self.debug_log("received bad line (escapes made explicit): " +
                               repr(line))
                # Hopefully the next line is ok...
                continue
            self.debug_log_recv(line)
            cmd         = match.group('cmd')
            cmd_arg_sep = match.group('cmd_arg_sep')
            arg_str     = match.group('arg_str')
            args = arg_str.split()
            cmd = cmd.upper()

            if cmd == "PROTOCOLINFO":
                # Stem calls PROTOCOLINFO before authenticating. Tell the
                # client that there is no authentication.
                self.respond("250-PROTOCOLINFO 1")
                self.respond("250-AUTH METHODS=NULL")
                self.respond("250-VERSION Tor=\"{}\""
                             .format(self.controller.get_version()))
                self.respond("250 OK")

            elif cmd == "AUTHENTICATE":
                # We have already authenticated, and the filtered port is
                # access-restricted according to our filter instead.
                self.respond("250 OK")

            elif cmd == "QUIT":
                self.respond("250 closing connection")
                break

            elif cmd == "SETEVENTS":
                # The control language doesn't care about case for
                # the event type.
                events = [event.upper() for event in args]
412 413 414 415 416
                if not global_args.disable_filtering and \
                   any(event not in self.allowed_events for event in events):
                    self.filter_line(line)
                else:
                    self.update_event_subscriptions(events)
417

418
            else:
419
                rule = self.get_rule(cmd, arg_str)
420
                if rule is None and global_args.disable_filtering:
421
                    rule = {}
422
                if rule is not None:
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
                    args_rewriter = None
                    response_rewriter = None

                    if 'response' in rule:
                        def _response_rewriter(lines):
                            return self.rewrite_matched_lines(rule['response'],
                                                              lines)
                        response_rewriter = _response_rewriter

                    if 'replacement' in rule:
                        def _args_rewriter(line):
                            # We also want to match the command in `line`
                            # and add it back to the replacement string.
                            # We make sure to keep the exact white spaces
                            # separating the command and arguments, to not
                            # rewrite the line unnecessarily.
                            prefix = cmd + cmd_arg_sep
                            replacer = {
                                'pattern':     prefix + rule['pattern'],
                                'replacement': prefix + rule['replacement']
                            }
                            return self.rewrite_line([replacer], line)
                        args_rewriter = _args_rewriter

                    self.proxy_line(line, args_rewriter=args_rewriter,
                                    response_rewriter=response_rewriter)
                else:
450
                    self.filter_line(line)
451 452


453
class FilteredControlPortProxyHandler(socketserver.StreamRequestHandler):
454 455 456 457 458 459 460
    """
    Class handing each control port connection and collecting information
    about the origin and using it to find a matching filter rule set. It
    then delegates the session handling (the actual filtering) to a
    FilteredControlPortProxySession object.
    """

461
    def debug_log(self, line, format_multiline=False, sep=': '):
462 463 464 465 466 467
        line = line.strip()
        if format_multiline and "\n" in line:
            sep += "(multi-line)\n"
            line = textwrap.indent(line, ' '*4)
        log(self.client_desc + sep + line)

468 469
    def setup(self):
        super(type(self), self).setup()
470 471 472 473 474 475 476
        self.allowed_commands = {}
        self.allowed_events = {}
        self.client_desc = None
        self.client_pid = None
        self.client_streams = set()
        self.controller = None
        self.filter_name = None
477
        self.filters = []
478 479 480
        self.restrict_stream_events = False
        self.server_address = self.server.server_address
        self.subscribed_event_listeners = []
481
        for filter_file in glob.glob('/etc/onion-grater.d/*.yml'):
482 483
            try:
                with open(filter_file, "rb") as fh:
484
                    filters = yaml.safe_load(fh.read())
485 486 487 488 489
                    name = re.sub(r'\.yml$', '', os.path.basename(filter_file))
                    for filter_ in filters:
                        if name not in filter_:
                            filter_['name'] = name
                    self.filters += filters
490 491 492
            except (yaml.parser.ParserError, yaml.scanner.ScannerError) as err:
                log("filter '{}' has bad YAML and was not loaded: {}"
                    .format(filter_file, str(err)))
493

494
    def add_allowed_commands(self, commands):
495 496 497 498 499
        for cmd in commands:
            allowed_args = commands[cmd]
            # An empty argument list allows nothing, but will
            # make some code below easier than if it can be
            # None as well.
500
            if allowed_args is None:
501 502 503 504
                allowed_args = []
            for i in range(len(allowed_args)):
                if isinstance(allowed_args[i], str):
                    allowed_args[i] = {'pattern': allowed_args[i]}
505
            self.allowed_commands[cmd.upper()] = allowed_args
506

507
    def add_allowed_confs_commands(self, confs):
508 509 510
        combined_getconf_rule = {'pattern': "(" + "|".join([
            key for key in confs]) + ")"}
        setconf_reset_part = "\s*|\s*".join([
511 512
            key for key in confs
            if isinstance(confs[key], list) and '' in confs[key]]
513 514 515 516
        )
        setconf_assignment_part = "\s*|\s*".join([
            "{}=({})".format(
                key, "|".join(confs[key])
517 518 519
            )
            for key in confs
            if isinstance(confs[key], list) and len(confs[key]) > 0])
520 521 522 523 524 525 526 527 528 529
        setconf_parts = []
        for part in [setconf_reset_part, setconf_assignment_part]:
            if part and part != '':
                setconf_parts.append(part)
        combined_setconf_rule = {
            'pattern': "({})+".format("\s*|\s*".join(setconf_parts))
        }
        for cmd, rule in [('GETCONF', combined_getconf_rule),
                          ('SETCONF', combined_setconf_rule)]:
            if rule['pattern'] != "()+":
530 531 532
                if cmd not in self.allowed_commands:
                    self.allowed_commands[cmd] = []
                self.allowed_commands[cmd].append(rule)
533

534
    def add_allowed_events(self, events):
535 536 537 538
        for event in events:
            opts = events[event]
            # Same as for the `commands` argument list, let's
            # add an empty dict to simplify later code.
539
            if opts is None:
540
                opts = {}
541 542 543
            self.allowed_events[event.upper()] = opts

    def match_and_parse_filter(self, matchers):
544 545 546
        matched_filters = [filter_ for filter_ in self.filters
                           if all(any(val == expected_val or val == '*'
                                      for val in filter_.get(key, []))
547 548
                                  for key, expected_val in matchers)]
        if len(matched_filters) == 0:
549
            return
550 551 552 553
        elif len(matched_filters) > 1:
            raise RuntimeError('multiple filters matched: ' +
                               ', '.join(matched_filters))
        matched_filter = matched_filters[0]
554
        self.filter_name = matched_filter['name']
555
        commands = matched_filter.get('commands', {}) or {}
556
        self.add_allowed_commands(commands)
557
        confs = matched_filter.get('confs', {}) or {}
558
        self.add_allowed_confs_commands(confs)
559
        events = matched_filter.get('events', {}) or {}
560 561
        self.add_allowed_events(events)
        self.restrict_stream_events = bool(matched_filter.get(
562 563 564
            'restrict-stream-events', False
        ))

565
    def connect_to_real_control_port(self):
anonym's avatar
anonym committed
566
        with open(global_args.control_cookie_path, "rb") as f:
567
            cookie = f.read()
568 569 570
        controller = stem.control.Controller.from_socket_file(
            global_args.control_socket_path
        )
571
        controller.authenticate(cookie)
anonym's avatar
anonym committed
572
        return controller
573

574
    def handle(self):
575 576 577
        client_host = self.client_address[0]
        local_connection = ipaddress.ip_address(client_host).is_loopback
        if local_connection:
578
            self.client_pid = pid_of_laddr(self.client_address)
579 580
            # Deal with the race between looking up the PID, and the
            # client being killed before we find the PID.
581 582
            if not self.client_pid:
                return
583 584
            client_exe_path = exe_path_of_pid(self.client_pid)
            client_user = psutil.Process(self.client_pid).username()
anonym's avatar
anonym committed
585
            matchers = [
586 587
                ('exe-paths', client_exe_path),
                ('users',     client_user),
anonym's avatar
anonym committed
588
            ]
589
        else:
590
            self.client_pid = None
anonym's avatar
anonym committed
591
            matchers = [
592
                ('hosts', client_host),
anonym's avatar
anonym committed
593
            ]
594
        self.match_and_parse_filter(matchers)
595
        if local_connection:
596 597
            self.client_desc = '{exe} (pid: {pid}, user: {user}, ' \
                               'port: {port}, filter: {filter_name})'.format(
598 599 600 601 602
                                   exe=client_exe_path,
                                   pid=self.client_pid,
                                   user=client_user,
                                   port=self.client_address[1],
                                   filter_name=self.filter_name
603
                               )
604
        else:
605 606
            self.client_desc = '{1}:{2} (filter: {0})'.format(
                self.filter_name, *self.client_address
607
            )
608 609
        if self.restrict_stream_events and not local_connection:
            self.debug_log(
610
                "filter '{}' has `restrict-stream-events` set "
anonym's avatar
anonym committed
611
                "and we are remote so the option was disabled"
612
                .format(self.filter_name)
anonym's avatar
anonym committed
613
            )
614 615 616 617 618 619 620
            self.restrict_stream_events = False

        if len(self.filters) == 0:
            status = 'no matching filter found, using an empty one'
        else:
            status = 'loaded filter: {}'.format(self.filter_name)
        log('{} connected: {}'.format(self.client_desc, status))
621
        if global_args.debug:
622
            log('Final rules:')
623
            log(yaml.dump({
624 625 626
                'commands': self.allowed_commands,
                'events': self.allowed_events,
                'restrict-stream-events': self.restrict_stream_events,
627
            }))
628
        disconnect_reason = "client quit"
629
        try:
630 631 632
            self.controller = self.connect_to_real_control_port()
            session = FilteredControlPortProxySession(self)
            session.handle()
633
        except (ConnectionResetError, BrokenPipeError) as err:
634
            # Handle clients disconnecting abruptly
635
            disconnect_reason = str(err)
636 637 638
        except stem.SocketError:
            # Handle client closing its socket abruptly
            disconnect_reason = "Client closed its socket"
639
        except stem.SocketClosed:
anonym's avatar
anonym committed
640
            # Handle Tor closing its socket abruptly
641
            disconnect_reason = "Tor closed its socket"
642
        finally:
643 644 645 646
            if self.controller:
                self.controller.close()
            log('{} disconnected: {}'.format(self.client_desc,
                                             disconnect_reason))
647 648


649
class FilteredControlPortProxy(socketserver.ThreadingTCPServer):
650 651 652 653
    """
    Simple subclass just setting some defaults differently.
    """

anonym's avatar
anonym committed
654
    # So we can restart when the listening port if in TIME_WAIT state
655 656 657 658 659 660 661
    # after an abrupt shutdown.
    allow_reuse_address = True
    # So all server threads immediately quit when the main thread
    # quits.
    daemon_threads = True


662
def main():
anonym's avatar
anonym committed
663
    parser = argparse.ArgumentParser()
664
    parser.add_argument(
665
        "--listen-address",
666 667
        type=str, metavar='ADDR', default=DEFAULT_LISTEN_ADDRESS,
        help="specifies the address on which the server listens " +
668 669
             "(default: {})".format(DEFAULT_LISTEN_ADDRESS)
    )
anonym's avatar
anonym committed
670
    parser.add_argument(
671
        "--listen-port",
672 673
        type=int, metavar='PORT', default=DEFAULT_LISTEN_PORT,
        help="specifies the port on which the server listens " +
674 675
             "(default: {})".format(DEFAULT_LISTEN_PORT)
    )
676 677 678 679 680 681
    parser.add_argument(
        "--listen-interface",
        type=str, metavar='INTERFACE',
        help="specifies the interface on which the server listens " +
             "(default: NULL)"
    )
anonym's avatar
anonym committed
682
    parser.add_argument(
683
        "--control-cookie-path",
684 685
        type=str, metavar='PATH', default=DEFAULT_COOKIE_PATH,
        help="specifies the path to Tor's control authentication cookie " +
686 687
             "(default: {})".format(DEFAULT_COOKIE_PATH)
    )
anonym's avatar
anonym committed
688
    parser.add_argument(
689
        "--control-socket-path",
690 691
        type=str, metavar='PATH', default=DEFAULT_CONTROL_SOCKET_PATH,
        help="specifies the path to Tor's control socket " +
692 693
             "(default: {})".format(DEFAULT_CONTROL_SOCKET_PATH)
    )
694 695
    parser.add_argument(
        "--complain",
696 697
        action='store_true', default=False,
        help="disables all filtering and just prints the commands sent " +
698 699
             "by the client"
    )
700 701 702
    parser.add_argument(
        "--debug",
        action='store_true', default=False,
703 704
        help="prints all requests and responses"
    )
705 706 707
    # We put the argparse results in the global scope since it's
    # awkward to extend socketserver so additional data can be sent to
    # the request handler, where we need access to the arguments.
anonym's avatar
anonym committed
708 709
    global global_args
    global_args = parser.parse_args()
710 711 712 713 714
    # Deal with overlapping functionality between arguments
    global_args.__dict__['disable_filtering'] = global_args.complain
    global_args.__dict__['print_requests'] = global_args.complain or \
                                             global_args.debug
    global_args.__dict__['print_responses'] = global_args.debug
715 716 717 718 719 720 721 722
    if global_args.listen_interface:
        ip_address = get_ip_address(global_args.listen_interface)
        if global_args.debug:
            log("IP address for interface {} : {}".format(
                global_args.listen_interface,ip_address))
    else:
        ip_address = global_args.listen_address
    address = (ip_address, global_args.listen_port)
anonym's avatar
anonym committed
723 724
    server = FilteredControlPortProxy(address, FilteredControlPortProxyHandler)
    log("Tor control port filter started, listening on {}:{}".format(*address))
725 726 727 728
    try:
        server.serve_forever()
    except KeyboardInterrupt:
        pass
729 730 731


if __name__ == "__main__":
732
    main()