#! /usr/bin/python3
#
# Copyright (c) 2016 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fcntl

import os
import pwd
from random import randint
import struct
import subprocess
import sys
import time

try:
    from netifaces import interfaces
except ImportError:
    if sys.platform in ['linux', 'linux2']:
        def interfaces():
            devices = []
            with open("/proc/net/dev", "r") as f_netdev:
                for line in f_netdev:
                    if ":" not in line:
                        continue
                    devices.append(line.split(":")[0].strip())
            return devices
    else:
        print("ERROR: Please install netifaces Python library.")
        sys.exit(1)

try:
    from ovs.db import idl
    from ovs import jsonrpc
    from ovs.poller import Poller
    from ovs.stream import Stream
except Exception:
    print("ERROR: Please install the correct Open vSwitch python support")
    print("       libraries (version 2.14.1).")
    print("       Alternatively, check that your PYTHONPATH is pointing to")
    print("       the correct location.")
    sys.exit(1)

tapdev_fd = None
_make_taps = {}
_make_mirror_name = {}
IFNAMSIZ_LINUX = 15      # this is the max name size, excluding the null byte.


def _doexec(*args, **kwargs):
    """Executes an application and returns a set of pipes"""

    shell = len(args) == 1
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, shell=shell,
                            bufsize=0)
    return proc


def _install_tap_linux(tap_name, mtu_value=None):
    """Uses /dev/net/tun to create a tap device"""
    global tapdev_fd

    IFF_TAP = 0x0002
    IFF_NO_PI = 0x1000
    TUNSETIFF = 0x400454CA  # This is derived by printf() of TUNSETIFF
    TUNSETOWNER = TUNSETIFF + 2

    tapdev_fd = os.open('/dev/net/tun', os.O_RDWR)
    ifr = struct.pack('16sH', tap_name.encode('utf8'), IFF_TAP | IFF_NO_PI)
    fcntl.ioctl(tapdev_fd, TUNSETIFF, ifr)
    fcntl.ioctl(tapdev_fd, TUNSETOWNER, os.getegid())

    time.sleep(1)  # required to give the new device settling time
    if mtu_value is not None:
        pipe = _doexec(
            *(['ip', 'link', 'set', 'dev', str(tap_name), 'mtu',
               str(mtu_value)]))
        pipe.wait()

    pipe = _doexec(
        *(['ip', 'link', 'set', 'dev', str(tap_name), 'up']))
    pipe.wait()


def _make_linux_mirror_name(interface_name):
    if len(interface_name) > IFNAMSIZ_LINUX - 2:
        return "ovsmi%06d" % randint(1, 999999)
    return "mi%s" % interface_name


_make_taps['linux'] = _install_tap_linux
_make_taps['linux2'] = _install_tap_linux
_make_mirror_name['linux'] = _make_linux_mirror_name
_make_mirror_name['linux2'] = _make_linux_mirror_name


def username():
    return pwd.getpwuid(os.getuid())[0]


def usage():
    print("""\
%(prog)s: Open vSwitch tcpdump helper.
usage: %(prog)s -i interface [TCPDUMP OPTIONS]
where TCPDUMP OPTIONS represents the options normally passed to tcpdump.

The following options are available:
   -h, --help                 display this help message
   -V, --version              display version information
   --db-sock                  A connection string to reach the Open vSwitch
                              ovsdb-server.
                              Default 'unix:/run/openvswitch/db.sock'
   --dump-cmd                 Command to use for tcpdump (default 'tcpdump')
   -i, --interface            Open vSwitch interface to mirror and tcpdump
   --mirror-to                The name for the mirror port to use (optional)
                              Default 'miINTERFACE'
   --span                     If specified, mirror all ports (optional)
""" % {'prog': sys.argv[0]})
    sys.exit(0)


class OVSDBException(Exception):
    pass


