#!/usr/bin/env python
# -----------------------------------------------------------------------------
# The proprietary software and information contained in this file is
# confidential and may only be used by an authorized person under a valid
# licensing agreement from Arm Limited or its affiliates.
#
# Copyright (C) 2025. Arm Limited or its affiliates. All rights reserved.
#
# This entire notice must be reproduced on all copies of this file and
# copies of this file may only be made by an authorized person under a valid
# licensing agreement from Arm Limited or its affiliates.
# -----------------------------------------------------------------------------

import os
import sys
import argparse
from collections import defaultdict
from contextlib import nullcontext
import json
import csv
import io
import shutil
from typing import Optional
from datetime import datetime
from importlib.metadata import version, PackageNotFoundError

from asct.core.asct_env import ProcessMutex, ASCTGlobalSettings as AGS
from asct.core.asct_env import SystemState
import asct.core.logger as log
import asct.sysreport.sysreport as sr
from asct.core.ubench_reporter import get_reporter
from asct.core.resources.output_folder import OutputFolder
from asct.core.recipes.memory_load_latency import MemoryRecipeCategory
from asct.core.ubench_cli import get_progress_tracker
from asct.core.asct_pmu_api import start_pipe_monitoring

ASCT_BENCHMARKS = {"memory": MemoryRecipeCategory()}


def write_to_file(filepath: str, data: str, mode: str, fmt: Optional[str] = None):
    """
    Writes data to a file within the 'data/' directory, creates the data/
    directory if it doesn't already exist

            filepath: path of file to write to

            data: string to be written to file

       open_mode: python open mode e.g. "w" for write, "a" for append
                  see https://docs.python.org/3/library/functions.html#open

             fmt: string containing the format used in info message
                  e.g. f"{fmt} output written to: ..."

                  if None (default), no message is sent to info

    """
    try:
        with open(filepath, mode) as file:
            file.write(data)
    except PermissionError:
        log.critical(
            f"{filepath} already exists, but we don't have permission to edit - please remove before continuing"
        )
        sys.exit(1)
    except Exception as e:
        log.critical(f"[{e}] Error outputting to {filepath}")
        sys.exit(1)

    if fmt:
        log.info(f"{fmt} output written to: {filepath}")


def get_version():
    """Returns the version of ASCT as defined in pyproject.toml."""
    try:
        return version("asct")  # MUST match [project].name in pyproject.toml
    except PackageNotFoundError:
        return "version not found"


def output_stdout(_, sysreport, memory_results, skipped_list):
    print("\nSystem Information ------------------------------------------------------------\n")
    sr.show(sysreport, False, False, False)

    if memory_results:
        print("\n\nMemory Characterization -------------------------------------------------------\n")
        memory_benchmarks = ASCT_BENCHMARKS["memory"]
        for k in memory_results:
            if k not in skipped_list:
                memory_benchmarks[k].to_stdout()
            else:
                print(f"{k}: skipped - see ERROR log messages")


def output_csv(args, sysreport, memory_results, skipped_list):
    # sysreport
    out_filepath = os.path.join(args.output_dir, "sysreport.csv")
    csv = sr.dump_as_csv(sysreport)
    write_to_file(out_filepath, csv, "w", "CSV")

    if memory_results:
        memory_benchmarks = ASCT_BENCHMARKS["memory"]
        for k in filter(lambda x: x not in skipped_list, memory_results):
            # csv is a pure table format, output only title as filename and tables
            out_filepath = os.path.join(args.output_dir, f"{k}.csv")
            csv = memory_benchmarks[k].to_csv()
            write_to_file(out_filepath, csv, "w", "CSV")


def output_json(args, sysreport, memory_results, skipped_list):
    out_filepath = os.path.join(args.output_dir, "report.json")
    # to create a properly nested structure, we first create one huge
    # nested python dictionary and then dump that whole structure to JSON
    # use a defaultdict so we create top level keys if they don't already exist,
    # instead of returning a KeyError
    out_dict = defaultdict(dict)
    # always include sysreport in the JSON report
    out_dict["sysreport"] = sr.dump_as_dict(sysreport)
    if memory_results:
        memory_benchmarks = ASCT_BENCHMARKS["memory"]
        for k in filter(lambda x: x not in skipped_list, memory_results):
            # nest the results under "memory" heading and benchmark name as sub-group
            out_dict["memory"][k] = memory_benchmarks[k].to_json()

    json_str = json.dumps(out_dict)
    # overwrite any existing file aka "w" open mode
    write_to_file(out_filepath, json_str, "w", "JSON")


