#!/usr/bin/env python
# -----------------------------------------------------------------------------
# The proprietary software and information contained in this file is
# confidential and may only be used by an authorized person under a valid
# licensing agreement from Arm Limited or its affiliates.
#
# Copyright (C) 2025. Arm Limited or its affiliates. All rights reserved.
#
# This entire notice must be reproduced on all copies of this file and
# copies of this file may only be made by an authorized person under a valid
# licensing agreement from Arm Limited or its affiliates.
# -----------------------------------------------------------------------------

import argparse
import sys
from os import path

from datetime import datetime
from importlib import import_module

import asct.core.constants as constants
import asct.core.logger as log

from asct.core.asct_env import ASCTGlobalSettings as AGS, SystemState
from asct.core.recipes.recipe_info import get_all_tags, TAG_DEFAULT
from asct.core.cmd.helpers.version_helpers import get_version
from asct.core.cmd.helpers.help_helpers import (
    ASCTCustomHelpFormatter,
    ASCTCommandHelpAction,
    ASCTParser,
    get_formatted_benchmark_str,
)
from asct.core.resources.output_folder import OutputFolder
from asct.core.resources.resource_manager import ResourceManager
from asct.core.term_ui.term_manager import TermManager
from asct.core.utility.misc import create_dict_path
from asct.core.cache import ASCTCache


def initialize_asct(args):
    output_dir = getattr(args, "output_dir", None)

    quiet = getattr(args, "quiet", False)
    log_file = getattr(args, "log_file", None)

    TermManager().initialize(sys.stderr)

    # ResourceManager follows the singleton pattern, for this reason
    # only one instance is used throughout the application
    log_config = log.LogConfigurator()

    log_level = getattr(args, "log_level", None)
    log_level_console = getattr(args, "log_level_console", None)
    log_level_file = getattr(args, "log_level_file", None)

    if log_level:
        log_level_console, log_level_file = log_level, log_level

    if log_level_console:
        log_config.configure_console_logging(not quiet, log_level_console).apply()

    if output_dir:
        resource_manager = ResourceManager()
        output_folder_resource = OutputFolder(output_dir, args.force)
        resource_manager.register(output_folder_resource, global_scope=True)
        try:
            resource_manager.apply_all(global_scope=True)
        except Exception as exc:
            log.error(f"Unable to create output directory: {exc}")
            exit(1)
        args.output_dir_path = output_folder_resource.get_output_folder_path()

        # by default write logs to output/asct.log based on current log level
        # quiet flag gives the user the option to turn off logging to file as well
        if log_file is None and not quiet:
            args.log_file = path.join(output_folder_resource.get_output_folder_path(), "asct.log")
        log_config.configure_file_logging(bool(args.log_file), log_level_file, args.log_file).apply()

    AGS().read_env_vars()

    try:
        SystemState().initialize()
    except Exception as exc:
        log.critical(f"{exc}")
        exit(1)

    # ensure that cache is not used in dev mode
    if getattr(args, "dev_mode", False):
        args.no_cache = True

    cache = ASCTCache(
        use_cache=not getattr(args, "no_cache", False), output_folder=getattr(args, "output_dir_path", "./")
    )
    cache.clear_asct_cache(invalidate=getattr(args, "clear_cache", False))


def exec_cmd(command, args):
    initialize_asct(args)

    module = import_module(f"asct.core.cmd.{command}")
    module.run(args)


def exec_cmd_version(args):
    exec_cmd("version", args)


def exec_cmd_help(args):
    exec_cmd("asct_help", args)


def exec_cmd_list_benchmarks(args):
    exec_cmd("list_benchmarks", args)


def exec_cmd_run(args):
    exec_cmd("run", args)


def exec_cmd_diff(args):
    exec_cmd("diff", args)


def exec_cmd_system_info(args):
    exec_cmd("system_info", args)


def exec_cmd_sysreg(args):
    exec_cmd("sysreg", args)


class ASCTHelpAction(argparse.Action):
    def __call__(self, parser, namespace, *_):
        namespace.parser = parser
        exec_cmd("asct_help", namespace)
        parser.exit(0)


class ASCTSettingsAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        values_dict = {}
        # values is a list because we use nargs='+'
        for item in values:
            if "=" not in item:
                parser.error(f"{option_string} expects SETTING=VALUE, got {item}")
            key, value = item.split("=", 1)
            create_dict_path(values_dict, key.split("."), value)
        setattr(namespace, self.dest, values_dict)