class OVSDB(object):
    @staticmethod
    def wait_for_db_change(idl):
        seq = idl.change_seqno
        stop = time.time() + 10
        while idl.change_seqno == seq and not idl.run():
            poller = Poller()
            idl.wait(poller)
            poller.block()
            if time.time() >= stop:
                raise Exception('Retry Timeout')

    def __init__(self, db_sock):
        self._db_sock = db_sock
        self._txn = None
        schema = self._get_schema()
        schema.register_all()
        self._idl_conn = idl.Idl(db_sock, schema)
        OVSDB.wait_for_db_change(self._idl_conn)  # Initial Sync with DB

    def _get_schema(self):
        error, strm = Stream.open_block(Stream.open(self._db_sock))
        if error:
            raise Exception("Unable to connect to %s" % self._db_sock)
        rpc = jsonrpc.Connection(strm)
        req = jsonrpc.Message.create_request('get_schema', ['Open_vSwitch'])
        error, resp = rpc.transact_block(req)
        rpc.close()

        if error or resp.error:
            raise Exception('Unable to retrieve schema.')
        return idl.SchemaHelper(None, resp.result)

    def get_table(self, table_name):
        return self._idl_conn.tables[table_name]

    def _start_txn(self):
        if self._txn is not None:
            raise OVSDBException("ERROR: A transaction was started already")
        self._idl_conn.change_seqno += 1
        self._txn = idl.Transaction(self._idl_conn)
        return self._txn

    def _complete_txn(self, try_again_fn):
        if self._txn is None:
            raise OVSDBException("ERROR: Not in a transaction")
        status = self._txn.commit_block()
        if status is idl.Transaction.TRY_AGAIN:
            if self._idl_conn._session.rpc.status != 0:
                self._idl_conn.force_reconnect()
                OVSDB.wait_for_db_change(self._idl_conn)
            return try_again_fn(self)
        elif status is idl.Transaction.ERROR:
            return False

    def _find_row(self, table_name, find):
        return next(
            (row for row in self.get_table(table_name).rows.values()
             if find(row)), None)

    def _find_row_by_name(self, table_name, value):
        return self._find_row(table_name, lambda row: row.name == value)

    def port_exists(self, port_name):
        return bool(self._find_row_by_name('Port', port_name))

    def port_bridge(self, port_name):
        try:
            port = self._find_row_by_name('Port', port_name)
            br = self._find_row('Bridge', lambda x: port in x.ports)
            return br.name
        except Exception:
            raise OVSDBException('Unable to find port %s bridge' % port_name)

    def interface_mtu(self, intf_name):
        try:
            intf = self._find_row_by_name('Interface', intf_name)
            return intf.mtu[0]
        except Exception:
            return None

    def interface_exists(self, intf_name):
        return bool(self._find_row_by_name('Interface', intf_name))

    def mirror_exists(self, mirror_name):
        return bool(self._find_row_by_name('Mirror', mirror_name))

    def interface_uuid(self, intf_name):
        row = self._find_row_by_name('Interface', intf_name)
        if bool(row):
            return row.uuid
        raise OVSDBException('No such interface: %s' % intf_name)

    def make_interface(self, intf_name, execute_transaction=True):
        if self.interface_exists(intf_name):
            print("INFO: Interface exists.")
            return self.interface_uuid(intf_name)

        txn = self._start_txn()
        tmp_row = txn.insert(self.get_table('Interface'))
        tmp_row.name = intf_name

        def try_again(db_entity):
            db_entity.make_interface(intf_name)

        if not execute_transaction:
            return tmp_row

        txn.add_comment("ovs-tcpdump: user=%s,create_intf=%s"
                        % (username(), intf_name))
        status = self._complete_txn(try_again)
        if status is False:
            raise OVSDBException('Unable to create Interface %s: %s' %
                                 (intf_name, txn.get_error()))
        result = txn.get_insert_uuid(tmp_row.uuid)
        self._txn = None
        return result

    def destroy_port(self, port_name, bridge_name):
        if not self.interface_exists(port_name):
            return
        txn = self._start_txn()
        br = self._find_row_by_name('Bridge', bridge_name)
        ports = [port for port in br.ports if port.name != port_name]
        br.ports = ports

        def try_again(db_entity):
            db_entity.destroy_port(port_name)

        txn.add_comment("ovs-tcpdump: user=%s,destroy_port=%s"
                        % (username(), port_name))
        status = self._complete_txn(try_again)
        if status is False:
            raise OVSDBException('unable to delete Port %s: %s' %
                                 (port_name, txn.get_error()))
        self._txn = None

    def destroy_mirror(self, intf_name, bridge_name):
        mirror_name = 'm_%s' % intf_name
        if not self.mirror_exists(mirror_name):
            return
        txn = self._start_txn()
        mirror_row = self._find_row_by_name('Mirror', mirror_name)
        br = self._find_row_by_name('Bridge', bridge_name)
        mirrors = [mirror for mirror in br.mirrors
                   if mirror.uuid != mirror_row.uuid]
        br.mirrors = mirrors

        def try_again(db_entity):
            db_entity.destroy_mirror(mirror_name, bridge_name)

        txn.add_comment("ovs-tcpdump: user=%s,destroy_mirror=%s"
                        % (username(), mirror_name))
        status = self._complete_txn(try_again)
        if status is False:
            raise OVSDBException('Unable to delete Mirror %s: %s' %
                                 (mirror_name, txn.get_error()))
        self._txn = None

    def make_port(self, port_name, bridge_name):
        iface_row = self.make_interface(port_name, False)
        txn = self._txn

        br = self._find_row_by_name('Bridge', bridge_name)
        if not br:
            raise OVSDBException('Bad bridge name %s' % bridge_name)

        port = txn.insert(self.get_table('Port'))
        port.name = port_name

        br.verify('ports')
        ports = getattr(br, 'ports', [])
        ports.append(port)
        br.ports = ports

        port.verify('interfaces')
        ifaces = getattr(port, 'interfaces', [])
        ifaces.append(iface_row)
        port.interfaces = ifaces

        def try_again(db_entity):
            db_entity.make_port(port_name, bridge_name)

        txn.add_comment("ovs-tcpdump: user=%s,create_port=%s"
                        % (username(), port_name))
        status = self._complete_txn(try_again)
        if status is False:
            raise OVSDBException('Unable to create Port %s: %s' %
                                 (port_name, txn.get_error()))
        result = txn.get_insert_uuid(port.uuid)
        self._txn = None
        return result

    def bridge_mirror(self, intf_name, mirror_intf_name, br_name,
                      mirror_select_all=False):

        txn = self._start_txn()
        mirror = txn.insert(self.get_table('Mirror'))
        mirror.name = 'm_%s' % intf_name

        mirror.select_all = mirror_select_all

        mirrored_port = self._find_row_by_name('Port', intf_name)

        mirror.verify('select_dst_port')
        dst_port = getattr(mirror, 'select_dst_port', [])
        dst_port.append(mirrored_port)
        mirror.select_dst_port = dst_port

        mirror.verify('select_src_port')
        src_port = getattr(mirror, 'select_src_port', [])
        src_port.append(mirrored_port)
        mirror.select_src_port = src_port

        output_port = self._find_row_by_name('Port', mirror_intf_name)

        mirror.verify('output_port')
        out_port = getattr(mirror, 'output_port', [])
        out_port.append(output_port.uuid)
        mirror.output_port = out_port

        br = self._find_row_by_name('Bridge', br_name)
        br.verify('mirrors')
        mirrors = getattr(br, 'mirrors', [])
        mirrors.append(mirror.uuid)
        br.mirrors = mirrors

        def try_again(db_entity):
            db_entity.bridge_mirror(intf_name, mirror_intf_name, br_name)

        txn.add_comment("ovs-tcpdump: user=%s,create_mirror=%s"
                        % (username(), mirror.name))
        status = self._complete_txn(try_again)
        if status is False:
            raise OVSDBException('Unable to create Mirror %s: %s' %
                                 (mirror_intf_name, txn.get_error()))
        result = txn.get_insert_uuid(mirror.uuid)
        self._txn = None
        return result


