#!/usr/bin/python3
#
# lxc-start-ephemeral: Start a copy of a container using an overlay
#
# This python implementation is based on the work done in the original
# shell implementation done by Serge Hallyn in Ubuntu (and other contributors)
#
# (C) Copyright Canonical Ltd. 2012
#
# Authors:
# Stéphane Graber <stgraber@ubuntu.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#

import argparse
import gettext
import lxc
import os
import sys
import subprocess
import tempfile

_ = gettext.gettext
gettext.textdomain("lxc-start-ephemeral")


# Other functions
def randomMAC():
    import random

    mac = [0x00, 0x16, 0x3e,
           random.randint(0x00, 0x7f),
           random.randint(0x00, 0xff),
           random.randint(0x00, 0xff)]
    return ':'.join(map(lambda x: "%02x" % x, mac))


def get_rundir():
    if os.geteuid() == 0:
        return "/run"

    if "XDG_RUNTIME_DIR" in os.environ:
        return os.environ["XDG_RUNTIME_DIR"]

    if "HOME" in os.environ:
        return "%s/.cache/lxc/run/" % os.environ["HOME"]

    raise Exception("Unable to find a runtime directory")


# Begin parsing the command line
parser = argparse.ArgumentParser(description=_(
                                 "LXC: Start an ephemeral container"),
                                 formatter_class=argparse.RawTextHelpFormatter,
                                 epilog=_("If a COMMAND is given, then the "
                                          """container will run only as long
as the command runs.
If no COMMAND is given, this command will attach to tty1 and stop the
container when exiting (with ctrl-a-q).

If no COMMAND is given and -d is used, the name and IP addresses of the
container will be printed to the console."""))

parser.add_argument("--lxcpath", "-P", dest="lxcpath", metavar="PATH",
                    help=_("Use specified container path"), default=None)

parser.add_argument("--orig", "-o", type=str, required=True,
                    help=_("name of the original container"))

parser.add_argument("--name", "-n", type=str,
                    help=_("name of the target container"))

parser.add_argument("--bdir", "-b", type=str, action="append", default=[],
                    help=_("directory to bind mount into container"))

parser.add_argument("--cdir", "-c", type=str, action="append", default=[],
                    help=_("directory to cow mount into container"))

parser.add_argument("--user", "-u", type=str,
                    help=_("the user to run the command as"))

parser.add_argument("--key", "-S", type=str,
                    help=_("the path to the key to use to connect "
                           "(when using ssh)"))

parser.add_argument("--daemon", "-d", action="store_true",
                    help=_("run in the background"))

parser.add_argument("--storage-type", "-s", type=str, default=None,
                    choices=("tmpfs", "dir"),
                    help=("type of storage use by the container"))

parser.add_argument("--union-type", "-U", type=str, default="overlayfs",
                    choices=("overlayfs", "aufs"),
                    help=_("type of union (overlayfs or aufs), "
                           "defaults to overlayfs."))

parser.add_argument("--keep-data", "-k", action="store_true",
                    help=_("don't wipe everything clean at the end"))

parser.add_argument("command", metavar='CMD', type=str, nargs="*",
                    help=_("Run specific command in container "
                           "(command as argument)"))

parser.add_argument("--version", action="version", version=lxc.version)

args = parser.parse_args()

## Check that -d and CMD aren't used at the same time
if args.command and args.daemon:
    parser.error(_("You can't use -d and a command at the same time."))

## Check that -k isn't used with -s tmpfs
if not args.storage_type:
    if args.keep_data:
        args.storage_type = "dir"
    else:
        args.storage_type = "tmpfs"

if args.keep_data and args.storage_type == "tmpfs":
    parser.error(_("You can't use -k with the tmpfs storage type."))

# Load the orig container
orig = lxc.Container(args.orig, args.lxcpath)
if not orig.defined:
    parser.error(_("Source container '%s' doesn't exist." % args.orig))

# Create the new container paths
if not args.lxcpath:
    lxc_path = lxc.default_config_path
else:
    lxc_path = args.lxcpath

if args.name:
    if os.path.exists("%s/%s" % (lxc_path, args.name)):
        parser.error(_("A container named '%s' already exists." % args.name))
    dest_path = "%s/%s" % (lxc_path, args.name)
    os.mkdir(dest_path)
else:
    dest_path = tempfile.mkdtemp(prefix="%s-" % args.orig, dir=lxc_path)
os.mkdir(os.path.join(dest_path, "rootfs"))
os.chmod(dest_path, 0o770)

