#!/usr/bin/python3

# Author: Benjamin Drung <bdrung@ubuntu.com>

"""Check convert_timezone from tzdata.config for consistency."""

import argparse
import functools
import logging
import os
import pathlib
import re
import subprocess
import sys
import typing

import debian.deb822

LOG_FORMAT = "%(levelname)s: %(message)s"
# Special timezones that should not be selectable in debconf
SPECIAL = {"Factory", "localtime", "posixrules"}
# Not selectable timezones that are not mentioned in the backward file.
# See also https://launchpad.net/bugs/2030684
EXCLUDE_UNSELECTABLE = {"GMT", "UTC"}


class ConvertTimezone:
    """Wrap convert_timezone from tzdata.config."""

    def __init__(self, tzdata_config: pathlib.Path) -> None:
        self.tzdata_config = tzdata_config
        content = tzdata_config.read_text(encoding="utf-8")
        match = re.search(r"convert_timezone\(\).*\n}", content, flags=re.DOTALL)
        assert match, f"convert_timezone function not found in {tzdata_config}"
        self.convert_timezone = match.group(0)

    @functools.lru_cache(maxsize=8192)
    def __call__(self, timezone: str) -> str:
        shell_script = f"{self.convert_timezone}\nconvert_timezone '{timezone}'\n"
        shell = subprocess.run(
            ["/bin/sh", "-c", shell_script],
            capture_output=True,
            check=True,
            encoding="utf-8",
        )
        return shell.stdout.strip()

    def filter_converted_timezones(self, timezones):
        """Return dict of timezones that will be converted by convert_timezone."""
        converted = {}
        for timezone in timezones:
            conversion = self(timezone)
            if conversion != timezone:
                converted[timezone] = conversion
        return converted

    def filter_unconverted_timezones(self, timezones):
        """Return set of timezones that will not be converted by convert_timezone."""
        return timezones - set(self.filter_converted_timezones(timezones))

    def get_targets(self):
        """Return set of conversion targets."""
        targets = set(re.findall('echo "([^"$]+)"', self.convert_timezone))
        logging.getLogger(__name__).info(
            "Available conversion targets in %s: %i", self.tzdata_config, len(targets)
        )
        return targets


def _is_tzif_file(fpath):
    with open(fpath, "rb") as tz_file:
        return tz_file.read(4) == b"TZif"


def get_available_timezones(directory: typing.Optional[pathlib.Path]):
    """Return a set of available timezones in the directory.

    If directory is not set, use the sytem's default.
    """
    logger = logging.getLogger(__name__)

    available = set()
    tz_root = str(directory or "/usr/share/zoneinfo")
    for root, dirnames, files in os.walk(tz_root):
        if root == tz_root:
            # right/ and posix/ are special directories and shouldn't be
            # included in the output of available zones
            if "right" in dirnames:
                dirnames.remove("right")
            if "posix" in dirnames:
                dirnames.remove("posix")

        for file in files:
            fpath = os.path.join(root, file)
            key = os.path.relpath(fpath, start=tz_root)
            if _is_tzif_file(fpath):
                available.add(key)

    if not available:
        logger.error("Found no timezones in %s.", tz_root)
        sys.exit(1)
    logger.info("Available timezones in %s: %i", directory or "system", len(available))
    return available


def get_debconf_choices(template_filename: pathlib.Path):
    """Extract the timezone choices from the debconf template."""
    logger = logging.getLogger(__name__)
    debconf_choices = set()
    with template_filename.open(encoding="utf-8") as template_file:
        for paragraph in debian.deb822.Deb822.iter_paragraphs(template_file):
            area_match = re.match("tzdata/Zones/(.*)", paragraph["Template"])
            if not area_match:
                continue
            area = area_match.group(1)
            choices = paragraph.get("Choices", paragraph.get("__Choices", ""))
            choices = choices.split(", ")
            if area == "Legacy":
                debconf_choices.update(choices)
            else:
                debconf_choices.update([f"{area}/{c}" for c in choices])
    if not debconf_choices:
        logger.error("Found no selectable timezones in %s.", template_filename)
        sys.exit(1)
    logger.info(
        "Selectable timezones in %s: %i", template_filename, len(debconf_choices)
    )
    return debconf_choices


