#
# This file is part of pysnmp software.
#
# Copyright (c) 2005-2020, Ilya Etingof <etingof@gmail.com>
# License: https://www.pysnmp.com/pysnmp/license.html
#
import os
import sys
import struct
import marshal
import time
import traceback

try:
    import importlib
    import importlib.util
    import importlib.machinery

    try:
        PY_MAGIC_NUMBER = importlib.util.MAGIC_NUMBER
        SOURCE_SUFFIXES = importlib.machinery.SOURCE_SUFFIXES
        BYTECODE_SUFFIXES = importlib.machinery.BYTECODE_SUFFIXES

    except Exception:
        raise ImportError()

except ImportError:
    import imp

    PY_MAGIC_NUMBER = imp.get_magic()
    SOURCE_SUFFIXES = [s[0] for s in imp.get_suffixes() if s[2] == imp.PY_SOURCE]
    BYTECODE_SUFFIXES = [s[0] for s in imp.get_suffixes() if s[2] == imp.PY_COMPILED]

PY_SUFFIXES = SOURCE_SUFFIXES + BYTECODE_SUFFIXES

try:
    from errno import ENOENT
except ImportError:
    ENOENT = -1

from pysnmp import version as pysnmp_version
from pysnmp.smi import error
from pysnmp import debug

classTypes = (type,)


class __AbstractMibSource:
    def __init__(self, srcName):
        self._srcName = srcName
        self.__inited = None
        debug.logger & debug.flagBld and debug.logger("trying %s" % self)

    def __repr__(self):
        return f"{self.__class__.__name__}({self._srcName!r})"

    def _uniqNames(self, files):
        u = set()

        for f in files:
            if f.startswith("__init__."):
                continue

            u.update(f[: -len(sfx)] for sfx in PY_SUFFIXES if f.endswith(sfx))

        return tuple(u)

    # MibSource API follows

    def fullPath(self, f="", sfx=""):
        return self._srcName + (f and (os.sep + f + sfx) or "")

    def init(self):
        if self.__inited is None:
            self.__inited = self._init()
            if self.__inited is self:
                self.__inited = True
        if self.__inited is True:
            return self

        else:
            return self.__inited

    def listdir(self):
        return self._listdir()

    def read(self, f):
        pycTime = pyTime = -1

        for pycSfx in BYTECODE_SUFFIXES:
            try:
                pycData, pycPath = self._getData(f + pycSfx, "rb")

            except OSError:
                why = sys.exc_info()[1]
                if ENOENT == -1 or why.errno == ENOENT:
                    debug.logger & debug.flagBld and debug.logger(
                        f"file {f + pycSfx} access error: {why}"
                    )

                else:
                    raise error.MibLoadError(
                        f"MIB file {f + pycSfx} access error: {why}"
                    )

            else:
                if PY_MAGIC_NUMBER == pycData[:4]:
                    pycData = pycData[4:]
                    pycTime = struct.unpack("<L", pycData[:4])[0]
                    pycData = pycData[4:]
                    debug.logger & debug.flagBld and debug.logger(
                        "file %s mtime %d" % (pycPath, pycTime)
                    )
                    break

                else:
                    debug.logger & debug.flagBld and debug.logger(
                        "bad magic in %s" % pycPath
                    )

        for pySfx in SOURCE_SUFFIXES:
            try:
                pyTime = self._getTimestamp(f + pySfx)

            except OSError:
                why = sys.exc_info()[1]
                if ENOENT == -1 or why.errno == ENOENT:
                    debug.logger & debug.flagBld and debug.logger(
                        f"file {f + pySfx} access error: {why}"
                    )

                else:
                    raise error.MibLoadError(
                        f"MIB file {f + pySfx} access error: {why}"
                    )

            else:
                debug.logger & debug.flagBld and debug.logger(
                    "file %s mtime %d" % (f + pySfx, pyTime)
                )
                break

        if pycTime != -1 and pycTime >= pyTime:
            return marshal.loads(pycData), pycSfx

        if pyTime != -1:
            modData, pyPath = self._getData(f + pySfx, "r")
            return compile(modData, pyPath, "exec"), pyPath

        raise OSError(ENOENT, "No suitable module found", f)

    # Interfaces for subclasses
    def _init(self):
        raise NotImplementedError()

    def _listdir(self):
        raise NotImplementedError()

    def _getTimestamp(self, f):
        raise NotImplementedError()

    def _getData(self, f, mode):
        NotImplementedError()


class ZipMibSource(__AbstractMibSource):
    def _init(self):
        try:
            p = __import__(self._srcName, globals(), locals(), ["__init__"])
            if hasattr(p, "__loader__") and hasattr(p.__loader__, "_files"):
                self.__loader = p.__loader__
                self._srcName = self._srcName.replace(".", os.sep)
                return self
            elif hasattr(p, "__file__"):
                # Dir relative to PYTHONPATH
                return DirMibSource(os.path.split(p.__file__)[0]).init()
            else:
                raise error.MibLoadError(f"{p} access error")

        except ImportError:
            # Dir relative to CWD
            return DirMibSource(self._srcName).init()

    @staticmethod
    def _parseDosTime(dosdate, dostime):
        t = (
            ((dosdate >> 9) & 0x7F) + 1980,  # year
            ((dosdate >> 5) & 0x0F),  # month
            dosdate & 0x1F,  # mday
            (dostime >> 11) & 0x1F,  # hour
            (dostime >> 5) & 0x3F,  # min
            (dostime & 0x1F) * 2,  # sec
            -1,  # wday
            -1,  # yday
            -1,
        )  # dst
        return time.mktime(t)

    def _listdir(self):
        l = []
        # noinspection PyProtectedMember
        for f in self.__loader._files.keys():
            d, f = os.path.split(f)
            if d == self._srcName:
                l.append(f)
        return tuple(self._uniqNames(l))

    def _getTimestamp(self, f):
        p = os.path.join(self._srcName, f)
        # noinspection PyProtectedMember
        if p in self.__loader._files:
            # noinspection PyProtectedMember
            return self._parseDosTime(
                self.__loader._files[p][6], self.__loader._files[p][5]
            )
        else:
            raise OSError(ENOENT, "No such file in ZIP archive", p)

    def _getData(self, f, mode=None):
        p = os.path.join(self._srcName, f)
        try:
            return self.__loader.get_data(p), p

        except Exception:  # ZIP code seems to return all kinds of errors
            why = sys.exc_info()
            raise OSError(ENOENT, f"File or ZIP archive {p} access error: {why[1]}")


class DirMibSource(__AbstractMibSource):
    def _init(self):
        self._srcName = os.path.normpath(self._srcName)
        return self

    def _listdir(self):
        try:
            return self._uniqNames(os.listdir(self._srcName))
        except OSError:
            why = sys.exc_info()
            debug.logger & debug.flagBld and debug.logger(
                f"listdir() failed for {self._srcName}: {why[1]}"
            )
            return ()

    def _getTimestamp(self, f):
        p = os.path.join(self._srcName, f)
        try:
            return os.stat(p)[8]
        except OSError:
            raise OSError(ENOENT, "No such file: %s" % sys.exc_info()[1], p)

    def _getData(self, f, mode):
        p = os.path.join(self._srcName, "*")
        try:
            if f in os.listdir(self._srcName):  # make FS case-sensitive
                p = os.path.join(self._srcName, f)
                fp = open(p, mode)
                data = fp.read()
                fp.close()
                return data, p

        except OSError:
            why = sys.exc_info()
            msg = f"File or directory {p} access error: {why[1]}"

        else:
            msg = "No such file or directory: %s" % p

        raise OSError(ENOENT, msg)