# Setup the new container's configuration
dest = lxc.Container(os.path.basename(dest_path), args.lxcpath)
dest.load_config(orig.config_file_name)
dest.set_config_item("lxc.utsname", dest.name)
dest.set_config_item("lxc.rootfs", os.path.join(dest_path, "rootfs"))
print("setting rootfs to .%s.", os.path.join(dest_path, "rootfs"))
for nic in dest.network:
    if hasattr(nic, 'hwaddr'):
        nic.hwaddr = randomMAC()

overlay_dirs = [(orig.get_config_item("lxc.rootfs"), "%s/rootfs/" % dest_path)]

# Generate a new fstab
if orig.get_config_item("lxc.mount"):
    dest.set_config_item("lxc.mount", os.path.join(dest_path, "fstab"))
    with open(orig.get_config_item("lxc.mount"), "r") as orig_fd:
        with open(dest.get_config_item("lxc.mount"), "w+") as dest_fd:
            for line in orig_fd.read().split("\n"):
                # Start by replacing any reference to the container rootfs
                line.replace(orig.get_config_item("lxc.rootfs"),
                             dest.get_config_item("lxc.rootfs"))

                fields = line.split()

                # Skip invalid entries
                if len(fields) < 4:
                    continue

                # Non-bind mounts are kept as-is
                if "bind" not in fields[3]:
                    dest_fd.write("%s\n" % line)
                    continue

                # Bind mounts of virtual filesystems are also kept as-is
                src_path = fields[0].split("/")
                if len(src_path) > 1 and src_path[1] in ("proc", "sys"):
                    dest_fd.write("%s\n" % line)
                    continue

                # Skip invalid mount points
                dest_mount = os.path.abspath(os.path.join("%s/rootfs/" % (
                                             dest_path), fields[1]))

                if "%s/rootfs/" % dest_path not in dest_mount:
                    print(_("Skipping mount entry '%s' as it's outside "
                            "of the container rootfs.") % line)

                # Setup an overlay for anything remaining
                overlay_dirs += [(fields[0], dest_mount)]

for entry in args.cdir:
    if not os.path.exists(entry):
        print(_("Path '%s' doesn't exist, won't be cow-mounted.") %
              entry)
    else:
        src_path = os.path.abspath(entry)
        dst_path = "%s/rootfs/%s" % (dest_path, src_path)
        overlay_dirs += [(src_path, dst_path)]

# do we have the new overlay fs which requires workdir, or the older
# overlayfs which does not?
have_new_overlay = False
with open("/proc/filesystems", "r") as fd:
    for line in fd:
        if line == "nodev\toverlay\n":
            have_new_overlay = True

# Generate pre-mount script
with open(os.path.join(dest_path, "pre-mount"), "w+") as fd:
    os.fchmod(fd.fileno(), 0o755)
    fd.write("""#!/bin/sh
LXC_DIR="%s"
LXC_BASE="%s"
LXC_NAME="%s"
""" % (dest_path, orig.name, dest.name))

    count = 0
    for entry in overlay_dirs:
        tmpdir = "%s/tmpfs" % dest_path
        fd.write("mkdir -p %s\n" % (tmpdir))
        if args.storage_type == "tmpfs":
            fd.write("mount -n -t tmpfs -o mode=0755 none %s\n" % (tmpdir))
        deltdir = "%s/delta%s" % (tmpdir, count)
        workdir = "%s/work%s" % (tmpdir, count)
        fd.write("mkdir -p %s %s\n" % (deltdir, entry[1]))
        if have_new_overlay:
            fd.write("mkdir -p %s\n" % workdir)

        fd.write("getfacl -a %s | setfacl --set-file=- %s || true\n" %
                 (entry[0], deltdir))
        fd.write("getfacl -a %s | setfacl --set-file=- %s || true\n" %
                 (entry[0], entry[1]))

        if args.union_type == "overlayfs":
            if have_new_overlay:
                fd.write("mount -n -t overlay"
                         " -oupperdir=%s,lowerdir=%s,workdir=%s none %s\n" % (
                             deltdir,
                             entry[0],
                             workdir,
                             entry[1]))
            else:
                fd.write("mount -n -t overlayfs"
                         " -oupperdir=%s,lowerdir=%s none %s\n" % (
                             deltdir,
                             entry[0],
                             entry[1]))
        elif args.union_type == "aufs":
            xino_path = "/dev/shm/aufs.xino"
            if not os.path.exists(os.path.basename(xino_path)):
                os.makedirs(os.path.basename(xino_path))

            fd.write("mount -n -t aufs "
                     "-o br=%s=rw:%s=ro,noplink,xino=%s none %s\n" % (
                         deltdir,
                         entry[0],
                         xino_path,
                         entry[1]))
        count += 1

    for entry in args.bdir:
        if not os.path.exists(entry):
            print(_("Path '%s' doesn't exist, won't be bind-mounted.") %
                  entry)
        else:
            src_path = os.path.abspath(entry)
            dst_path = "%s/rootfs/%s" % (dest_path, os.path.abspath(entry))
            fd.write("mkdir -p %s\nmount -n --bind %s %s\n" % (
                     dst_path, src_path, dst_path))

    fd.write("""
[ -e $LXC_DIR/configured ] && exit 0
for file in $LXC_DIR/rootfs/etc/hostname \\
            $LXC_DIR/rootfs/etc/hosts \\
            $LXC_DIR/rootfs/etc/sysconfig/network \\
            $LXC_DIR/rootfs/etc/sysconfig/network-scripts/ifcfg-eth0; do
        [ -f "$file" ] && sed -i -e "s/$LXC_BASE/$LXC_NAME/" $file
done
touch $LXC_DIR/configured
""")

