#!/usr/bin/env python3

# Copyright (C) 2022 Canonical Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import binascii
import hashlib
import itertools
import logging
import os
import re
import subprocess
import yaml

from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import cpu_count
from pathlib import Path
from threading import Lock
from typing import Dict, List


# default shell for shellcheck
SHELLCHECK_SHELL = os.getenv("SHELLCHECK_SHELL", "bash")
# set to non-empty to ignore all errors
NO_FAIL = os.getenv("NO_FAIL")
# set to non empty to enable 'set -x'
D = os.getenv("D")
# set to non-empty to enable verbose logging
V = os.getenv("V")
# set to a number to use these many threads
N = int(os.getenv("N") or cpu_count())
# file with list of files that can fail validation
CAN_FAIL = os.getenv("CAN_FAIL")

# names of sections
SECTIONS: List[str] = [
    "prepare",
    "prepare-each",
    "restore",
    "restore-each",
    "debug",
    "debug-each",
    "execute",
    "repack",
    "skip",
]


def parse_arguments():
    parser = argparse.ArgumentParser(description="spread shellcheck helper")
    parser.add_argument("-s", "--shell", default="bash", help="shell")
    parser.add_argument(
        "-n",
        "--no-errors",
        action="store_true",
        default=False,
        help="ignore all errors ",
    )
    parser.add_argument(
        "-v", "--verbose", action="store_true", default=False, help="verbose logging"
    )
    parser.add_argument(
        "--can-fail",
        default=None,
        help=("file with list of files that are can fail " "validation"),
    )
    parser.add_argument(
        "-P",
        "--max-procs",
        default=N,
        type=int,
        metavar="N",
        help="run these many shellchecks in parallel (default: %(default)s)",
    )
    parser.add_argument(
        "-e",
        "--exclude",
        default=[],
        action="append",
        help="path to exclude of the shell check",
    )
    parser.add_argument("--no-cache", help="disable caching", action="store_true")
    parser.add_argument("paths", nargs="+", help="paths to check")
    return parser.parse_args()


class ShellcheckRunError(Exception):
    def __init__(self, stderr):
        super().__init__()
        self.stderr = stderr


class ShellcheckError(Exception):
    def __init__(self, path):
        super().__init__()
        self.sectionerrors = {}
        self.path = path

    def addfailure(self, section, error):
        self.sectionerrors[section] = error

    def __len__(self):
        return len(self.sectionerrors)


class ShellcheckFailures(Exception):
    def __init__(self, failures=None):
        super().__init__()
        self.failures = set()
        if failures:
            self.failures = set(failures)

    def merge(self, otherfailures):
        self.failures = self.failures.union(otherfailures.failures)

    def __len__(self):
        return len(self.failures)

    def intersection(self, other):
        return self.failures.intersection(other)

    def difference(self, other):
        return self.failures.difference(other)

    def __iter__(self):
        return iter(self.failures)


def iscomposite(section: str) -> bool:
    return section == "skip"


def checksection(data, env: Dict[str, str]):
    # spread shell snippets are executed under 'set -e' shell, make sure
    # shellcheck knows about that
    script_data = []
    script_data.append("set -e")

    for key, value in env.items():
        value = str(value)
        disabled_warnings = set()
        export_disabled_warnings = set()

        def replacement(match):
            if match.group(0) == '"':
                # SC2089 and SC2090 are about quotes vs arrays
                # We cannot have arrays in environment variables of spread
                # So we do have to use quotes
                disabled_warnings.add("SC2089")
                export_disabled_warnings.add("SC2090")
                return r"\""
            else:
                assert match.group("command") is not None
                # "Useless" echo. This is what we get.
                # We cannot just evaluate to please shellcheck.
                disabled_warnings.add("SC2116")
                return "$({})".format(match.group("command"))

        value = re.sub(r'[$][(]HOST:(?P<command>.*)[)]|"', replacement, value)
        # converts
        # FOO: "$(HOST: echo $foo)"     -> FOO="$(echo $foo)"
        # FOO: "$(HOST: echo \"$foo\")" -> FOO="$(echo "$foo")"
        # FOO: "foo"                    -> FOO="foo"
        # FOO: "\"foo\""                -> FOO="\"foo\""
        if disabled_warnings:
            script_data.append(
                "# shellcheck disable={}".format(",".join(disabled_warnings))
            )
        script_data.append('{}="{}"'.format(key, value))
        if export_disabled_warnings:
            script_data.append(
                "# shellcheck disable={}".format(",".join(export_disabled_warnings))
            )
        script_data.append("export {}".format(key))
    script_data.append(data)
    proc = subprocess.Popen(
        "shellcheck -s {} -x -".format(SHELLCHECK_SHELL),
        stdout=subprocess.PIPE,
        stdin=subprocess.PIPE,
        shell=True,
    )
    stdout, _ = proc.communicate(
        input="\n".join(script_data).encode("utf-8"), timeout=60
    )
    if proc.returncode != 0:
        raise ShellcheckRunError(stdout)


def checkcompositesection(compositesection: str, data, env):
    if compositesection == "skip":
        # skip data is a list of conditions
        for entry in data:
            if "if" not in entry:
                continue
            logging.debug("checking composite 'if' entry")
            checksection(entry["if"], env)
    else:
        raise RuntimeError(f"unsupported composite section {compositesection}")


class Cacher:
    _instance = None

    def __init__(self):
        self._enabled = True
        self._lock = Lock()
        self._hit = 0
        self._miss = 0
        self._shellcheck_version = None
        self._probe_shellcheck_version()

    @classmethod
    def init(cls):
        cls._instance = Cacher()

    @classmethod
    def get(cls):
        return cls._instance

    def disable(self):
        logging.debug("caching is disabled")
        self._enabled = False

    @staticmethod
    def _cache_path_for(digest):
        prefix = digest[:2]
        return Path.home().joinpath(".cache", "spread-shellcheck", prefix, digest)

    def is_cached(self, data, path):
        if not self._enabled:
            return False, ""
        # the digest uses script content and shellcheck versions as inputs, but
        # consider other possible inputs: path to the *.yaml file (so moving
        # the script around would cause a miss) or even the contents of this
        # script
        h = hashlib.sha256()
        h.update(self._shellcheck_version)
        h.update(data)
        hdg = binascii.b2a_hex(h.digest()).decode()
        cachepath = Cacher._cache_path_for(hdg)
        logging.debug(
            "cache stamp %s, exists? %s", cachepath.as_posix(), cachepath.exists()
        )
        hit = cachepath.exists()
        self._record_cache_event(hit)
        return hit, hdg

    def cache_success(self, digest, path):
        if not self._enabled:
            return
        cachepath = Cacher._cache_path_for(digest)
        logging.debug("cache success, path %s", cachepath.as_posix())
        cachepath.parent.mkdir(parents=True, exist_ok=True)
        cachepath.touch()

    def _record_cache_event(self, hit):
        with self._lock:
            if hit:
                self._hit += 1
            else:
                self._miss += 1

    def _probe_shellcheck_version(self):
        logging.debug("probing shellcheck version")
        out = subprocess.check_output("shellcheck --version", shell=True)
        self._shellcheck_version = out

    @property
    def stats(self):
        return namedtuple("Stats", ["hit", "miss"])(self._hit, self._miss)


def checkfile(path, executor):
    logging.debug("checking file %s", path)
    with open(path, mode="rb") as inf:
        rawdata = inf.read()
        cached, digest = Cacher.get().is_cached(rawdata, path)
        if cached:
            logging.debug("entry %s already cached", digest)
            return
        data = yaml.safe_load(rawdata)

    errors = ShellcheckError(path)
    # TODO: handle stacking of environment from other places that influence it:
    # spread.yaml -> global env + backend env + suite env -> task.yaml (task
    # env + variant env).
    env = {}
    for key, value in data.get("environment", {}).items():
        if "/" in key:
            # TODO: re-check with each variant's value set.
            key = key.split("/", 1)[0]
        env[key] = value
    for section in SECTIONS:
        if section not in data:
            continue
        try:
            if not iscomposite(section):
                logging.debug("%s: checking section %s", path, section)
                checksection(data[section], env)
            else:
                logging.debug("%s: checking composite section %s", path, section)
                checkcompositesection(section, data[section], env)
        except ShellcheckRunError as serr:
            errors.addfailure(section, serr.stderr.decode("utf-8"))

    if path.endswith("spread.yaml") and "suites" in data:
        # check suites
        suites_sections_and_futures = []
        for suite in data["suites"].keys():
            for section in SECTIONS:
                if section not in data["suites"][suite]:
                    continue
                logging.debug(
                    "%s (suite %s): checking section %s", path, suite, section
                )
                if not iscomposite(section):
                    future = executor.submit(
                        checksection, data["suites"][suite][section], env
                    )
                else:
                    future = executor.submit(
                        checkcompositesection,
                        section,
                        data["suites"][suite][section],
                        env,
                    )
                suites_sections_and_futures.append((suite, section, future))
        for item in suites_sections_and_futures:
            suite, section, future = item
            try:
                future.result()
            except ShellcheckRunError as serr:
                errors.addfailure(
                    "suites/" + suite + "/" + section, serr.stderr.decode("utf-8")
                )

    if errors:
        raise errors
    # only stamp the cache when the script was found to be valid
    Cacher.get().cache_success(digest, path)


def is_file_in_dirs(file, dirs):
    for dir in dirs:
        if os.path.abspath(file).startswith("{}/".format(os.path.abspath(dir))):
            print("Skipping {}".format(file))
            return True

    return False


def findfiles(locations, exclude):
    for loc in locations:
        if os.path.isdir(loc):
            for root, _, files in os.walk(loc, topdown=True):
                for name in files:
                    if name in ["spread.yaml", "task.yaml"]:
                        full_path = os.path.join(root, name)
                        if not is_file_in_dirs(full_path, exclude):
                            yield full_path
        else:
            full_path = os.path.abspath(loc)
            if not is_file_in_dirs(full_path, exclude):
                yield full_path


def check1path(path, executor):
    try:
        checkfile(path, executor)
    except ShellcheckError as err:
        return err
    return None


def checkpaths(locs, exclude, executor):
    # setup iterator
    locations = findfiles(locs, exclude)
    failed = []
    for serr in executor.map(check1path, locations, itertools.repeat(executor)):
        if serr is None:
            continue
        logging.error(
            ("shellcheck failed for file %s in sections: " "%s; error log follows"),
            serr.path,
            ", ".join(serr.sectionerrors.keys()),
        )
        for section, error in serr.sectionerrors.items():
            logging.error("%s: section '%s':\n%s", serr.path, section, error)
        failed.append(serr.path)

    if failed:
        raise ShellcheckFailures(failures=failed)


def loadfilelist(flistpath):
    flist = set()
    with open(flistpath) as inf:
        for line in inf:
            if not line.startswith("#"):
                flist.add(line.strip())
    return flist


def main(opts):
    paths = opts.paths or ["."]
    exclude = opts.exclude
    failures = ShellcheckFailures()
    with ThreadPoolExecutor(max_workers=opts.max_procs) as executor:
        try:
            checkpaths(paths, exclude, executor)
        except ShellcheckFailures as sf:
            failures.merge(sf)

    if not opts.no_cache:
        stats = Cacher.get().stats
        logging.info("cache stats: hit %d miss %d", stats.hit, stats.miss)

    if failures:
        if opts.can_fail:
            can_fail = loadfilelist(opts.can_fail)

            unexpected = failures.difference(can_fail)
            if unexpected:
                logging.error(
                    (
                        "validation failed for the following "
                        "non-whitelisted files:\n%s"
                    ),
                    "\n".join([" - " + f for f in sorted(unexpected)]),
                )
                raise SystemExit(1)

            did_not_fail = can_fail - failures.intersection(can_fail)
            if did_not_fail:
                logging.error(
                    (
                        "the following files are whitelisted "
                        "but validated successfully:\n%s"
                    ),
                    "\n".join([" - " + f for f in sorted(did_not_fail)]),
                )
                raise SystemExit(1)

            # no unexpected failures
            return

        logging.error(
            "validation failed for the following files:\n%s",
            "\n".join([" - " + f for f in sorted(failures)]),
        )

        if NO_FAIL or opts.no_errors:
            logging.warning("ignoring errors")
        else:
            raise SystemExit(1)


if __name__ == "__main__":
    opts = parse_arguments()
    if opts.verbose or D or V:
        lvl = logging.DEBUG
    else:
        lvl = logging.INFO
    logging.basicConfig(level=lvl)

    if CAN_FAIL:
        opts.can_fail = CAN_FAIL

    if NO_FAIL:
        opts.no_errors = True

    if opts.max_procs == 1:
        # TODO: temporary workaround for a deadlock when running with a single
        # worker
        opts.max_procs += 1
        logging.warning("workers count bumped to 2 to workaround a deadlock")

    Cacher.init()
    if opts.no_cache:
        Cacher.get().disable()

    main(opts)