def output(args, sysreport, memory_results, skipped_list):
    # this just dispatches to output_{stdout,csv,json} functions in a single line
    # we already validated the args.format selections/choices with argparse
    getattr(sys.modules[__name__], f"output_{args.format}")(args, sysreport, memory_results, skipped_list)


def print_asct_help(print_only_benchmarks=False):
    if not print_only_benchmarks:
        print("info:")
        print("  --sysreport, -s              Print available system information and quit\n")
        print("benchmarks:")
        print("  --all, -a                    Run all available benchmarks")
    for benchmark_category in ASCT_BENCHMARKS.values():
        print(
            benchmark_category.get_help_string(
                not print_only_benchmarks,
                base_left_indent=2,
                items_left_indent=4,
                column_spacing=3,
                total_width=shutil.get_terminal_size().columns - 1,
            )
        )


def print_help(parser):
    default_help = parser.format_help()
    print(default_help)
    # we want to manually display benchmark options - the default is bad
    # and doesn't give the level of detail we want
    print_asct_help()


class CustomHelpFormatter(argparse.HelpFormatter):
    """
    The default argparse help message has some duplication when you have both a
    short and a long argument version.

    This custom formatter allows us to end up with something like the below:
        --format, -f [stdout,csv,json]
    instead of:
        --format [stdout,csv,json], -f [stdout,csv,json]
    """

    def _format_action_invocation(self, action):
        if not action.option_strings or action.nargs == 0:
            return super()._format_action_invocation(action)
        default = self._get_default_metavar_for_optional(action)
        args_string = self._format_args(action, default)
        return ", ".join(action.option_strings) + " " + args_string


def write_benchmark_list_json(output_dir):
    benchmark_list = {}
    for name, benchmark_category in ASCT_BENCHMARKS.items():
        benchmark_list[name] = benchmark_category.to_json_dict()
    json_str = json.dumps(benchmark_list)
    out_filepath = os.path.join(output_dir, "benchmark_list.json")
    write_to_file(out_filepath, json_str, "w", "JSON")


def write_benchmark_list_csv(output_dir):
    benchmark_data = [["category", "benchmark", "description"]]
    for benchmark_category in ASCT_BENCHMARKS.values():
        benchmark_data += benchmark_category.to_csv_array()
    output = io.StringIO()
    writer = csv.writer(output)
    writer.writerows(benchmark_data)
    csv_string = output.getvalue()
    output.close()
    out_filepath = os.path.join(output_dir, "benchmark_list.csv")
    write_to_file(out_filepath, csv_string, "w", "CSV")


def output_benchmark_list(output_format, output_dir):
    if output_format == "stdout":
        print_asct_help(True)
        return

    output_folder_resource = OutputFolder(output_dir, True)
    output_folder_resource.setup()

    if output_format == "json":
        write_benchmark_list_json(output_dir)
    elif output_format == "csv":
        write_benchmark_list_csv(output_dir)

    output_folder_resource.teardown()


