#!/usr/bin/python2

# Copyright (C) 2007-2016 Giampaolo Rodola' <g.rodola@gmail.com>.
# Use of this source code is governed by MIT license that can be
# found in the LICENSE file.

"""
FTP server benchmark script.

In order to run this you must have a listening FTP server with a user
with writing permissions configured.
This is a stand-alone script which does not depend from pyftpdlib.
It just requires python >= 2.6 (and optionally psutil to keep track
of FTP server memory usage).

Example usages:
  ftpbench -u USER -p PASSWORD
  ftpbench -u USER -p PASSWORD -H ftp.domain.com -P 21   # host / port
  ftpbench -u USER -p PASSWORD -b transfer
  ftpbench -u USER -p PASSWORD -b concurrence
  ftpbench -u USER -p PASSWORD -b all
  ftpbench -u USER -p PASSWORD -b concurrence -n 500     # 500 clients
  ftpbench -u USER -p PASSWORD -b concurrence -s 20M     # file size
  ftpbench -u USER -p PASSWORD -b concurrence -p 3521    # memory usage
"""

# Some benchmarks (Linux 3.0.0, Intel core duo - 3.1 Ghz).

# pyftpdlib 1.0.0:
#
#   (starting with 6.7M of memory being used)
#   STOR (client -> server)                              557.97 MB/sec  6.7M
#   RETR (server -> client)                             1613.82 MB/sec  6.8M
#   300 concurrent clients (connect, login)                1.20 secs    8.8M
#   STOR (1 file with 300 idle clients)                  567.52 MB/sec  8.8M
#   RETR (1 file with 300 idle clients)                 1561.41 MB/sec  8.8M
#   300 concurrent clients (RETR 10.0M file)               3.26 secs    10.8M
#   300 concurrent clients (STOR 10.0M file)               8.46 secs    12.6M
#   300 concurrent clients (QUIT)                          0.07 secs
#
#
# proftpd 1.3.4a:
#
#   (starting with 1.4M of memory being used)
#   STOR (client -> server)                              554.67 MB/sec  3.2M
#   RETR (server -> client)                             1517.12 MB/sec  3.2M
#   300 concurrent clients (connect, login)                9.30 secs    568.6M
#   STOR (1 file with 300 idle clients)                  484.11 MB/sec  570.6M
#   RETR (1 file with 300 idle clients)                 1534.61 MB/sec  570.6M
#   300 concurrent clients (RETR 10.0M file)               3.67 secs    568.6M
#   300 concurrent clients (STOR 10.0M file)              11.21 secs    568.7M
#   300 concurrent clients (QUIT)                          0.43 secs
#
#
# vsftpd 2.3.2
#
#   (starting with 352.0K of memory being used)
#   STOR (client -> server)                              607.23 MB/sec  816.0K
#   RETR (server -> client)                             1506.59 MB/sec  816.0K
#   300 concurrent clients (connect, login)               18.91 secs    140.9M
#   STOR (1 file with 300 idle clients)                  618.99 MB/sec  141.4M
#   RETR (1 file with 300 idle clients)                 1402.48 MB/sec  141.4M
#   300 concurrent clients (RETR 10.0M file)               3.64 secs    140.9M
#   300 concurrent clients (STOR 10.0M file)               9.74 secs    140.9M
#   300 concurrent clients (QUIT)                          0.00 secs


from __future__ import division, print_function
import asynchat
import asyncore
import atexit
import contextlib
import ftplib
import optparse
import os
import ssl
import sys
import time
try:
    import resource
except ImportError:
    resource = None

try:
    import psutil
except ImportError:
    psutil = None


HOST = 'localhost'
PORT = 21
USER = None
PASSWORD = None
TESTFN = "$testfile"
BUFFER_LEN = 8192
SERVER_PROC = None
TIMEOUT = None
FILE_SIZE = "10M"
SSL = False
PY3 = sys.version_info >= (3, 0)

server_memory = []
# python >= 2.7.9
SSLWantReadError = getattr(ssl, "SSLWantReadError", object())
SSLWantWriteError = getattr(ssl, "SSLWantWriteError", object())
# python <= 2.7.8
SSL_ERROR_WANT_READ = getattr(ssl, "SSL_ERROR_WANT_READ", object())
SSL_ERROR_WANT_WRITE = getattr(ssl, "SSL_ERROR_WANT_WRITE", object())


if not sys.stdout.isatty() or os.name != 'posix':
    def hilite(s, *args, **kwargs):
        return s
else:
    # http://goo.gl/6V8Rm
    def hilite(string, ok=True, bold=False):
        """Return an highlighted version of 'string'."""
        attr = []
        if ok is None:  # no color
            pass
        elif ok:   # green
            attr.append('32')
        else:   # red
            attr.append('31')
        if bold:
            attr.append('1')
        return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)