def argv_tuples(lst):
    cur, nxt = iter(lst), iter(lst)
    next(nxt, None)

    try:
        while True:
            yield next(cur), next(nxt, None)
    except StopIteration:
        pass


def py_which(executable):
    return any(os.access(os.path.join(path, executable), os.X_OK)
               for path in os.environ["PATH"].split(os.pathsep))


def main():
    db_sock = 'unix:/run/openvswitch/db.sock'
    interface = None
    tcpdargs = []

    skip_next = False
    mirror_interface = None
    mirror_select_all = False
    dump_cmd = 'tcpdump'

    for cur, nxt in argv_tuples(sys.argv[1:]):
        if skip_next:
            skip_next = False
            continue
        if cur in ['-h', '--help']:
            usage()
        elif cur in ['-V', '--version']:
            print("ovs-tcpdump (Open vSwitch) 2.14.1")
            sys.exit(0)
        elif cur in ['--db-sock']:
            db_sock = nxt
            skip_next = True
            continue
        elif cur in ['--dump-cmd']:
            dump_cmd = nxt
            skip_next = True
            continue
        elif cur in ['-i', '--interface']:
            interface = nxt
            skip_next = True
            continue
        elif cur in ['--mirror-to']:
            mirror_interface = nxt
            skip_next = True
            continue
        elif cur in ['--span']:
            mirror_select_all = True
            continue
        tcpdargs.append(cur)

    if interface is None:
        print("Error: must at least specify an interface with '-i' option")
        sys.exit(1)

    if not py_which(dump_cmd):
        print("Error: unable to execute '%s' (check PATH)" % dump_cmd)
        sys.exit(1)

    if '-l' not in tcpdargs:
        tcpdargs.insert(0, '-l')

    if '-vv' in tcpdargs:
        print("TCPDUMP Args: %s" % ' '.join(tcpdargs))

    ovsdb = OVSDB(db_sock)
    if mirror_interface is None:
        mirror_interface = "mi%s" % interface
        if sys.platform in _make_mirror_name:
            mirror_interface = _make_mirror_name[sys.platform](interface)

    if sys.platform in _make_taps and \
       mirror_interface not in interfaces():
        _make_taps[sys.platform](mirror_interface,
                                 ovsdb.interface_mtu(interface))

    if mirror_interface not in interfaces():
        print("ERROR: Please create an interface called `%s`" %
              mirror_interface)
        print("See your OS guide for how to do this.")
        print("Ex: ip link add %s type veth peer name %s" %
              (mirror_interface, mirror_interface + "2"))
        sys.exit(1)

    if not ovsdb.port_exists(interface):
        print("ERROR: Port %s does not exist." % interface)
        sys.exit(1)
    if ovsdb.port_exists(mirror_interface):
        print("ERROR: Mirror port (%s) exists for port %s." %
              (mirror_interface, interface))
        sys.exit(1)
    try:
        ovsdb.make_port(mirror_interface, ovsdb.port_bridge(interface))
        ovsdb.bridge_mirror(interface, mirror_interface,
                            ovsdb.port_bridge(interface),
                            mirror_select_all)
    except OVSDBException as oe:
        print("ERROR: Unable to properly setup the mirror: %s." % str(oe))
        try:
            ovsdb.destroy_port(mirror_interface, ovsdb.port_bridge(interface))
        except Exception:
            pass
        sys.exit(1)

    pipes = _doexec(*([dump_cmd, '-i', mirror_interface] + tcpdargs))
    try:
        while pipes.poll() is None:
            data = pipes.stdout.readline().strip(b'\n')
            if len(data) == 0:
                raise KeyboardInterrupt
            print(data.decode('utf-8'))
        raise KeyboardInterrupt
    except KeyboardInterrupt:
        if pipes.poll() is None:
            pipes.terminate()

        ovsdb.destroy_mirror(interface, ovsdb.port_bridge(interface))
        ovsdb.destroy_port(mirror_interface, ovsdb.port_bridge(interface))
    except Exception:
        print("Unable to tear down the ports and mirrors.")
        print("Please use ovs-vsctl to remove the ports and mirrors created.")
        print(" ex: ovs-vsctl --db=%s del-port %s" % (db_sock,
                                                      mirror_interface))
        sys.exit(1)

    sys.exit(0)


if __name__ == '__main__':
    main()

# Local variables:
# mode: python
# End:
