#!/usr/bin/python3
#
# lxc-ls: List containers
#
# This python implementation is based on the work done in the original
# shell implementation done by Serge Hallyn in Ubuntu (and other contributors)
#
# (C) Copyright Canonical Ltd. 2012
#
# Authors:
# Stéphane Graber <stgraber@ubuntu.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#

import argparse
import gettext
import json
import lxc
import os
import re
import shutil
import tempfile
import sys

_ = gettext.gettext
gettext.textdomain("lxc-ls")

# Required for containers without python
import encodings.ascii
assert encodings.ascii

# Constants
LXCPATH = "/var/lib/lxc"
RUNTIME_PATH = "/run"


# Functions used later on
def batch(iterable, cols=1):
    import math

    length = len(iterable)
    lines = math.ceil(length / cols)

    for line in range(lines):
        fields = []
        for col in range(cols):
            index = line + (col * lines)
            if index < length:
                fields.append(iterable[index])
        yield fields


def get_terminal_size():
    import os
    env = os.environ

    def ioctl_GWINSZ(fd):
        try:
            import fcntl
            import termios
            import struct
            cr = struct.unpack('hh', fcntl.ioctl(fd, termios.TIOCGWINSZ,
                               '1234'))
            return cr
        except:
            return

    cr = ioctl_GWINSZ(0) or ioctl_GWINSZ(1) or ioctl_GWINSZ(2)
    if not cr:
        try:
            fd = os.open(os.ctermid(), os.O_RDONLY)
            cr = ioctl_GWINSZ(fd)
            os.close(fd)
        except:
            pass

    if not cr:
        cr = (env.get('LINES', 25), env.get('COLUMNS', 80))

    return int(cr[1]), int(cr[0])


def get_root_path(path):
    lxc_path = LXCPATH
    global_conf = "%s/etc/lxc/lxc.conf" % path
    if os.path.exists(global_conf):
        with open(global_conf, "r") as fd:
            for line in fd:
                if line.startswith("lxc.lxcpath"):
                    lxc_path = line.split("=")[-1].strip()
                    break
    return lxc_path


# Constants
FIELDS = ("name", "state", "interfaces", "ipv4", "ipv6", "autostart", "pid",
          "memory", "ram", "swap", "groups")
DEFAULT_FIELDS = ("name", "state", "ipv4", "ipv6", "groups", "autostart")

# Begin parsing the command line
parser = argparse.ArgumentParser(description=_("LXC: List containers"),
                                 formatter_class=argparse.RawTextHelpFormatter,
                                 epilog=_("""Valid fancy-format fields:
  %s

Default fancy-format fields:
  %s\n""" % (", ".join(FIELDS), ", ".join(DEFAULT_FIELDS))))

parser.add_argument("-1", dest="one", action="store_true",
                    help=_("list one container per line (default when piped)"))

parser.add_argument("-P", "--lxcpath", dest="lxcpath", metavar="PATH",
                    help=_("Use specified container path"),
                    default=lxc.default_config_path)

parser.add_argument("--active", action="store_true",
                    help=_("list only active containers"))

parser.add_argument("--frozen", dest="state", action="append_const",
                    const="FROZEN", help=_("list only frozen containers"))

parser.add_argument("--running", dest="state", action="append_const",
                    const="RUNNING", help=_("list only running containers"))

parser.add_argument("--stopped", dest="state", action="append_const",
                    const="STOPPED", help=_("list only stopped containers"))

parser.add_argument("-f", "--fancy", action="store_true",
                    help=_("use fancy output"))

parser.add_argument("-F", "--fancy-format", type=str,
                    default=",".join(DEFAULT_FIELDS),
                    help=_("comma separated list of fields to show"))

parser.add_argument("-g", "--groups", type=str, action="append",
                    metavar="GROUPS",
                    help=_("groups (comma separated) the container must "
                           "be a member of"))

parser.add_argument("--nesting", dest="nesting", action="store_true",
                    help=_("show nested containers"))

parser.add_argument("filter", metavar='FILTER', type=str, nargs="?",
                    help=_("regexp to be applied on the container list"))

parser.add_argument("--version", action="version", version=lxc.version)

args = parser.parse_args()

# --active is the same as --running --frozen
if args.active:
    if not args.state:
        args.state = []
    args.state += ["RUNNING", "FROZEN", "UNKNOWN"]

# If the output is piped, default to --one
if not sys.stdout.isatty():
    args.one = True

# Turn args.fancy_format into a list
args.fancy_format = args.fancy_format.strip().split(",")

if set(args.fancy_format) - set(FIELDS):
    parser.error(_("Invalid field(s): %s" %
                 ", ".join(list(set(args.fancy_format) - set(FIELDS)))))