dest.set_config_item("lxc.hook.pre-mount",
                     os.path.join(dest_path, "pre-mount"))

# Generate post-stop script
if not args.keep_data:
    with open(os.path.join(dest_path, "post-stop"), "w+") as fd:
        os.fchmod(fd.fileno(), 0o755)
        fd.write("""#!/bin/sh
[ -d "%s" ] && rm -Rf "%s"
""" % (dest_path, dest_path))

    dest.set_config_item("lxc.hook.post-stop",
                         os.path.join(dest_path, "post-stop"))

dest.save_config()

# Start the container
if not dest.start() or not dest.wait("RUNNING", timeout=5):
    print(_("The container '%s' failed to start.") % dest.name)
    dest.stop()
    if dest.defined:
        dest.destroy()
    sys.exit(1)

# Deal with the case where we just attach to the container's console
if not args.command and not args.daemon:
    dest.console()
    if not dest.shutdown(timeout=5):
        dest.stop()
    sys.exit(0)

# Try to get the IP addresses
ips = dest.get_ips(timeout=10)

# Deal with the case where we just print info about the container
if args.daemon:
    print(_("""The ephemeral container is now started.

You can enter it from the command line with: lxc-console -n %s
The following IP addresses have be found in the container:
%s""") % (dest.name,
          "\n".join([" - %s" % entry for entry in ips]
                    or [" - %s" % _("No address could be found")])))
    sys.exit(0)

# Now deal with the case where we want to run a command in the container
if not ips:
    print(_("Failed to get an IP for container '%s'.") % dest.name)
    dest.stop()
    if dest.defined:
        dest.destroy()
    sys.exit(1)

if os.path.exists("/proc/self/ns/pid"):
    def attach_as_user(command):
        try:
            username = "root"
            if args.user:
                username = args.user

            line = subprocess.check_output(
                ["getent", "passwd", username],
                universal_newlines=True).rstrip("\n")
            _, _, pw_uid, pw_gid, _, pw_dir, _ = line.split(":", 6)
            pw_uid = int(pw_uid)
            pw_gid = int(pw_gid)
            os.setgid(pw_gid)
            os.initgroups(username, pw_gid)
            os.setuid(pw_uid)
            os.chdir(pw_dir)
            os.environ['HOME'] = pw_dir
        except:
            print(_("Unable to switch to user: %s" % username))
            sys.exit(1)

        return lxc.attach_run_command(command)

    retval = dest.attach_wait(attach_as_user, args.command,
                              env_policy=lxc.LXC_ATTACH_CLEAR_ENV)

else:
    cmd = ["ssh",
           "-o", "StrictHostKeyChecking=no",
           "-o", "UserKnownHostsFile=/dev/null"]

    if args.user:
        cmd += ["-l", args.user]

    if args.key:
        cmd += ["-i", args.key]

    for ip in ips:
        ssh_cmd = cmd + [ip] + args.command
        retval = subprocess.call(ssh_cmd, universal_newlines=True)
        if retval == 255:
            print(_("SSH failed to connect, trying next IP address."))
            continue

        if retval != 0:
            print(_("Command returned with non-zero return code: %s") % retval)
        break

# Shutdown the container
if not dest.shutdown(timeout=5):
    dest.stop()

sys.exit(retval)
