#!/opt/local/bin/python3.13
#
# The ZeekControl interactive shell.

import logging
import os.path
import sys
import time

for path in (
    "/opt/local/lib/zeekctl",
    "/opt/local/lib/zeek/python",
    "/opt/local/lib/zeek/python/zeekctl",
):
    if os.path.isdir(path):
        sys.path.insert(0, path)

from ZeekControl import util, utilcurses, version, zeekcmd
from ZeekControl.exceptions import CommandSyntaxError, ZeekControlError
from ZeekControl.zeekctl import ZeekCtl


# Main command loop.
class ZeekCtlCmdLoop(zeekcmd.ExitValueCmd):
    prompt = "[ZeekControl] > "

    def __init__(self, zeekctl_class=ZeekCtl, interactive=False, cmd=""):
        zeekcmd.ExitValueCmd.__init__(self)
        self.zeekctl = zeekctl_class(ui=self)
        self.interactive = interactive

        # Warn user to do zeekctl install, if needed.  Skip this check when
        # running cron to avoid receiving annoying emails.  Also skip if the
        # install or deploy commands are running.
        if cmd not in ("cron", "install", "deploy"):
            self.zeekctl.warn_zeekctl_install()

    def finish(self):
        self.zeekctl.finish()

    def info(self, text):
        print(text)
        logging.info(text)

    def warn(self, text):
        self.info(f"Warning: {text}")

    def error(self, text):
        print(f"Error: {text}", file=sys.stderr)
        logging.info(text)

    def err(self, text):
        print(text, file=sys.stderr)
        logging.info(text)

    def default(self, line):
        strlist = line.split()
        cmd = strlist[0]
        cmdargs = " ".join(strlist[1:])

        results = self.zeekctl.plugincmd(cmd, cmdargs)

        if results.unknowncmd:
            self.error(f"unknown command '{cmd}'")

            if not self.interactive:
                self.do_help(None)

        return results.ok

    def emptyline(self):
        pass

    def precmd(self, line):
        logging.debug(line)
        return line

    def postcmd(self, stop, line):
        logging.debug("done")
        return stop

    def do_EOF(self, args):
        self._stopping = True
        return True

    def do_exit(self, args):
        """Terminates the shell."""
        self._stopping = True
        return True

    def do_quit(self, args):
        """Terminates the shell."""
        self._stopping = True
        return True

    def do_nodes(self, args):
        """Prints a list of all configured nodes.

        Note that the env_vars attribute includes the set of environment
        variables from the 'env_vars' option in both 'node.cfg' and
        'zeekctl.cfg' and also those set by any plugins."""

        if args:
            raise CommandSyntaxError("the nodes command does not take any arguments")

        results = self.zeekctl.nodes()
        for node, success, data in results.get_node_data():
            self.info(data["description"])

        return results.ok

    def do_config(self, args):
        """Prints all configuration options with their current values."""
        if args:
            raise CommandSyntaxError("the config command does not take any arguments")

        results = self.zeekctl.get_config()
        for key, val in results.keyval:
            self.info(f"{key} = {val}")

        return results.ok

    def do_install(self, args):
        """- [--local]

        Reinstalls on all nodes, including all configuration files and
        local policy scripts.

        The ``--local`` option is intended for testing or debugging.  It
        causes only the local host to be installed (i.e., no changes pushed
        out to any other hosts in the Zeek cluster).  Normally all nodes
        should be reinstalled at the same time, as any inconsistencies between
        them will lead to strange effects.

        This command must be executed after *all* changes to any part of
        the ZeekControl configuration or after upgrading to a new version
        of Zeek or ZeekControl, otherwise the modifications will not take effect.
        Before executing ``install``, it is recommended to verify the
        configuration with check_.  Note that when using the deploy command
        there is no need to first use the install command, because deploy
        automatically runs install before restarting the nodes."""

        local = False

        for arg in args.split():
            if arg == "--local":
                local = True
            else:
                raise CommandSyntaxError(
                    f"invalid argument for the install command: {arg}"
                )

        results = self.zeekctl.install(local)
        return results.ok

    def do_start(self, args):
        """- [<nodes>]

        Starts the given nodes, or all nodes if none are specified. Nodes
        already running are left untouched.
        """

        results = self.zeekctl.start(node_list=args)

        return results.ok

    def do_stop(self, args):
        """- [<nodes>]

        Stops the given nodes, or all nodes if none are specified. Nodes that
        are in the "crashed" state are reset to the "stopped" state, and
        nodes that are "stopped" are left untouched.
        """
        results = self.zeekctl.stop(node_list=args)

        return results.ok

    def do_restart(self, args):
        """- [--clean] [<nodes>]

        Restarts the given nodes, or all nodes if none are specified. The
        effect is the same as first executing stop_ followed
        by a start_, giving the same nodes in both cases.

        If ``--clean`` is given, the installation is reset into a clean state
        before restarting. More precisely, a ``restart --clean`` turns into
        the command sequence stop_, cleanup_, check_, install_, and
        start_.
        """
        clean = False
        if args.startswith("--clean"):
            args = args[7:]
            clean = True

        results = self.zeekctl.restart(clean=clean, node_list=args)
        return results.ok

    def do_deploy(self, args):
        """
        Checks for errors in Zeek policy scripts, then does an install followed
        by a restart on all nodes.  This command should be run after any
        changes to Zeek policy scripts or the zeekctl configuration, and after
        Zeek is upgraded or even just recompiled.

        This command is equivalent to running the check_, install_, and
        restart_ commands, in that order.
        """
        if args:
            raise CommandSyntaxError("the deploy command does not take any arguments")

        results = self.zeekctl.deploy()

        return results.ok

    def do_status(self, args):
        """- [<nodes>]

        Prints the current status of the given nodes.

        For each node, the information shown includes the node's name and type,
        the host where the node will run, the status, the PID, and the
        date/time when the node was started.  The status column will usually
        show a status of either "stopped" or "running".  A status of
        "crashed" means that ZeekControl verified that Zeek is no longer
        running, but was expected to be running."""

        success = True
        results = self.zeekctl.status(node_list=args)

        typewidth = 7
        hostwidth = 16
        data = results.get_node_data()
        if data and data[0][2]["type"] == "standalone":
            # In standalone mode, we need a wider "type" column.
            typewidth = 10
            hostwidth = 13

        showall = False
        if data:
            showall = "peers" in data[0][2]

        if showall:
            colfmt = "{name:<12} {type:<{0}} {host:<{1}} {status:<9} {pid:<6} {peers:<6} {started}"
        else:
            colfmt = "{name:<12} {type:<{0}} {host:<{1}} {status:<9} {pid:<6} {started}"

        hdrlist = ["name", "type", "host", "status", "pid", "peers", "started"]
        header = {x: x.title() for x in hdrlist}
        self.info(colfmt.format(typewidth, hostwidth, **header))

        colfmtstopped = "{name:<12} {type:<{0}} {host:<{1}} {status}"

        for data in results.get_node_data():
            node_info = data[2]
            mycolfmt = colfmt if node_info["pid"] else colfmtstopped

            self.info(mycolfmt.format(typewidth, hostwidth, **node_info))

            # Return status code of True only if all nodes are running
            if node_info["status"] != "running":
                success = False

        return success

    def _do_top_once(self, args):
        results = self.zeekctl.top(args)

        typewidth = 7
        hostwidth = 16
        data = results.get_node_data()
        if data:
            procinfo = data[0][2]["procs"]
            if procinfo["type"] == "standalone":
                # In standalone mode, we need a wider "type" column.
                typewidth = 10
                hostwidth = 13

        lines = [
            "{:<12s} {:<{}s} {:<{}s} {:<7s} {:<6s} {:<4s} {:<5s} {:s}".format(
                "Name",
                "Type",
                typewidth,
                "Host",
                hostwidth,
                "Pid",
                "VSize",
                "Rss",
                "Cpu",
                "Cmd",
            )
        ]
        for data in results.get_node_data():
            procinfo = data[2]["procs"]
            msg = ["{:<12s}".format(procinfo["name"])]
            msg.append("{:<{}s}".format(procinfo["type"], typewidth))
            msg.append("{:<{}s}".format(procinfo["host"], hostwidth))
            if procinfo["error"]:
                msg.append("<{:s}>".format(procinfo["error"]))
            else:
                msg.append("{:<7s}".format(str(procinfo["pid"])))
                msg.append("{:<6s}".format(util.number_unit_str(procinfo["vsize"])))
                msg.append("{:<4s}".format(util.number_unit_str(procinfo["rss"])))
                msg.append("{:>3s}% ".format(procinfo["cpu"]))
                msg.append("{:s}".format(procinfo["cmd"]))

            lines.append(" ".join(msg))

        return (results.ok, lines)

    def do_top(self, args):
        """- [<nodes>]

        For each of the nodes, prints the status of the Zeek process in
        a *top*-like format, including CPU usage and memory consumption. If
        executed interactively, the display is updated frequently
        until key ``q`` is pressed. If invoked non-interactively, the
        status is printed only once."""

        if not self.interactive:
            success, lines = self._do_top_once(args)
            for line in lines:
                self.info(line)

            return success

        utilcurses.enterCurses()
        utilcurses.clearScreen()

        count = 0

        while utilcurses.getCh() != "q":
            if count % 10 == 0:
                success, lines = self._do_top_once(args)
                utilcurses.clearScreen()
                utilcurses.printLines(lines)
            time.sleep(0.1)
            count += 1

        utilcurses.leaveCurses()

        return success

    def do_diag(self, args):
        """- [<nodes>]

        If a node has terminated unexpectedly, this command prints a (somewhat
        cryptic) summary of its final state including excerpts of any
        stdout/stderr output, resource usage, and also a stack backtrace if a
        core dump is found. The same information is sent out via mail when a
        node is found to have crashed (the "crash report"). While the
        information is mainly intended for debugging, it can also help to find
        misconfigurations (which are usually, but not always, caught by the
        check_ command)."""

        results = self.zeekctl.diag(node_list=args)

        for node, success, output in results.get_node_output():
            self.info(f"[{node}]")
            self.info(output)

        return results.ok

    def do_cron(self, args):
        """- [enable|disable|?] | [--no-watch]

        This command has two modes of operation. Without arguments (or just
        ``--no-watch``), it performs a set of maintenance tasks, including
        the logging of various statistical information, expiring old log
        files, checking for dead hosts, and restarting nodes which terminated
        unexpectedly (the latter can be suppressed with the ``--no-watch``
        option if no auto-restart is desired). This mode is intended to be
        executed regularly via *cron*, as described in the installation
        instructions. While not intended for interactive use, no harm will be
        caused by executing the command manually: all the maintenance tasks
        will then just be performed one more time.

        The second mode is for interactive usage and determines if the regular
        tasks are indeed performed when ``zeekctl cron`` is executed. In other
        words, even with ``zeekctl cron`` in your crontab, you can still
        temporarily disable it by running ``cron disable``, and
        then later reenable with ``cron enable``. This can be helpful while
        working, e.g., on the ZeekControl configuration and ``cron`` would
        interfere with that. ``cron ?`` can be used to query the current state.
        """

        watch = True

        if args == "--no-watch":
            watch = False
        elif args:
            if args == "enable":
                self.zeekctl.setcronenabled(True)
            elif args == "disable":
                self.zeekctl.setcronenabled(False)
            elif args == "?":
                results = self.zeekctl.cronenabled()
                cron_state = "enabled" if results else "disabled"
                self.info("cron " + cron_state)
            else:
                self.error("invalid cron argument")
                return False

            return True

        self.zeekctl.cron(watch)

        return True

    def do_check(self, args):
        """- [<nodes>]

        Verifies a modified configuration in terms of syntactical correctness
        (most importantly correct syntax in policy scripts).

        Note that this command checks the site-specific policy files as found
        in SitePolicyPath_ rather than the ones installed by the install_
        command.  Therefore, new errors in a policy script can be detected
        before affecting currently running nodes, even when they need to be
        restarted.

        This command should be executed for each configuration change *before*
        using install_ to put the change into place.  However, when using the
        deploy command there is no need to first run check, because deploy
        automatically runs check before installing the policy scripts."""

        results = self.zeekctl.check(node_list=args)

        for node, success, output in results.get_node_output():
            if success:
                self.info(f"{node} scripts are ok.")
            else:
                self.info(f"{node} scripts failed.")
                self.err(output)

        return results.ok

    def do_cleanup(self, args):
        """- [--all] [<nodes>]

        Clears the nodes' spool directories, but only for nodes that are not
        running. This implies that their persistent state is flushed. Nodes
        that were crashed are reset into the "stopped" state.

        If ``--all`` is specified, this command also removes the content of
        the node's TmpDir_, in particular deleting any data
        potentially saved there for reference from previous crashes.
        Generally, if you want to reset the installation back into a clean
        state, you can first stop_ all nodes, then execute
        ``cleanup --all``, then install_, and finally start_ all nodes
        again."""

        cleantmp = False
        if args.startswith("--all"):
            args = args[5:]
            cleantmp = True

        self.info("cleaning up nodes ...")

        results = self.zeekctl.cleanup(cleantmp=cleantmp, node_list=args)

        return results.ok

    def do_capstats(self, args):
        """- [<nodes>] [<interval>]

        Determines the current load on the network interfaces monitored by
        each of the given worker nodes. The load is measured over the
        specified interval (in seconds), or by default over 10 seconds. This
        command uses the :doc:`capstats<../../components/capstats/README>`
        tool, which is installed along with ``zeekctl``."""

        interval = 10
        args = args.split()

        if args:
            try:
                interval = max(1, int(args[-1]))
                args = args[0:-1]
            except ValueError:
                pass

        args = " ".join(args)

        def outputcapstats(tag, data):
            def output_one(tag, vals):
                return "{:<21s} {:<10s} {:s}".format(
                    tag, str(vals.get("kpps", "")), str(vals.get("mbps", ""))
                )

            self.info(
                "{:<21s} {:<10s} {:<10s} ({:d}s average)\n{:s}".format(
                    tag, "kpps", "mbps", interval, "-" * 40
                )
            )

            totals = None

            for node, success, vals in data:
                if not success:
                    self.err(vals["output"])
                    continue

                if str(node) != "$total":
                    hostnetif = f"{node.host}/{node.interface}"
                    self.info(output_one(hostnetif, vals))
                else:
                    totals = vals

            if totals:
                self.info("")
                self.info(output_one("Total", totals))

        results = self.zeekctl.capstats(interval=interval, node_list=args)

        nodedata = results.get_node_data()
        if nodedata:
            outputcapstats("Interface", nodedata)
        else:
            self.error(
                "No network interfaces suitable for use with capstats were found."
            )

        return results.ok

    def do_df(self, args):
        """- [<nodes>]

        Reports the amount of disk space available on the nodes. Shows only
        paths relevant to the zeekctl installation."""

        results = self.zeekctl.df(node_list=args)

        self.info(
            "{:>27s}  {:>15s}  {:<5s}  {:<5s}  {:<5s}".format(
                "", "", "total", "avail", "capacity"
            )
        )
        for node, success, dfs in results.get_node_data():
            for key, diskinfo in sorted(dfs.items()):
                if key == "FAIL":
                    self.error(f"df helper failed on {node}: {diskinfo}")
                    continue
                nodehost = f"{node.name}/{node.host}"
                self.info(
                    f"{nodehost:>28s}  {diskinfo.fs:>15s}  {util.number_unit_str(diskinfo.total):<5s}  {util.number_unit_str(diskinfo.available):<5s}  {diskinfo.percent:<5.1f}%"
                )

        return results.ok

    def do_print(self, args):
        """- <id> [<nodes>]

        Reports the *current* live value of the given Zeek script ID on all of
        the specified nodes (which obviously must be running). This can for
        example be useful to (1) check that policy scripts are working as
        expected, or (2) confirm that configuration changes have in fact been
        applied.  Note that IDs defined inside a Zeek namespace must be
        prefixed with ``<namespace>::`` (e.g.,
        ``print Log::enable_remote_logging``)."""

        args = args.split()
        try:
            id = args[0]
            args = " ".join(args[1:])
        except IndexError:
            raise CommandSyntaxError("no id given to print")

        results = self.zeekctl.print_id(id=id, node_list=args)

        for node, success, msg in results.get_node_output():
            if success:
                out = msg.split("\n", 1)
                self.info(f"{node:>12s}   {out[0]} = {out[1]}")
            else:
                self.err(f"{node:>12s}   <error: {msg}>")

        return results.ok

    def do_peerstatus(self, args):
        """- [<nodes>]

        Primarily for debugging, ``peerstatus`` reports statistics about the
        network connections cluster nodes are using to communicate with other
        nodes."""

        results = self.zeekctl.peerstatus(node_list=args)

        for node, success, msg in results.get_node_output():
            if success:
                self.info(f"{node:>11s}\n{msg}")
            else:
                self.err(f"{node:>11s}   <error: {msg}>")

        return results.ok

    def do_netstats(self, args):
        """- [<nodes>]

        Queries each of the nodes for their current counts of captured and
        dropped packets."""

        results = self.zeekctl.netstats(node_list=args)

        for node, success, msg in results.get_node_output():
            if success:
                self.info(f"{node:>11s}: {msg}")
            else:
                self.err(f"{node:>11s}: <error: {msg}>")

        return results.ok

    def do_exec(self, args):
        """- <command line>

        Executes the given Unix shell command line on all hosts configured to
        run at least one Zeek instance. This is handy to quickly perform an
        action across all systems."""

        results = self.zeekctl.execute(cmd=args)

        for node, success, output in results.get_node_output():
            out = "\n> ".join(output.splitlines())
            error = " " if success else "error"
            self.info(f"[{node.name}/{node.host}] {error}\n> {out}")

        return results.ok

    def do_scripts(self, args):
        """- [-c] [<nodes>]

        Primarily for debugging Zeek configurations, the ``scripts``
        command lists all the Zeek scripts loaded by each of the nodes in the
        order they will be parsed by the node at startup.  The pathnames
        of each script are indented such that it is possible to determine
        from where a script was loaded based on the amount of indentation.

        If ``-c`` is given, the command operates as check_ does: it reads
        the policy files from their *original* location, not the copies
        installed by install_. The latter option is useful to check a
        not yet installed configuration."""

        check = False

        args = args.split()

        try:
            while args[0].startswith("-"):
                opt = args[0]

                if opt == "-c":
                    # Check non-installed policies.
                    check = True
                else:
                    raise CommandSyntaxError(
                        f"invalid argument for the scripts command: {opt}"
                    )

                args = args[1:]

        except IndexError:
            pass

        args = " ".join(args)

        results = self.zeekctl.scripts(check=check, node_list=args)

        for node, success, output in results.get_node_output():
            if success:
                self.info(f"{node} scripts are ok.")
                for line in output.splitlines():
                    self.info(f"  {line}")
            else:
                self.info(f"{node} scripts failed.")
                self.err(output)

        return results.ok

    def do_process(self, args):
        """- <trace> [options] [-- <scripts>]

        Runs Zeek offline on a given trace file using the same configuration as
        when running live. It does, however, use the potentially
        not-yet-installed policy files in SitePolicyPath_ and disables log
        rotation. Additional Zeek command line flags and scripts can
        be given (each argument after a ``--`` argument is interpreted as
        a script).

        Upon completion, the command prints a path where the log files can be
        found. Subsequent runs of this command may delete these logs.

        In cluster mode, Zeek is run with *both* manager and worker scripts
        loaded into a single instance. While that doesn't fully reproduce the
        live setup, it is often sufficient for debugging analysis scripts.
        """
        options = []
        scripts = []
        trace = ""
        in_scripts = False

        for arg in args.split():
            if not trace:
                trace = arg
                continue

            if arg == "--":
                if in_scripts:
                    raise CommandSyntaxError(
                        'cannot parse the arguments of the process command (too many "--")'
                    )

                in_scripts = True
                continue

            if not in_scripts:
                options += [arg]

            else:
                scripts += [arg]

        if not trace:
            raise CommandSyntaxError(
                "the process command requires the pathname of a trace file"
            )

        results = self.zeekctl.process(trace, options, scripts)

        return results.ok

    def completedefault(self, text, line, begidx, endidx):
        # Commands that take a "<nodes>" argument.
        nodes_cmds = [
            "capstats",
            "check",
            "cleanup",
            "df",
            "diag",
            "netstats",
            "print",
            "restart",
            "start",
            "status",
            "stop",
            "top",
            "update",
            "peerstatus",
            "scripts",
        ]

        args = line.split()

        if not args or args[0] not in nodes_cmds:
            return []

        nodes = self.zeekctl.node_groups() + self.zeekctl.node_names()

        return [n for n in nodes if n.startswith(text)]

    def do_help(self, args):
        """Prints a brief summary of all commands understood by the shell."""

        plugin_help = ""

        for cmd, args, descr in self.zeekctl.plugins.allCustomCommands():
            if not plugin_help:
                plugin_help += "\nCommands provided by plugins:\n\n"

            if args:
                cmd = f"{cmd} {args}"

            plugin_help += f"  {cmd:<32s} - {descr}\n"

        self.info(
            f"""
ZeekControl Version {version.VERSION}

  capstats [<nodes>] [<secs>]      - Report interface statistics with capstats
  check [<nodes>]                  - Check configuration before installing it
  cleanup [--all] [<nodes>]        - Delete working dirs (flush state) on nodes
  config                           - Print zeekctl configuration
  cron [--no-watch]                - Perform jobs intended to run from cron
  cron enable|disable|?            - Enable/disable "cron" jobs
  deploy                           - Check, install, and restart
  df [<nodes>]                     - Print nodes' current disk usage
  diag [<nodes>]                   - Output diagnostics for nodes
  exec <shell cmd>                 - Execute shell command on all hosts
  exit                             - Exit shell
  install                          - Update zeekctl installation/configuration
  netstats [<nodes>]               - Print nodes' current packet counters
  nodes                            - Print node configuration
  peerstatus [<nodes>]             - Print status of nodes' remote connections
  print <id> [<nodes>]             - Print values of script variable at nodes
  process <trace> [<op>] [-- <sc>] - Run Zeek with options and scripts on trace
  quit                             - Exit shell
  restart [--clean] [<nodes>]      - Stop and then restart processing
  scripts [-c] [<nodes>]           - List the Zeek scripts the nodes will load
  start [<nodes>]                  - Start processing
  status [<nodes>]                 - Summarize node status
  stop [<nodes>]                   - Stop processing
  top [<nodes>]                    - Show Zeek processes ala top
  {plugin_help}"""
        )