# Basic checks
## Check for setns
SUPPORT_SETNS = os.path.exists("/proc/self/ns")
SUPPORT_SETNS_NET = False
SUPPORT_SETNS_PID = False
if SUPPORT_SETNS:
    SUPPORT_SETNS_NET = os.path.exists("/proc/self/ns/net")
    SUPPORT_SETNS_PID = os.path.exists("/proc/self/ns/pid")

## Nesting requires setns to pid and net ns
if args.nesting:
    if not SUPPORT_SETNS:
        parser.error(_("Showing nested containers requires setns support "
                       "which your kernel doesn't support."))

    if not SUPPORT_SETNS_NET:
        parser.error(_("Showing nested containers requires setns to the "
                       "network namespace which your kernel doesn't support."))

    if not SUPPORT_SETNS_PID:
        parser.error(_("Showing nested containers requires setns to the "
                       "PID namespace which your kernel doesn't support."))

# Set the actual lxcpath value
if not args.lxcpath:
    args.lxcpath = lxc.default_config_path


# List of containers, stored as dictionaries
def get_containers(fd=None, base="/", root=False):
    containers = []

    paths = [args.lxcpath]

    if not root:
        paths.append(get_root_path(base))

    # Generate a unique list of valid paths
    paths = set([os.path.normpath("%s/%s" % (base, path)) for path in paths])

    for path in paths:
        if not os.access(path, os.R_OK):
            continue

        for container_name in lxc.list_containers(config_path=path):
            entry = {}
            entry['name'] = container_name

            # Apply filter
            if root and args.filter and \
               not re.match(args.filter, container_name):
                continue

            # Return before grabbing the object (non-root)
            if not args.state and not args.fancy and not args.nesting \
                    and not args.groups:
                containers.append(entry)
                continue

            try:
                container = lxc.Container(container_name, path)
            except:
                continue

            if args.groups or "autostart" in args.fancy_format \
                    or "groups" in args.fancy_format:
                try:
                    groups = container.get_config_item("lxc.group")
                except KeyError:
                    groups = []

            if args.groups:
                set_has = set(groups)

                for group in args.groups:
                    set_must = set(group.split(","))
                    if not set_must - set_has:
                        break
                else:
                    continue

            if container.controllable:
                state = container.state
            else:
                state = 'UNKNOWN'

            # Filter by status
            if args.state and state not in args.state:
                continue

            # Nothing more is needed if we're not printing some fancy output
            if not args.fancy and not args.nesting:
                containers.append(entry)
                continue

            # Some extra field we may want
            if 'state' in args.fancy_format:
                entry['state'] = state

            if 'pid' in args.fancy_format:
                entry['pid'] = "-"
                if state == 'UNKNOWN':
                    entry['pid'] = state
                elif container.init_pid != -1:
                    entry['pid'] = str(container.init_pid)

            if 'groups' in args.fancy_format:
                entry['groups'] = "-"
                if len(groups) > 0:
                    entry['groups'] = ", ".join(groups)

            if 'autostart' in args.fancy_format:
                entry['autostart'] = "NO"
                try:
                    if container.get_config_item("lxc.start.auto") == "1":
                        if len(groups) > 0:
                            entry['autostart'] = "BY-GROUP"
                        else:
                            entry['autostart'] = "YES"
                except KeyError:
                    pass

            if 'memory' in args.fancy_format or \
               'ram' in args.fancy_format or \
               'swap' in args.fancy_format:

                if container.running:
                    try:
                        memory_ram = int(container.get_cgroup_item(
                            "memory.usage_in_bytes"))
                    except:
                        memory_ram = 0

                    try:
                        memory_swap = int(container.get_cgroup_item(
                            "memory.memsw.usage_in_bytes")) - memory_ram
                    except:
                        memory_swap = 0
                else:
                    memory_ram = 0
                    memory_swap = 0

                memory_total = memory_ram + memory_swap

            if 'memory' in args.fancy_format:
                if container.running:
                    entry['memory'] = "%sMB" % round(memory_total / 1048576, 2)
                else:
                    entry['memory'] = "-"

            if 'ram' in args.fancy_format:
                if container.running:
                    entry['ram'] = "%sMB" % round(memory_ram / 1048576, 2)
                else:
                    entry['ram'] = "-"

            if 'swap' in args.fancy_format:
                if container.running:
                    entry['swap'] = "%sMB" % round(memory_swap / 1048576, 2)
                else:
                    entry['swap'] = "-"

            # Get the IPs
            for family, protocol in {'inet': 'ipv4', 'inet6': 'ipv6'}.items():
                if protocol in args.fancy_format:
                    entry[protocol] = "-"

                    if state == 'UNKNOWN':
                        entry[protocol] = state
                        continue

                    if container.running:
                        if not SUPPORT_SETNS_NET:
                            entry[protocol] = 'UNKNOWN'
                            continue

                        ips = container.get_ips(family=family)
                        if ips:
                            entry[protocol] = ", ".join(ips)

            # Get the interfaces
            if 'interfaces' in args.fancy_format:
                entry['interfaces'] = "-"

                if state == 'UNKNOWN' or (container.running and
                                          not SUPPORT_SETNS_NET):
                    entry['interfaces'] = "UNKNOWN"
                elif container.running:
                    interfaces = container.get_interfaces()
                    if interfaces:
                        entry['interfaces'] = ", ".join(interfaces)

            # Nested containers
            if args.nesting:
                if container.running:
                    # Recursive call in container namespace
                    temp_fd, temp_file = tempfile.mkstemp()
                    os.remove(temp_file)

                    container.attach_wait(get_containers, temp_fd)

                    json_file = os.fdopen(temp_fd, "r")
                    json_file.seek(0)

                    try:
                        sub_containers = json.loads(json_file.read())
                    except:
                        sub_containers = []

                    json_file.close()
                else:
                    def clear_lock():
                        try:
                            lock_path = "%s/lock/lxc/%s/%s" % (RUNTIME_PATH,
                                                               path,
                                                               entry['name'])
                            if os.path.exists(lock_path):
                                if os.path.isdir(lock_path):
                                    shutil.rmtree(lock_path)
                                else:
                                    os.remove(lock_path)
                        except:
                            pass

                    clear_lock()

                    # Recursive call using container rootfs
                    sub_containers = get_containers(
                        base="%s/%s" % (
                            base, container.get_config_item("lxc.rootfs")))

                    clear_lock()

                for sub in sub_containers:
                    if 'nesting_parent' not in sub:
                        sub['nesting_parent'] = []
                    sub['nesting_parent'].insert(0, entry['name'])
                    sub['nesting_real_name'] = sub.get('nesting_real_name',
                                                       sub['name'])
                    sub['name'] = "%s/%s" % (entry['name'], sub['name'])
                    containers.append(sub)

            # Append the container
            containers.append(entry)

    if fd:
        json_file = os.fdopen(fd, "w+")
        json_file.write(json.dumps(containers))
        return

    return containers

