"""base class for parallel client tests"""
import os
import signal
import sys
import time
import warnings
from contextlib import contextmanager

import pytest
import zmq
from decorator import decorator

from ipyparallel import Client, error
from ipyparallel.tests import add_engines, launchers

# simple tasks for use in apply tests


def segfault():
    """this will segfault"""
    import ctypes

    ctypes.memset(-1, 0, 1)


def crash():
    """Ungracefully exit the process"""
    os._exit(1)


def conditional_crash(condition):
    """Ungracefully exit the process"""
    if condition:
        crash()


def wait(n):
    """sleep for a time"""
    import time

    time.sleep(n)
    return n


def raiser(eclass):
    """raise an exception"""
    raise eclass()


def generate_output():
    """function for testing output

    publishes two outputs of each type, and returns
    a rich displayable object.
    """

    from IPython.core.display import HTML, Math, display

    print("stdout")
    print("stderr", file=sys.stderr)

    display(HTML("<b>HTML</b>"))

    print("stdout2")
    print("stderr2", file=sys.stderr)

    display(Math(r"\alpha=\beta"))

    return Math("42")


# test decorator for skipping tests when libraries are unavailable
def skip_without(*names):
    """skip a test if some names are not importable"""

    @decorator
    def skip_without_names(f, *args, **kwargs):
        """decorator to skip tests in the absence of numpy, etc."""
        for name in names:
            try:
                __import__(name)
            except ImportError:
                pytest.skip("Test requires %s" % name)
            except Exception as e:
                warnings.warn(f"Unexpected exception importing {name}: {e}")
                pytest.skip("Test requires %s" % name)
        return f(*args, **kwargs)

    return skip_without_names


@contextmanager
def raises_remote(etype):
    if isinstance(etype, str):
        # allow Exception or 'Exception'
        expected_ename = etype
    else:
        expected_ename = etype.__name__

    try:
        try:
            yield
        except error.AlreadyDisplayedError as e:
            e.original_error.raise_exception()
        except error.CompositeError as e:
            e.raise_exception()
    except error.RemoteError as e:
        assert (
            expected_ename == e.ename
        ), f"Should have raised {expected_ename}, but raised {e.ename}"

    else:
        pytest.fail("should have raised a RemoteError")


# -------------------------------------------------------------------------------
# Classes
# -------------------------------------------------------------------------------


@pytest.mark.usefixtures("cluster")
class ClusterTestCase:
    timeout = 10
    engine_count = 2

    def add_engines(self, n=1, block=True):
        """add multiple engines to our cluster"""
        self.engines.extend(add_engines(n))
        if block:
            self.wait_on_engines()

    def minimum_engines(self, n=1, block=True):
        """add engines until there are at least n connected"""
        self.engines.extend(add_engines(n, total=True))
        if block:
            self.wait_on_engines()

    def wait_on_engines(self, timeout=5):
        """wait for our engines to connect."""
        n = len(self.engines) + self.base_engine_count
        self.client.wait_for_engines(n, timeout=timeout)

        assert not len(self.client.ids) < n, "waiting for engines timed out"

    def client_wait(self, client, jobs=None, timeout=-1):
        """my wait wrapper, sets a default finite timeout to avoid hangs"""
        if timeout is None or timeout < 0:
            timeout = self.timeout
        return Client.wait(client, jobs, timeout)

    def connect_client(self):
        """connect a client with my Context, and track its sockets for cleanup"""
        c = Client(profile='iptest', context=self.context)
        c.wait = lambda *a, **kw: self.client_wait(c, *a, **kw)
        return c

    def assertRaisesRemote(self, etype, f, *args, **kwargs):
        with raises_remote(etype):
            f(*args, **kwargs)

    def _wait_for(self, f, timeout=10):
        """wait for a condition"""
        tic = time.time()
        while time.time() <= tic + timeout:
            if f():
                return
            time.sleep(0.1)
        if not f():
            print("Warning: Awaited condition never arrived")

    test_timeout = 30

    def setup_method(self):
        self.context = zmq.Context.instance()
        if hasattr(signal, 'SIGALRM'):
            # use sigalarm for test timeout
            def _sigalarm(sig, frame):
                raise TimeoutError(
                    f"test did not finish in {self.test_timeout} seconds"
                )

            signal.signal(signal.SIGALRM, _sigalarm)
            signal.alarm(self.test_timeout)

        add_engines(self.engine_count, total=True)

        self.client = self.connect_client()
        # start every test with clean engine namespaces:
        self.client.clear(block=True)
        self.base_engine_count = len(self.client.ids)
        self.engines = []

    def teardown_method(self):
        if len(self.client):
            self.client[:].use_pickle()

        # self.client.clear(block=True)
        # close fds:
        for e in filter(lambda e: e.poll() is not None, launchers):
            launchers.remove(e)

        # allow flushing of incoming messages to prevent crash on socket close
        self.client.wait(timeout=2)
        self.client.close()
        if hasattr(signal, 'SIGALRM'):
            signal.alarm(0)
            signal.signal(signal.SIGALRM, signal.SIG_DFL)
        self.context.destroy()