def _check_symlink(tzpath: pathlib.Path, source: str, target: str) -> bool:
    """Check if the given timezone source symlinks to the given target."""
    timezone = tzpath / source
    if timezone.exists():
        expected_target = pathlib.Path(tzpath) / target
        return timezone.resolve() == expected_target.resolve()
    return True


def check_symlinks(tzpath: typing.Optional[pathlib.Path], timezones):
    """Check timezone replacements are identical to the symlinks on disk."""
    if tzpath is None:
        tzpath = pathlib.Path("/usr/share/zoneinfo")
    mismatch = set()
    for source, target in timezones.items():
        if source.startswith("posix/") or source.startswith("right/"):
            continue
        if not _check_symlink(tzpath, source, target):
            mismatch.add(source)
    return mismatch


def existing_dir_path(string: str) -> pathlib.Path:
    """Convert string to existing dir path or raise ArgumentTypeError."""
    path = pathlib.Path(string)
    if not path.is_dir():
        raise argparse.ArgumentTypeError(f"Directory {string} does not exist")
    return path


def parse_args() -> argparse.Namespace:
    """Parse command line arguments and return namespace."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-z",
        "--zoneinfo-directory",
        type=existing_dir_path,
        help="Directory containing the generated zoneinfo files (default: system)",
    )
    parser.add_argument(
        "-d",
        "--debian-directory",
        default=pathlib.Path("debian"),
        type=existing_dir_path,
        help="Path to debian directory containing tzdata.config"
        " and tzdata.templates (default: %(default)s)",
    )
    parser.add_argument(
        "--all-selectable",
        action="store_true",
        help="Require all available timezones to be selectable in debconf",
    )
    return parser.parse_args()


def main() -> int:
    """Check convert_timezone from tzdata.config for consistency."""
    args = parse_args()
    logging.basicConfig(format=LOG_FORMAT, level=logging.INFO)
    logger = logging.getLogger(__name__)

    selectable = get_debconf_choices(args.debian_directory / "tzdata.templates")
    available = get_available_timezones(args.zoneinfo_directory)
    convert_timezone = ConvertTimezone(args.debian_directory / "tzdata.config")
    conversion_targets = convert_timezone.get_targets()
    failures = 0

    converted = set(convert_timezone.filter_converted_timezones(selectable))
    if converted:
        logger.warning(
            "Following %i timezones can be selected, but will be converted:\n%s",
            len(converted),
            "\n".join(sorted(converted)),
        )

    unselectable = available - selectable - SPECIAL - EXCLUDE_UNSELECTABLE
    if args.all_selectable and unselectable:
        logger.error(
            "Following %i timezones cannot be selected, but are available:\n%s",
            len(unselectable),
            "\n".join(sorted(unselectable)),
        )
        failures += 1

    missing = convert_timezone.filter_unconverted_timezones(unselectable)
    if missing:
        logger.error(
            "Following %i timezones cannot be selected, but are not converted:\n%s",
            len(missing),
            "\n".join(sorted(missing)),
        )
        failures += 1

    targets = conversion_targets - available
    if targets:
        logger.error(
            "Following %i timezones are conversion targets, but are not available:\n%s",
            len(targets),
            "\n".join(sorted(targets)),
        )
        failures += 1

    targets = conversion_targets - selectable
    if targets:
        logger.error(
            "Following %i timezones are conversion targets,"
            " but are not selectable:\n%s",
            len(targets),
            "\n".join(sorted(targets)),
        )
        failures += 1

    mismatch = check_symlinks(
        args.zoneinfo_directory, convert_timezone.filter_converted_timezones(available)
    )
    if mismatch:
        logger.error(
            "Following %i timezones are converted,"
            " but they do not match their symlink targets:\n%s",
            len(mismatch),
            "\n".join(sorted(mismatch)),
        )
        failures += 1

    return failures


if __name__ == "__main__":
    sys.exit(main())