def main():
    # Undocumented option to print the documentation.
    if len(sys.argv) == 3 and sys.argv[1] == "--print-doc":
        from ZeekControl import printdoc

        printdoc.print_zeekctl_docs(sys.argv[2], ZeekCtlCmdLoop)
        return 0

    if len(sys.argv) == 2 and sys.argv[1] == "--version":
        print(f"ZeekControl version {version.VERSION}")
        return 0

    interactive = True
    if len(sys.argv) > 1:
        interactive = False

    cmd = ""
    if len(sys.argv) == 2:
        cmd = sys.argv[1]

    try:
        loop = ZeekCtlCmdLoop(ZeekCtl, interactive, cmd)
    except ZeekControlError as e:
        print(f"Error: {e}", file=sys.stderr)
        return 1

    if len(sys.argv) > 1:
        cmdline = " ".join(sys.argv[1:])
        loop.precmd(cmdline)
        try:
            cmdsuccess = loop.onecmd(cmdline)
            loop.postcmd(False, cmdline)
        except ZeekControlError as e:
            cmdsuccess = False
            print(f"Error: {e}", file=sys.stderr)
        except KeyboardInterrupt:
            cmdsuccess = False
        finally:
            loop.finish()

    else:
        try:
            cmdsuccess = loop.cmdloop(
                f'\nWelcome to ZeekControl {version.VERSION}\n\nType "help" for help.\n'
            )
        except ZeekControlError as e:
            cmdsuccess = False
            print(f"Error: {e}", file=sys.stderr)
        except KeyboardInterrupt:
            cmdsuccess = False
        finally:
            loop.finish()

    return not cmdsuccess


if __name__ == "__main__":
    sys.exit(main())