def print_bench(what, value, unit=""):
    s = "%s %s %-8s" % (hilite("%-50s" % what, ok=None, bold=0),
                        hilite("%8.2f" % value),
                        unit)
    if server_memory:
        s += "%s" % hilite(server_memory.pop())
    print(s.strip())


# http://goo.gl/zeJZl
def bytes2human(n, format="%(value).1f%(symbol)s"):
    """
    >>> bytes2human(10000)
    '9K'
    >>> bytes2human(100001221)
    '95M'
    """
    symbols = ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y')
    prefix = {}
    for i, s in enumerate(symbols[1:]):
        prefix[s] = 1 << (i + 1) * 10
    for symbol in reversed(symbols[1:]):
        if n >= prefix[symbol]:
            value = float(n) / prefix[symbol]
            return format % locals()
    return format % dict(symbol=symbols[0], value=n)


# http://goo.gl/zeJZl
def human2bytes(s):
    """
    >>> human2bytes('1M')
    1048576
    >>> human2bytes('1G')
    1073741824
    """
    symbols = ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y')
    letter = s[-1:].strip().upper()
    num = s[:-1]
    assert num.isdigit() and letter in symbols, s
    num = float(num)
    prefix = {symbols[0]: 1}
    for i, s in enumerate(symbols[1:]):
        prefix[s] = 1 << (i + 1) * 10
    return int(num * prefix[letter])


def register_memory():
    """Register an approximation of memory used by FTP server process
    and all of its children.
    """
    # XXX How to get a reliable representation of memory being used is
    # not clear. (rss - shared) seems kind of ok but we might also use
    # the private working set via get_memory_maps().private*.
    def get_mem(proc):
        if os.name == 'posix':
            mem = proc.memory_info_ex()
            counter = mem.rss
            if 'shared' in mem._fields:
                counter -= mem.shared
            return counter
        else:
            # TODO figure out what to do on Windows
            return proc.get_memory_info().rss

    if SERVER_PROC is not None:
        mem = get_mem(SERVER_PROC)
        for child in SERVER_PROC.children():
            mem += get_mem(child)
        server_memory.append(bytes2human(mem))


def timethis(what):
    """"Utility function for making simple benchmarks (calculates time calls).
    It can be used either as a context manager or as a decorator.
    """
    @contextlib.contextmanager
    def benchmark():
        timer = time.clock if sys.platform == "win32" else time.time
        start = timer()
        yield
        stop = timer()
        res = (stop - start)
        print_bench(what, res, "secs")

    if hasattr(what, "__call__"):
        def timed(*args, **kwargs):
            with benchmark():
                return what(*args, **kwargs)
        return timed
    else:
        return benchmark()


def connect():
    """Connect to FTP server, login and return an ftplib.FTP instance."""
    ftp_class = ftplib.FTP if not SSL else ftplib.FTP_TLS
    ftp = ftp_class(timeout=TIMEOUT)
    ftp.connect(HOST, PORT)
    ftp.login(USER, PASSWORD)
    if SSL:
        ftp.prot_p()  # secure data connection
    return ftp


def retr(ftp):
    """Same as ftplib's retrbinary() but discard the received data."""
    ftp.voidcmd('TYPE I')
    with contextlib.closing(ftp.transfercmd("RETR " + TESTFN)) as conn:
        recv_bytes = 0
        while 1:
            data = conn.recv(BUFFER_LEN)
            if not data:
                break
            recv_bytes += len(data)
    ftp.voidresp()


def stor(ftp=None):
    """Same as ftplib's storbinary() but just sends dummy data
    instead of reading it from a real file.
    """
    if ftp is None:
        ftp = connect()
        quit = True
    else:
        quit = False
    ftp.voidcmd('TYPE I')
    with contextlib.closing(ftp.transfercmd("STOR " + TESTFN)) as conn:
        chunk = b'x' * BUFFER_LEN
        total_sent = 0
        while 1:
            sent = conn.send(chunk)
            total_sent += sent
            if total_sent >= FILE_SIZE:
                break
    ftp.voidresp()
    if quit:
        ftp.quit()
    return ftp