containers = get_containers(root=True)

# Print the list
## Standard list with one entry per line
if not args.fancy and args.one:
    for container in sorted(containers,
                            key=lambda container: container['name']):
        print(container['name'])
    sys.exit(0)

## Standard list with multiple entries per line
if not args.fancy and not args.one:
    # Get the longest name and extra simple list
    field_maxlength = 0
    container_names = []
    for container in containers:
        if len(container['name']) > field_maxlength:
            field_maxlength = len(container['name'])
        container_names.append(container['name'])

    # Figure out how many we can put per line
    width = get_terminal_size()[0]

    entries = int(width / (field_maxlength + 2))
    if entries == 0:
        entries = 1

    for line in batch(sorted(container_names), entries):
        line_format = ""
        for index in range(len(line)):
            line_format += "{line[%s]:%s}" % (index, field_maxlength + 2)

        print(line_format.format(line=line))

## Fancy listing
if args.fancy:
    field_maxlength = {}

    # Get the maximum length per field
    for field in args.fancy_format:
        field_maxlength[field] = len(field)

    for container in containers:
        for field in args.fancy_format:
            if field == 'name' and 'nesting_real_name' in container:
                fieldlen = len(" " * ((len(container['nesting_parent']) - 1)
                               * 4) + " \_ " + container['nesting_real_name'])
                if fieldlen > field_maxlength[field]:
                    field_maxlength[field] = fieldlen
            elif len(container[field]) > field_maxlength[field]:
                field_maxlength[field] = len(container[field])

    # Generate the line format string based on the maximum length and
    # a 2 character padding
    line_format = ""
    index = 0
    for field in args.fancy_format:
        line_format += "{fields[%s]:%s}" % (index, field_maxlength[field] + 2)
        index += 1

    # Get the line length minus the padding of the last field
    line_length = -2
    for field in field_maxlength:
        line_length += field_maxlength[field] + 2

    # Print header
    print(line_format.format(fields=[header.upper()
                                     for header in args.fancy_format]))
    print("-" * line_length)

    # Print the entries
    for container in sorted(containers,
                            key=lambda container: container['name']):
        fields = []
        for field in args.fancy_format:
            if field == 'name' and 'nesting_real_name' in container:
                prefix = " " * ((len(container['nesting_parent']) - 1) * 4)
                fields.append(prefix + " \_ " + container['nesting_real_name'])
            else:
                fields.append(container[field])

        print(line_format.format(fields=fields))