def main():
    parser = ASCTParser(
        description=f"ASCT: Arm System Characterization Tool (version: {get_version()})",
        add_help=False,
        usage="%(prog)s command [command_args…]",
        prog="asct",
        formatter_class=ASCTCustomHelpFormatter,
    )

    parser.add_argument(
        "-h",
        "--help",
        action=ASCTHelpAction,
        nargs=0,
        help=argparse.SUPPRESS,
    )

    commands = parser.add_subparsers(dest="command", title=None, metavar="ASCT commands", required=True)

    # common args for commands that do not produce any file output
    common_no_wr_args = argparse.ArgumentParser(add_help=False)
    common_no_wr_args.add_argument(
        "--log-level",
        "-L",
        choices=log.get_log_levels(),
        type=str.lower,
        default="",
        help=f"Logging level for both file and console logging (defaults: console: {log.DEFAULT_LOG_LEVEL_CONSOLE}"
        f" file: {log.DEFAULT_LOG_LEVEL_FILE})",
        metavar=f"[{','.join(map(str, log.get_log_levels()))}]",
    )

    common_no_wr_args.add_argument(
        "--log-level-console",
        choices=log.get_log_levels(),
        type=str.lower,
        default=log.DEFAULT_LOG_LEVEL_CONSOLE,
        help=f"Logging level for console logging (default: {log.DEFAULT_LOG_LEVEL_CONSOLE})",
        metavar=f"[{','.join(map(str, log.get_log_levels()))}]",
    )

    common_no_wr_args.add_argument(
        "--log-level-file",
        choices=log.get_log_levels(),
        type=str.lower,
        default=log.DEFAULT_LOG_LEVEL_FILE,
        help=f"Logging level for file logging (default: {log.DEFAULT_LOG_LEVEL_FILE})",
        metavar=f"[{','.join(map(str, log.get_log_levels()))}]",
    )

    # common args for commands that produce file output
    common_wr_args = argparse.ArgumentParser(add_help=False)
    common_wr_args.add_argument(
        "--log-file",
        type=str,
        default=None,
        help="File to output logging messages (default: asct.log)",
    )
    common_wr_args.add_argument(
        "--format",
        "-f",
        choices=constants.FORMAT_OPTIONS,
        type=str.lower,
        default="stdout",
        help="Output format (default: stdout)",
        metavar=f"[{','.join(map(str, constants.FORMAT_OPTIONS))}]",
    )
    common_wr_args.add_argument(
        "--output-dir",
        "-o",
        type=str,
        default=f"data.{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}",
        help="Path to the output directory (default: 'data.<YYYYMMDD_HHMMSS_microseconds>' in cwd)",
    )
    common_wr_args.add_argument(
        "--force", action="store_true", default=False, help="Reuse the output directory if it exists"
    )
    common_wr_args.add_argument(
        "--quiet",
        "-q",
        action="store_true",
        help="Disable default logging to the asct.log file, and all output to stdout/stderr,"
        " including critical errors and log messages. Use --log-file to capture and view logs.",
    )

    common_wr_args.add_argument(
        "--no-cache",
        action="store_true",
        default=False,
        help="Disable the use of cached benchmark data",
    )

    common_wr_args.add_argument(
        "--clear-cache",
        action="store_true",
        default=False,
        help="Clear the cached benchmark data",
    )

    # 'run' command
    run_cmd = commands.add_parser(
        "run",
        usage="asct run keyword1 keyword2... [options...]",
        parents=[common_no_wr_args, common_wr_args],
        help="Run a list of benchmarks based on a provided list of keywords",
        description="Run a list of benchmark based on a provided list of keywords and outputs "
        "the results in a chosen format. Keywords can be negated using the ^ character.",
        epilog=f"\nbenchmarks:\n{get_formatted_benchmark_str(True)}"
        "\n\nexamples:\n  asct run all\n  asct run latency bandwidth ^sweep --format json --output-dir my_dir\n"
        "  asct run idle-latency peak-bandwidth --format csv",
        formatter_class=ASCTCustomHelpFormatter,
    )

    # 'system-info' command
    sysreport_cmd = commands.add_parser(
        "system-info",
        usage="asct system-info [options...]",
        parents=[common_no_wr_args, common_wr_args],
        help="Get system information",
        description="Output a report containing information about the hardware and software installed on the system.",
        epilog="examples:\n  asct system-info\n  asct system-info --format json --output-dir my_dir",
        formatter_class=ASCTCustomHelpFormatter,
    )
    sysreport_cmd.set_defaults(func=exec_cmd_system_info)

    # 'help' command
    help_cmd = commands.add_parser(
        "help",
        usage="asct help [command]",
        help="Display the general help page or the help page for a specified command",
        description="Display the general help page or the help page for a specified command.",
        epilog="examples:\n  asct help\n  asct help list",
        formatter_class=ASCTCustomHelpFormatter,
    )
    help_cmd.add_argument("cmd", nargs="?", default=None, action=ASCTCommandHelpAction, help=argparse.SUPPRESS)
    help_cmd.set_defaults(func=exec_cmd_help, parser=parser, subcommands=commands)

    # 'version' command
    version_cmd = commands.add_parser(
        "version",
        usage="asct version [-h]",
        help="Print version number",
        description="Print version information for the tool",
        epilog="example:\n  asct version",
        formatter_class=ASCTCustomHelpFormatter,
    )
    version_cmd.set_defaults(func=exec_cmd_version)

    # 'list' command
    list_bm_cmd = commands.add_parser(
        "list",
        usage="asct list [options...]",
        parents=[common_no_wr_args, common_wr_args],
        help="Get a list of available benchmarks",
        description="Output a list of available benchmarks and their description.",
        epilog="examples:\n  asct list\n  asct list --format json --output-dir my_dir",
        formatter_class=ASCTCustomHelpFormatter,
    )
    list_bm_cmd.set_defaults(func=exec_cmd_list_benchmarks)

    run_cmd.add_argument(
        "benchmarks",
        nargs="*",
        default=TAG_DEFAULT,
        choices=get_all_tags(),
        help=argparse.SUPPRESS,
    )

    run_cmd.add_argument(
        "--dev-mode",
        action="store_true",
        help=argparse.SUPPRESS,  # hidden option - only for development, don't expose to users
    )

    run_cmd.add_argument(
        "--quick-mode",
        action="store_true",
        help=argparse.SUPPRESS,  # hidden option - only for testing, don't expose to users
    )

    run_cmd.add_argument(
        "--dry-run",
        action="store_true",
        help="Print a list of benchmarks that would run based on the provided "
        "command line arguments and exits immediately",
    )

    run_cmd.add_argument(
        "--no-progress-bar",
        action="store_true",
        help="Disable the progress bar, use single line update messages instead",
    )

    run_cmd.add_argument(
        "--user-config",
        nargs="+",
        action=ASCTSettingsAction,
        metavar="SETTING=VALUE",
        default=None,
        help="Set benchmark user-defined settings",
    )

    run_cmd.add_argument(
        "--user-config-file",
        default="",
        help="Read benchmark user-defined settings from a given JSON file",
    )

    run_cmd.set_defaults(func=exec_cmd_run)

    # 'diff' command
    diff_cmd = commands.add_parser(
        "diff",
        usage="asct diff run_dir [run_dir2 ...] [options...]",
        parents=[common_no_wr_args, common_wr_args],
        help="Compare results from previous ASCT runs",
        description="Compare results from previous ASCT runs. "
        "If --baseline is omitted, the first run_dir is used as baseline and at least one more run_dir is required. "
        "If --baseline is provided, only one run_dir is required.",
        epilog="examples:\n"
        "  asct diff runA runB\n"
        "  asct diff runB --baseline runA\n"
        "  asct diff runB runC --baseline runA --sort-by runC",
        formatter_class=ASCTCustomHelpFormatter,
    )

    diff_cmd.add_argument(
        "run_dirs",
        nargs="+",
        type=str,
        help="Run output directories. Need 2+ if --baseline not set, otherwise 1+.",
    )

    diff_cmd.add_argument(
        "--baseline",
        "-b",
        type=str,
        help="Specify the baseline run for comparison",
    )

    diff_cmd.add_argument(
        "--benchmarks",
        "-k",
        nargs="+",
        metavar="BENCH",
        default=[],
        help="Limit comparison to these benchmarks (default: all benchmarks in the runs).",
    )

    diff_cmd.add_argument(
        "--sort-by",
        "-s",
        type=str,
        default=None,
        help="Sort the output table by the specified column (default: no sorting)",
    )

    diff_cmd.set_defaults(func=exec_cmd_diff)

    # 'sysreg' command
    sysreg_cmd = commands.add_parser(
        "sysreg",
        usage="asct sysreg [options...]",
        parents=[common_no_wr_args, common_wr_args],
        help="Get system register information",
        description="Output a report containing information about the system register values",
        epilog="examples:\n  asct sysreg\n  asct sysreg --format json --output-dir my_dir",
        formatter_class=ASCTCustomHelpFormatter,
    )
    sysreg_cmd.set_defaults(func=exec_cmd_sysreg)

    # if no arguments provided, show the help message instead
    args = parser.parse_args(args=None if sys.argv[1:] else ["help"])
    args.func(args)


if __name__ == "__main__":
    main()