def main():
    parser = argparse.ArgumentParser(
        description=f"ASCT: Arm System Characterization Tool (version: {get_version()})",
        add_help=False,
        formatter_class=CustomHelpFormatter,
    )

    parser.add_argument(
        "--help",
        "-h",
        action="store_true",
        help="Display this help message/usage",
    )

    format_options = ["stdout", "csv", "json"]
    parser.add_argument(
        "--format",
        "-f",
        choices=format_options,
        type=str.lower,
        default="stdout",
        help="Output format (default: stdout)",
        metavar=f"[{','.join(map(str, format_options))}]",
    )
    parser.add_argument(
        "--output-dir",
        "-o",
        type=str,
        default=f"data.{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}",
        help="Path to the output directory (default: 'data.<YYYYMMDD_HHMMSS_microseconds>' in cwd)",
    )
    parser.add_argument("--force", action="store_true", default=False, help="Reuse the output directory if it exists")
    parser.add_argument(
        "--log-level",
        "-L",
        choices=log.get_log_levels(),
        type=str.lower,
        default=log.DEFAULT_LOG_LEVEL,
        help="Logging level (default: info)",
        metavar=f"[{','.join(map(str, log.get_log_levels()))}]",
    )
    parser.add_argument(
        "--log-file",
        type=str,
        default=None,
        help="File to output logging messages (in addition to stderr)",
    )
    parser.add_argument(
        "--quiet",
        "-q",
        action="store_true",
        help="Disable all output to stdout/stderr, including critical errors and log messages. Use --log-file to "
        "capture and view logs.",
    )
    parser.add_argument(
        "--all",
        "-a",
        action="store_true",
        help=argparse.SUPPRESS,
    )
    parser.add_argument(
        "--dev-mode",
        action="store_true",
        help=argparse.SUPPRESS,  # hidden option - only for development, don't expose to users
    )
    parser.add_argument(
        "--quick-mode",
        action="store_true",
        help=argparse.SUPPRESS,  # hidden option - only for testing, don't expose to users
    )

    parser.add_argument(
        "--list-benchmarks",
        "-b",
        action="store_true",
        help=argparse.SUPPRESS,  # hidden option (for now), used by ATP, don't expose to users
    )

    parser.add_argument(
        "--sysreport",
        "-s",
        action="store_true",
        help=argparse.SUPPRESS,
    )

    parser.add_argument(
        "--no-progress-bar",
        action="store_true",
        help="Disable the progress bar, use single line update messages instead",
    )

    parser.add_argument(
        "--version",
        "-v",
        action="store_true",
        help="Display the version of ASCT",
    )

    for benchmark_category in ASCT_BENCHMARKS.values():
        parser.add_argument(
            benchmark_category.cli_arg_long,
            benchmark_category.cli_arg_short,
            nargs="*",
            choices=benchmark_category.get_recipe_names(),
            help=argparse.SUPPRESS,
        )

    args = parser.parse_args()

    log.initialize(args.log_level, args.log_file, args.quiet)

    if args.version:
        print(f"ASCT {get_version()}")
        sys.exit(0)

    if args.help:
        print_help(parser)
        sys.exit(0)

    if args.list_benchmarks:
        output_benchmark_list(args.format, args.output_dir)
        sys.exit(0)

    memory_categ = ASCT_BENCHMARKS["memory"]
    if args.all:  # --all passed, run all benchmarks
        args.memory = memory_categ.get_recipe_names()

    # if neither --sysreport or --memory flags given, run default benchmarks
    elif not args.sysreport and not args.memory:
        log.info("No benchmarks specified, running default memory benchmarks.")
        args.memory = memory_categ.get_default_recipe_names()

    elif args.memory == []:  # --memory passed with no arguments
        log.info("No benchmarks specified, running default memory benchmarks.")
        args.memory = memory_categ.get_default_recipe_names()

    settings = AGS()
    settings.read_env_vars()

    if args.dev_mode:
        ctx = nullcontext()  # dev-mode doesn't run benchmarks so lock is not required
        settings.set_dev_mode()
    else:
        ctx = ProcessMutex("_asct_", retry_count=5, retry_wait=0.5)
        if args.quick_mode:
            settings.set_quick_mode()

    # Create the output folder resource but don't set it up yet
    # we will do that after acquiring the lock
    output_folder_resource = OutputFolder(args.output_dir, args.force)

    with ctx as global_lock:
        if global_lock and not global_lock.lock_successful():
            log.error(global_lock.get_error())
            sys.exit(1)

        # TODO: Ideally all the shared resources should be registered and applied
        # here currently some shared resources are registered in the SystemState
        # constructor as a result of asct/src/core/asct_pmu_api.py:180
        # thread using the pipe resource on before this main function runs.
        SystemState().resource_manager.register(output_folder_resource, global_scope=True)
        SystemState().resource_manager.apply_all(global_scope=True)

        get_reporter().output_dir = output_folder_resource.get_output_folder_path()

        # always run sysreport first and pass to the benchmarks to use
        sysreport = sr.run_sysreport()

        priority_list = []
        skipped_list = []

        if args.memory:
            if settings.enable_pmu:
                start_pipe_monitoring()

            get_progress_tracker().initialize(2, not args.no_progress_bar, args.quiet)

            # get the benchmark dependencies for the specified memory benchmarks that are not already included
            # in the args.memory list and add them to the list
            dependencies = memory_categ.resolve_dependencies(args.memory)
            if dependencies:
                log.info(f"Adding dependencies to the benchmark list: {', '.join(dependencies)}")
                args.memory.extend(dependencies)

            # sorting the benchmarks by their priority, highest first
            priority_list = memory_categ.get_recipes(args.memory)
            log.info(
                f"Executing memory benchmarks in the following order: {', '.join([name for name, _ in priority_list])}"
            )

            for benchmark_name, benchmark in get_progress_tracker().iterate(
                priority_list, lambda _, elem: f"{elem[0]}"
            ):
                benchmark.set_config({"sysreport": sysreport, "pmu_mode": settings.enable_pmu})
                try:
                    benchmark.run()
                except Exception as e:
                    log.error(f"Error running {benchmark_name} - skipping benchmark")
                    log.error(f"Reason for skipping: {e}")
                    skipped_list.append(benchmark_name)
                    continue

            get_progress_tracker().wait_until_idle()
            get_progress_tracker().terminate()

    output(args, sysreport, args.memory, skipped_list)


if __name__ == "__main__":
    main()
