#!/usr/bin/env python3
"""
Starts a game in a contained process tree, waits for the game to start,
gently tries to close other game processes when the main game has exited.
"""

import os
import sys
import time
import subprocess
import signal
import logging
import ctypes
from ctypes.util import find_library
from lutris.util.monitor import ProcessMonitor
from lutris.util.log import logger

try:
    from setproctitle import setproctitle
except ImportError:
    setproctitle = print


PR_SET_CHILD_SUBREAPER = 36


def set_child_subreaper():
    """Sets the current process to a subreaper.

    A subreaper fulfills the role of init(1) for its descendant
    processes.  When a process becomes orphaned (i.e., its
    immediate parent terminates) then that process will be
    reparented to the nearest still living ancestor subreaper.
    Subsequently, calls to getppid() in the orphaned process will
    now return the PID of the subreaper process, and when the
    orphan terminates, it is the subreaper process that will
    receive a SIGCHLD signal and will be able to wait(2) on the
    process to discover its termination status.

    The setting of this bit is not inherited by children created
    by fork(2) and clone(2).  The setting is preserved across
    execve(2).

    Establishing a subreaper process is useful in session
    management frameworks where a hierarchical group of processes
    is managed by a subreaper process that needs to be informed
    when one of the processes—for example, a double-forked daemon—
    terminates (perhaps so that it can restart that process).
    Some init(1) frameworks (e.g., systemd(1)) employ a subreaper
    process for similar reasons.
    """
    result = ctypes.CDLL(find_library('c')).prctl(PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0, 0)
    if result == -1:
        print("PR_SET_CHILD_SUBREAPER failed, process watching may fail")


def log(line):
    """Generic log function that can be adjusted for any log output method
    (stdout, file, logging, t2s, Discord, ...)
    """
    try:
        sys.stdout.write(line + "\n")
        sys.stdout.flush()
    except BrokenPipeError:
        pass

    # File output example
    # with open(os.path.expanduser("~/lutris.log"), "a") as logfile:
    #     logfile.write(line)
    #     logfile.write("\n")


def main():
    """Runs a command independently from the Lutris client"""
    # pylint: disable=too-many-branches,too-many-statements
    # TODO: refactor
    set_child_subreaper()
    _, proc_title, include_proc_count, exclude_proc_count, *args = sys.argv

    setproctitle("lutris-wrapper: " + proc_title)

    # So I'm too lazy to implement real argument parsing... sorry.
    include_proc_count = int(include_proc_count)
    exclude_proc_count = int(exclude_proc_count)
    include_procs, args = args[:include_proc_count], args[include_proc_count:]
    exclude_procs, args = args[:exclude_proc_count], args[exclude_proc_count:]

    if "PYTHONPATH" in os.environ:
        del os.environ["PYTHONPATH"]
    monitor = ProcessMonitor(include_procs, exclude_procs)

    def hard_sig_handler(signum, _frame):
        log("Caught another signal, sending SIGKILL.")
        for _ in range(3):  # just in case we race a new process.
            for child in monitor.iterate_all_processes():
                try:
                    os.kill(child.pid, signal.SIGKILL)
                except ProcessLookupError:  # process already dead
                    pass
        log("--killed processes--")

    def sig_handler(signum, _frame):
        log("Caught signal %s" % signum)
        signal.signal(signal.SIGTERM, hard_sig_handler)
        signal.signal(signal.SIGINT, hard_sig_handler)
        for child in monitor.iterate_game_processes():
            log("passing along signal to PID %s" % child.pid)
            try:
                os.kill(child.pid, signum)
            except ProcessLookupError:  # process already dead
                pass
        log("--terminated processes--")

    signal.signal(signal.SIGTERM, sig_handler)
    signal.signal(signal.SIGINT, sig_handler)

    log("Running %s" % " ".join(args))
    returncode = None
    try:
        initial_pid = subprocess.Popen(args).pid
    except FileNotFoundError:
        log("Failed to execute process. Check that the file exists")
        return

    log("Initial process has started with pid %d" % initial_pid)

    class NoMoreChildren(Exception):
        "Raised when async_reap_children finds no children left"

    def async_reap_children():
        """
        Attempts to reap zombie child processes. Thanks to setting
        ourselves as a subreaper, we are assigned zombie processes
        that our children orphan and so we are responsible for
        clearing them.

        This is also how we determine what our main process' exit
        code was so that we can forward it to our caller.
        """
        nonlocal returncode

        while True:
            try:
                dead_pid, dead_returncode, usage = os.wait3(os.WNOHANG)
            except ChildProcessError:
                # No processes remain. No need to check monitor.
                raise NoMoreChildren from None

            if returncode is None and dead_pid == initial_pid:
                returncode = dead_returncode
                log("Initial process has exited (return code: %s)" % dead_returncode)

            if dead_pid == 0:
                break

    try:
        # While we are inside this try..except, if at the time of any
        # call to async_reap_children there are no children left, we
        # will skip the rest of our cleanup logic, since with no
        # children remaining, there's nothing left to wait for.
        #
        # This behavior doesn't help with ignoring "system processes",
        # so its more of a shortcut out of this code than it is
        # essential for correctness.

        # The initial wait loop:
        #  the initial process may have been excluded. Wait for the game
        #  to be considered "started".
        if not monitor.is_game_alive():
            log("Waiting for game to start (first non-excluded process started)")
            while not monitor.is_game_alive():
                async_reap_children()
                time.sleep(0.1)

        # The main wait loop:
        #  The game is running. Our process is now just waiting around
        #  for processes to exit, waiting up every .1s to reap child
        #  processes.
        log("Start monitoring process.")
        while monitor.is_game_alive():
            async_reap_children()
            time.sleep(0.1)

        log("Monitored process exited.")
        async_reap_children()

        # The exit wait loop:
        #  The game is no longer running. We ask monitored processes
        #  to exit and wait 30 seconds before sending more SIGTERMs.
        while monitor.are_monitored_processes_alive():
            async_reap_children()
            child = None
            for child in monitor.iterate_monitored_processes():
                log("Sending SIGTERM to PID %s (pid %s)" % (child.name, child.pid))
                try:
                    os.kill(child.pid, signal.SIGTERM)
                except ProcessLookupError:  # process already dead
                    pass

            # Spend 60 seconds waiting for processes to clean up.
            async_reap_children()
            for _ in range(600):
                if not monitor.are_monitored_processes_alive():
                    break

                if _ == 0:
                    log("Waiting up to 30sec for processes to exit.")

                async_reap_children()
                time.sleep(0.1)

        async_reap_children()
        log("All monitored processes have exited.")

    except NoMoreChildren:
        log("All children have exited.")

    if returncode is None:
        returncode = 0
        log("Monitored process didn't return an exit code.")
    log("Exit with returncode %s" % returncode)
    sys.exit(returncode)


if __name__ == "__main__":
    LAUNCH_PATH = os.path.dirname(os.path.realpath(__file__))
    if os.path.isdir(os.path.join(LAUNCH_PATH, "../lutris")):
        logger.setLevel(logging.DEBUG)
        sys.dont_write_bytecode = True
        SOURCE_PATH = os.path.normpath(os.path.join(LAUNCH_PATH, '..'))
        sys.path.insert(0, SOURCE_PATH)
    else:
        sys.path.insert(0, os.path.normpath(os.path.join(LAUNCH_PATH, "../lib/lutris")))

    main()