class MibBuilder:
    defaultCoreMibs = os.pathsep.join(("pysnmp.smi.mibs.instances", "pysnmp.smi.mibs"))
    defaultMiscMibs = "pysnmp_mibs"

    moduleID = "PYSNMP_MODULE_ID"

    loadTexts = False

    # MIB modules can use this to select the features they can use
    version = pysnmp_version

    def __init__(self):
        self.lastBuildId = self._autoName = 0
        sources = []
        for ev in "PYSNMP_MIB_PKGS", "PYSNMP_MIB_DIRS", "PYSNMP_MIB_DIR":
            if ev in os.environ:
                for m in os.environ[ev].split(os.pathsep):
                    sources.append(ZipMibSource(m))
        if not sources and self.defaultMiscMibs:
            for m in self.defaultMiscMibs.split(os.pathsep):
                sources.append(ZipMibSource(m))
        for m in self.defaultCoreMibs.split(os.pathsep):
            sources.insert(0, ZipMibSource(m))
        self.mibSymbols = {}
        self.__mibSources = []
        self.__modSeen = {}
        self.__modPathsSeen = set()
        self.__mibCompiler = None
        self.setMibSources(*sources)

    # MIB compiler management

    def getMibCompiler(self):
        return self.__mibCompiler

    def setMibCompiler(self, mibCompiler, destDir):
        self.addMibSources(DirMibSource(destDir))
        self.__mibCompiler = mibCompiler
        return self

    # MIB modules management

    def addMibSources(self, *mibSources):
        self.__mibSources.extend([s.init() for s in mibSources])
        debug.logger & debug.flagBld and debug.logger(
            f"addMibSources: new MIB sources {self.__mibSources}"
        )

    def setMibSources(self, *mibSources):
        self.__mibSources = [s.init() for s in mibSources]
        debug.logger & debug.flagBld and debug.logger(
            f"setMibSources: new MIB sources {self.__mibSources}"
        )

    def getMibSources(self):
        return tuple(self.__mibSources)

    # Legacy/compatibility methods (won't work for .eggs)
    def setMibPath(self, *mibPaths):
        self.setMibSources(*[DirMibSource(x) for x in mibPaths])

    def getMibPath(self):
        paths = ()
        for mibSource in self.getMibSources():
            if isinstance(mibSource, DirMibSource):
                paths += (mibSource.fullPath(),)
            else:
                raise error.MibLoadError(
                    f"MIB source is not a plain directory: {mibSource}"
                )
        return paths

    def loadModule(self, modName, **userCtx):
        """Load and execute MIB modules as Python code"""
        for mibSource in self.__mibSources:
            debug.logger & debug.flagBld and debug.logger(
                f"loadModule: trying {modName} at {mibSource}"
            )
            try:
                codeObj, sfx = mibSource.read(modName)

            except OSError:
                debug.logger & debug.flagBld and debug.logger(
                    f"loadModule: read {modName} from {mibSource} failed: {sys.exc_info()[1]}"
                )
                continue

            modPath = mibSource.fullPath(modName, sfx)

            if modPath in self.__modPathsSeen:
                debug.logger & debug.flagBld and debug.logger(
                    "loadModule: seen %s" % modPath
                )
                break

            else:
                self.__modPathsSeen.add(modPath)

            debug.logger & debug.flagBld and debug.logger(
                "loadModule: evaluating %s" % modPath
            )

            g = {"mibBuilder": self, "userCtx": userCtx}

            try:
                exec(codeObj, g)

            except Exception:
                self.__modPathsSeen.remove(modPath)
                raise error.MibLoadError(
                    f"MIB module '{modPath}' load error: {traceback.format_exception(*sys.exc_info())}"
                )

            self.__modSeen[modName] = modPath

            debug.logger & debug.flagBld and debug.logger(
                "loadModule: loaded %s" % modPath
            )

            break

        if modName not in self.__modSeen:
            raise error.MibNotFoundError(
                'MIB file "{}" not found in search path ({})'.format(
                    modName and modName + ".py[co]",
                    ", ".join([str(x) for x in self.__mibSources]),
                )
            )

        return self

    def loadModules(self, *modNames, **userCtx):
        """Load (optionally, compiling) pysnmp MIB modules"""
        # Build a list of available modules
        if not modNames:
            modNames = {}
            for mibSource in self.__mibSources:
                for modName in mibSource.listdir():
                    modNames[modName] = None
            modNames = list(modNames)

        if not modNames:
            raise error.MibNotFoundError(f"No MIB module to load at {self}")

        for modName in modNames:
            try:
                self.loadModule(modName, **userCtx)

            except error.MibNotFoundError:
                if self.__mibCompiler:
                    debug.logger & debug.flagBld and debug.logger(
                        "loadModules: calling MIB compiler for %s" % modName
                    )
                    status = self.__mibCompiler.compile(
                        modName, genTexts=self.loadTexts
                    )
                    errs = "; ".join(
                        [
                            hasattr(x, "error") and str(x.error) or x
                            for x in status.values()
                            if x in ("failed", "missing")
                        ]
                    )
                    if errs:
                        raise error.MibNotFoundError(
                            f"{modName} compilation error(s): {errs}"
                        )

                    # compilation succeeded, MIB might load now
                    self.loadModule(modName, **userCtx)

        return self

    def unloadModules(self, *modNames):
        if not modNames:
            modNames = list(self.mibSymbols.keys())
        for modName in modNames:
            if modName not in self.mibSymbols:
                raise error.MibNotFoundError(f"No module {modName} at {self}")
            self.unexportSymbols(modName)
            self.__modPathsSeen.remove(self.__modSeen[modName])
            del self.__modSeen[modName]

            debug.logger & debug.flagBld and debug.logger("unloadModules: %s" % modName)

        return self

    def importSymbols(self, modName, *symNames, **userCtx):
        if not modName:
            raise error.SmiError("importSymbols: empty MIB module name")
        r = ()
        for symName in symNames:
            if modName not in self.mibSymbols:
                self.loadModules(modName, **userCtx)
            if modName not in self.mibSymbols:
                raise error.MibNotFoundError(f"No module {modName} loaded at {self}")
            if symName not in self.mibSymbols[modName]:
                raise error.SmiError(f"No symbol {modName}::{symName} at {self}")
            r = r + (self.mibSymbols[modName][symName],)
        return r

    def exportSymbols(self, modName, *anonymousSyms, **namedSyms):
        if modName not in self.mibSymbols:
            self.mibSymbols[modName] = {}
        mibSymbols = self.mibSymbols[modName]

        for symObj in anonymousSyms:
            debug.logger & debug.flagBld and debug.logger(
                "exportSymbols: anonymous symbol %s::__pysnmp_%ld"
                % (modName, self._autoName)
            )
            mibSymbols["__pysnmp_%ld" % self._autoName] = symObj
            self._autoName += 1
        for symName, symObj in namedSyms.items():
            if symName in mibSymbols:
                raise error.SmiError(f"Symbol {symName} already exported at {modName}")

            if symName != self.moduleID and not isinstance(symObj, classTypes):
                label = symObj.getLabel()
                if label:
                    symName = label
                else:
                    symObj.setLabel(symName)

            mibSymbols[symName] = symObj

            debug.logger & debug.flagBld and debug.logger(
                f"exportSymbols: symbol {modName}::{symName}"
            )

        self.lastBuildId += 1

    def unexportSymbols(self, modName, *symNames):
        if modName not in self.mibSymbols:
            raise error.SmiError(f"No module {modName} at {self}")
        mibSymbols = self.mibSymbols[modName]
        if not symNames:
            symNames = list(mibSymbols.keys())
        for symName in symNames:
            if symName not in mibSymbols:
                raise error.SmiError(f"No symbol {modName}::{symName} at {self}")
            del mibSymbols[symName]

            debug.logger & debug.flagBld and debug.logger(
                f"unexportSymbols: symbol {modName}::{symName}"
            )

        if not self.mibSymbols[modName]:
            del self.mibSymbols[modName]

        self.lastBuildId += 1