def bytes_per_second(ftp, retr=True):
    """Return the number of bytes transmitted in 1 second."""
    tot_bytes = 0
    if retr:
        def request_file():
            ftp.voidcmd('TYPE I')
            conn = ftp.transfercmd("retr " + TESTFN)
            return conn

        with contextlib.closing(request_file()) as conn:
            register_memory()
            stop_at = time.time() + 1.0
            while stop_at > time.time():
                chunk = conn.recv(BUFFER_LEN)
                if not chunk:
                    a = time.time()
                    while conn.recv(BUFFER_LEN):
                        break
                    conn.close()
                    ftp.voidresp()
                    conn = request_file()
                    stop_at += time.time() - a
                tot_bytes += len(chunk)

        conn.close()
        try:
            ftp.voidresp()
        except (ftplib.error_temp, ftplib.error_perm):
            pass
    else:
        ftp.voidcmd('TYPE I')
        with contextlib.closing(ftp.transfercmd("STOR " + TESTFN)) as conn:
            register_memory()
            chunk = b'x' * BUFFER_LEN
            stop_at = time.time() + 1
            while stop_at > time.time():
                tot_bytes += conn.send(chunk)
        ftp.voidresp()

    return tot_bytes


def cleanup():
    ftp = connect()
    try:
        if TESTFN in ftp.nlst():
            ftp.delete(TESTFN)
    except (ftplib.error_perm, ftplib.error_temp) as err:
        msg = "could not delete %r test file on cleanup: %r" % (TESTFN, err)
        print(hilite(msg, ok=False), file=sys.stderr)
    ftp.quit()


def bench_stor(ftp=None, title="STOR (client -> server)"):
    if ftp is None:
        ftp = connect()
    tot_bytes = bytes_per_second(ftp, retr=False)
    print_bench(title, round(tot_bytes / 1024.0 / 1024.0, 2), "MB/sec")
    ftp.quit()


def bench_retr(ftp=None, title="RETR (server -> client)"):
    if ftp is None:
        ftp = connect()
    tot_bytes = bytes_per_second(ftp, retr=True)
    print_bench(title, round(tot_bytes / 1024.0 / 1024.0, 2), "MB/sec")
    ftp.quit()


def bench_multi(howmany):
    # The OS usually sets a limit of 1024 as the maximum number of
    # open file descriptors for the current process.
    # Let's set the highest number possible, just to be sure.
    if howmany > 500 and resource is not None:
        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))

    def bench_multi_connect():
        with timethis("%i concurrent clients (connect, login)" % howmany):
            clients = []
            for x in range(howmany):
                clients.append(connect())
            register_memory()
        return clients

    def bench_multi_retr(clients):
        stor(clients[0])
        with timethis("%s concurrent clients (RETR %s file)" % (
                howmany, bytes2human(FILE_SIZE))):
            for ftp in clients:
                ftp.voidcmd('TYPE I')
                conn = ftp.transfercmd("RETR " + TESTFN)
                AsyncReader(conn)
            register_memory()
            asyncore.loop(use_poll=True)
        for ftp in clients:
            ftp.voidresp()

    def bench_multi_stor(clients):
        with timethis("%s concurrent clients (STOR %s file)" % (
                howmany, bytes2human(FILE_SIZE))):
            for ftp in clients:
                ftp.voidcmd('TYPE I')
                conn = ftp.transfercmd("STOR " + TESTFN)
                AsyncWriter(conn, FILE_SIZE)
            register_memory()
            asyncore.loop(use_poll=True)
        for ftp in clients:
            ftp.voidresp()

    def bench_multi_quit(clients):
        for ftp in clients:
            AsyncQuit(ftp.sock)
        with timethis("%i concurrent clients (QUIT)" % howmany):
            asyncore.loop(use_poll=True)

    clients = bench_multi_connect()
    bench_stor(title="STOR (1 file with %s idle clients)" % len(clients))
    bench_retr(title="RETR (1 file with %s idle clients)" % len(clients))
    bench_multi_retr(clients)
    bench_multi_stor(clients)
    bench_multi_quit(clients)


@contextlib.contextmanager
def handle_ssl_want_rw_errs():
    try:
        yield
    except (SSLWantReadError, SSLWantWriteError) as err:
        if DEBUG:
            print(err)
    except ssl.SSLError as err:
        if err.args[0] in (SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE):
            if DEBUG:
                print(err)
        else:
            raise


class AsyncReader(asyncore.dispatcher):
    """Just read data from a connected socket, asynchronously."""

    def __init__(self, sock):
        asyncore.dispatcher.__init__(self, sock)

    def handle_read(self):
        if SSL:
            with handle_ssl_want_rw_errs():
                chunk = self.socket.recv(65536)
        else:
            chunk = self.socket.recv(65536)
        if not chunk:
            self.close()

    def handle_close(self):
        self.close()

    def handle_error(self):
        raise


class AsyncWriter(asyncore.dispatcher):
    """Just write dummy data to a connected socket, asynchronously."""

    def __init__(self, sock, size):
        asyncore.dispatcher.__init__(self, sock)
        self.size = size
        self.sent = 0
        self.chunk = b'x' * BUFFER_LEN

    def handle_write(self):
        if SSL:
            with handle_ssl_want_rw_errs():
                self.sent += asyncore.dispatcher.send(self, self.chunk)
        else:
            self.sent += asyncore.dispatcher.send(self, self.chunk)
        if self.sent >= self.size:
            self.handle_close()

    def handle_error(self):
        raise


class AsyncQuit(asynchat.async_chat):

    def __init__(self, sock):
        asynchat.async_chat.__init__(self, sock)
        self.in_buffer = []
        self.set_terminator(b'\r\n')
        self.push(b'QUIT\r\n')

    def collect_incoming_data(self, data):
        self.in_buffer.append(data)

    def found_terminator(self):
        self.handle_close()

    def handle_error(self):
        raise


class OptFormatter(optparse.IndentedHelpFormatter):

    def format_epilog(self, s):
        return s.lstrip()

    def format_option(self, option):
        result = []
        opts = self.option_strings[option]
        result.append('  %s\n' % opts)
        if option.help:
            help_text = '     %s\n\n' % self.expand_default(option)
            result.append(help_text)
        return ''.join(result)


def main():
    global HOST, PORT, USER, PASSWORD, SERVER_PROC, TIMEOUT, SSL, FILE_SIZE, \
           DEBUG
    USAGE = "%s -u USERNAME -p PASSWORD [-H] [-P] [-b] [-n] [-s] [-k] " \
            "[-t] [-d] [-S]" % (os.path.basename(__file__))
    parser = optparse.OptionParser(usage=USAGE,
                                   epilog=__doc__[__doc__.find('Example'):],
                                   formatter=OptFormatter())
    parser.add_option('-u', '--user', dest='user', help='username')
    parser.add_option('-p', '--pass', dest='password', help='password')
    parser.add_option('-H', '--host', dest='host', default=HOST,
                      help='hostname')
    parser.add_option('-P', '--port', dest='port', default=PORT, help='port',
                      type=int)
    parser.add_option('-b', '--benchmark', dest='benchmark',
                      default='transfer',
                      help="benchmark type ('transfer', 'download', 'upload', "
                           "'concurrence', 'all')")
    parser.add_option('-n', '--clients', dest='clients', default=200,
                      type="int",
                      help="number of concurrent clients used by "
                           "'concurrence' benchmark")
    parser.add_option('-s', '--filesize', dest='filesize', default="10M",
                      help="file size used by 'concurrence' benchmark "
                           "(e.g. '10M')")
    parser.add_option('-k', '--pid', dest='pid', default=None, type="int",
                      help="the PID of the FTP server process, to track its "
                           "memory usage")
    parser.add_option('-t', '--timeout', dest='timeout',
                      default=TIMEOUT, type="int", help="the socket timeout")
    parser.add_option('-d', '--debug', action='store_true', dest='debug',
                      help="whether to print debugging info")
    parser.add_option('-S', '--ssl', action='store_true', dest='ssl',
                      help="whether to use FTPS")

    options, args = parser.parse_args()
    if not options.user or not options.password:
        sys.exit(USAGE)
    else:
        USER = options.user
        PASSWORD = options.password
        HOST = options.host
        PORT = options.port
        TIMEOUT = options.timeout
        SSL = bool(options.ssl)
        DEBUG = options.debug
        if SSL and sys.version_info < (2, 7):
            sys.exit("--ssl option requires python >= 2.7")
        try:
            FILE_SIZE = human2bytes(options.filesize)
        except (ValueError, AssertionError):
            parser.error("invalid file size %r" % options.filesize)
        if options.pid is not None:
            if psutil is None:
                raise ImportError("-p option requires psutil module")
            SERVER_PROC = psutil.Process(options.pid)

    # before starting make sure we have write permissions
    ftp = connect()
    conn = ftp.transfercmd("STOR " + TESTFN)
    conn.close()
    ftp.voidresp()
    ftp.delete(TESTFN)
    ftp.quit()
    atexit.register(cleanup)

    # start benchmark
    if SERVER_PROC is not None:
        register_memory()
        print("(starting with %s of memory being used)" % (
            hilite(server_memory.pop())))
    if options.benchmark == 'download':
        stor()
        bench_retr()
    elif options.benchmark == 'upload':
        bench_stor()
    elif options.benchmark == 'transfer':
        bench_stor()
        bench_retr()
    elif options.benchmark == 'concurrence':
        bench_multi(options.clients)
    elif options.benchmark == 'all':
        bench_stor()
        bench_retr()
        bench_multi(options.clients)
    else:
        sys.exit("invalid 'benchmark' parameter %r" % options.benchmark)


if __name__ == '__main__':
    main()
