/* Copyright (C) 2023-2025 by Arm Limited (or its affiliates). All rights reserved. */

#include "metrics/group_generator.hpp"

#include "lib/Assert.h"
#include "lib/Span.h"
#include "metrics/definitions.hpp"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <string_view>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

namespace metrics {
    namespace {
        struct raw_combination_t {
            std::unordered_set<metric_events_set_t const *> contains_sets;
            std::unordered_map<std::uint16_t, std::uint16_t> event_code_to_freq_multiplier;
            metric_priority_t priority;
            metric_arch_t arch;
            bool uses_cycles;

            raw_combination_t(std::unordered_set<metric_events_set_t const *> contains_sets,
                              std::unordered_map<std::uint16_t, std::uint16_t> event_code_to_freq_multiplier,
                              metric_priority_t priority,
                              metric_arch_t arch,
                              bool uses_cycles)
                : contains_sets(std::move(contains_sets)),
                  event_code_to_freq_multiplier(std::move(event_code_to_freq_multiplier)),
                  priority(priority),
                  arch(arch),
                  uses_cycles(uses_cycles)
            {
            }
        };

        [[nodiscard]] bool is_cycle_counter(std::uint16_t code, metric_arch_t arch)
        {
            constexpr std::uint16_t arm32_linux_cycle_counter = 0xff;
            constexpr std::uint16_t arm64_linux_cycle_counter = 0x11;

            switch (arch) {
                case metric_arch_t::v7:
                    return (code == arm32_linux_cycle_counter);
                case metric_arch_t::v8:
                    return (code == arm64_linux_cycle_counter);
                case metric_arch_t::any:
                default:
                    return false;
            }
        }

        [[nodiscard]] constexpr metric_priority_t select_best(metric_priority_t a, metric_priority_t b)
        {
            return std::min(a, b);
        }

        [[nodiscard]] metric_arch_t combine_arch(metric_arch_t a, metric_arch_t b)
        {
            if (a == metric_arch_t::any) {
                return b;
            }

            if (b == metric_arch_t::any) {
                return b;
            }

            runtime_assert(a == b, "Invalid arch combo");

            return a;
        }

        [[nodiscard]] constexpr std::uint16_t to_event_code(std::uint16_t event)
        {
            return event;
        }

        template<typename EventCodesA, typename EventCodesB>
        [[nodiscard]] std::unordered_map<std::uint16_t, std::uint16_t> combine_codes(EventCodesA const & event_codes_a,
                                                                                     metric_arch_t arch_a,
                                                                                     EventCodesB const & event_codes_b,
                                                                                     metric_arch_t arch_b)
        {
            std::unordered_map<std::uint16_t, std::uint16_t> result {};

            for (auto const & [event, freq_multiplier] : event_codes_a) {
                auto const code = to_event_code(event);
                if (!is_cycle_counter(code, arch_a)) {
                    auto const [it, inserted] = result.try_emplace(code, freq_multiplier);

                    if (!inserted && (it->second > freq_multiplier)) {
                        it->second = freq_multiplier;
                    }
                }
            }

            for (auto const & [event, freq_multiplier] : event_codes_b) {
                auto const code = to_event_code(event);
                if (!is_cycle_counter(code, arch_b)) {
                    auto const [it, inserted] = result.try_emplace(code, freq_multiplier);

                    if (!inserted && (it->second > freq_multiplier)) {
                        it->second = freq_multiplier;
                    }
                }
            }

            return result;
        }

        template<typename EventCodes>
        [[nodiscard]]
        std::unordered_map<std::uint16_t, std::uint16_t> filter_cycles(EventCodes const & event_codes,
                                                                       metric_arch_t arch)
        {
            std::unordered_map<std::uint16_t, std::uint16_t> result {};

            for (auto const & event : event_codes) {
                // insert the code into the map
                if (!is_cycle_counter(event.code, arch)) {
                    result.try_emplace(event.code, event.freq_multiplier);
                }
            }

            return result;
        }

        void make_initial_combinations_inner(
            std::size_t max_events,
            lib::Span<std::reference_wrapper<metrics::metric_events_set_t const> const> const & metric_events,
            std::function<bool(metric_events_set_t const &)> const & filter_predicate,
            bool & has_boundness,
            bool & has_stalled_cycles,
            std::vector<raw_combination_t> & result,
            std::unordered_set<metric_events_set_t const *> & consumed_metrics)
        {
            for (auto min_priority = metric_priority_t::min_value; min_priority <= metric_priority_t::max_value;
                 min_priority = metric_priority_t(unsigned(min_priority) + 1)) {

                for (metric_events_set_t const & metric_a : metric_events) {
                    // filter based on current priority level
                    if ((min_priority < metric_priority_t::max_value)
                        && (metric_a.priority_group >= metric_priority_t::min_value)
                        && (metric_a.priority_group != min_priority)) {
                        continue;
                    }

                    // filter metric based on predicate
                    if (!filter_predicate(metric_a)) {
                        continue;
                    }

                    // dont reuse metrics
                    if (auto const [it, inserted] = consumed_metrics.insert(&metric_a); !inserted) {
                        (void) it; // GCC7 :-(
                        continue;
                    }

                    raw_combination_t current_combination {
                        {&metric_a},
                        filter_cycles(metric_a.event_codes, metric_a.arch),
                        metric_a.priority_group,
                        metric_a.arch,
                        metric_a.uses_cycles,
                    };

                    if (current_combination.event_code_to_freq_multiplier.size() > max_events) {
                        continue;
                    }

                    // update the input flags
                    has_boundness |= (current_combination.priority == metric_priority_t::boundness);
                    has_stalled_cycles |= (current_combination.priority == metric_priority_t::stall_cycles);

                    // stick normalized priority items together
                    for (metric_events_set_t const & metric_b : metric_events) {
                        // filter metric based on predicate
                        if (!filter_predicate(metric_b)) {
                            continue;
                        }

                        // dont reuse metrics
                        if (consumed_metrics.count(&metric_b) != 0) {
                            continue;
                        }

                        // compute
                        auto const b_event_code_to_numerator_flag = filter_cycles(metric_b.event_codes, metric_b.arch);

                        // combine the event codes
                        auto combined_codes = combine_codes(current_combination.event_code_to_freq_multiplier,
                                                            current_combination.arch,
                                                            b_event_code_to_numerator_flag,
                                                            metric_b.arch);
                        if (combined_codes.size() > max_events) {
                            continue;
                        }

                        // filter based on current priority level (but only if there is events added to the set)
                        if ((min_priority < metric_priority_t::max_value) && (metric_b.priority_group > min_priority)
                            && (combined_codes.size() > current_combination.event_code_to_freq_multiplier.size())) {
                            continue;
                        }

                        // combine architectures
                        auto const combined_arch = combine_arch(current_combination.arch, metric_b.arch);

                        // update current
                        consumed_metrics.insert(&metric_b);
                        current_combination.contains_sets.insert(&metric_b);
                        current_combination.arch = combined_arch;
                        current_combination.event_code_to_freq_multiplier = std::move(combined_codes);
                        current_combination.priority =
                            select_best(current_combination.priority, metric_b.priority_group);
                    }

                    result.emplace_back(std::move(current_combination));
                }
            }
        }

        [[nodiscard]] std::vector<raw_combination_t> make_initial_combinations(
            std::size_t max_events,
            lib::Span<std::reference_wrapper<metrics::metric_events_set_t const> const> events,
            std::function<bool(metric_events_set_t const &)> const & filter_predicate,
            bool & has_boundness,
            bool & has_stalled_cycles)
        {
            std::vector<raw_combination_t> result {};

            std::unordered_set<metric_events_set_t const *> consumed_metrics {};

            has_boundness = false;
            has_stalled_cycles = false;

            make_initial_combinations_inner(max_events,
                                            events,
                                            filter_predicate,
                                            has_boundness,
                                            has_stalled_cycles,
                                            result,
                                            consumed_metrics);

            return result;
        }

        [[nodiscard]] bool is_already_consumed(std::unordered_set<metric_events_set_t const *> const & consumed_metrics,
                                               raw_combination_t const & combination)
        {
            for (auto const * set : combination.contains_sets) {
                if (consumed_metrics.count(set) != 0) {
                    return true;
                }
            }
            return false;
        }

        template<typename Predicate>
        [[nodiscard]] std::vector<raw_combination_t> combine_combinations(
            std::size_t max_events,
            std::vector<raw_combination_t> initial_combinations,
            Predicate && predicate)
        {
            static_assert(std::is_invocable_r_v<bool, Predicate, raw_combination_t const &, raw_combination_t const &>);

            // attempt to mush combinations together
            while (true) {
                std::vector<raw_combination_t> new_combinations {};
                std::unordered_set<metric_events_set_t const *> consumed_metrics {};

                bool modified = false;

                for (auto const & combination_a : initial_combinations) {
                    // dont reuse metrics
                    if (is_already_consumed(consumed_metrics, combination_a)) {
                        continue;
                    }

                    if (combination_a.event_code_to_freq_multiplier.size() > max_events) {
                        continue;
                    }

                    // base the new combination of of our starting point
                    raw_combination_t current_combination {combination_a};
                    consumed_metrics.insert(combination_a.contains_sets.begin(), combination_a.contains_sets.end());

                    // attempt to append other combinations to the current combination
                    for (auto const & combination_b : initial_combinations) {
                        // dont reuse metrics
                        if (is_already_consumed(consumed_metrics, combination_b)) {
                            continue;
                        }

                        // check combination
                        if (!predicate(current_combination, combination_b)) {
                            continue;
                        }

                        // combine the event codes
                        auto combined_codes = combine_codes(current_combination.event_code_to_freq_multiplier,
                                                            current_combination.arch,
                                                            combination_b.event_code_to_freq_multiplier,
                                                            combination_b.arch);
                        if (combined_codes.size() > max_events) {
                            continue;
                        }

                        // combine architectures
                        auto const combined_arch = combine_arch(current_combination.arch, combination_b.arch);

                        // update current
                        modified |= combined_codes.size() != current_combination.event_code_to_freq_multiplier.size();
                        consumed_metrics.insert(combination_b.contains_sets.begin(), combination_b.contains_sets.end());
                        current_combination.contains_sets.insert(combination_b.contains_sets.begin(),
                                                                 combination_b.contains_sets.end());
                        current_combination.arch = combined_arch;
                        current_combination.event_code_to_freq_multiplier = std::move(combined_codes);
                        current_combination.priority =
                            select_best(current_combination.priority, combination_b.priority);
                    }

                    new_combinations.emplace_back(std::move(current_combination));
                }

                if (!modified) {
                    return new_combinations;
                }

                initial_combinations = std::move(new_combinations);
            }
        }

        template<metric_priority_t... Enums>
        [[nodiscard]] constexpr bool is_one_of(metric_priority_t v)
        {
            return (... || (v == Enums));
        }

        template<metric_priority_t... Priorities>
        [[nodiscard]] constexpr auto filter_for_priorities()
        {
            return [](raw_combination_t const & a, raw_combination_t const & b) -> bool {
                return a.priority == b.priority
                    || (is_one_of<Priorities...>(a.priority) == is_one_of<Priorities...>(b.priority));
            };
        }

        [[nodiscard]] std::vector<combination_t> convert_to_final(std::vector<raw_combination_t> combinations)
        {
            std::vector<combination_t> result {};
            result.reserve(combinations.size());

            for (auto & combination : combinations) {
                result.emplace_back(std::move(combination.contains_sets),
                                    std::move(combination.event_code_to_freq_multiplier),
                                    combination.arch,
                                    combination.uses_cycles);
            }

            return result;
        }
    }

    metric_cpu_event_map_entry_t const * find_events_for_cset(std::string_view cset_id)
    {
        if (auto const it = cpu_metrics_table.find(cset_id); it != cpu_metrics_table.end()) {
            return &(it->second);
        }
        return nullptr;
    }

    std::vector<combination_t> make_combinations(
        std::size_t max_events,
        lib::Span<std::reference_wrapper<metrics::metric_events_set_t const> const> events,
        std::function<bool(metric_events_set_t const &)> const & filter_predicate)
    {
        bool has_boundness = false;
        bool has_stalled_cycles = false;

        // make the initial set
        auto raw_combinations =
            make_initial_combinations(max_events, events, filter_predicate, has_boundness, has_stalled_cycles);

        // merge boundness and top_level if possible
        raw_combinations =
            combine_combinations(max_events,
                                 std::move(raw_combinations),
                                 filter_for_priorities<metric_priority_t::top_level, metric_priority_t::boundness>());

        // merge branch and top_level if the group has boundness and stalled_cycles (branches are prioritized over stall cycles)
        if (has_boundness && has_stalled_cycles) {
            raw_combinations =
                combine_combinations(max_events,
                                     std::move(raw_combinations),
                                     filter_for_priorities<metric_priority_t::top_level, metric_priority_t::branch>());
        }

        // merge stalled_cycles and top_level if possible
        raw_combinations = combine_combinations(
            max_events,
            std::move(raw_combinations),
            filter_for_priorities<metric_priority_t::top_level, metric_priority_t::stall_cycles>());

        // merge branch and top_level if not done previously
        if (!has_boundness || !has_stalled_cycles) {
            raw_combinations =
                combine_combinations(max_events,
                                     std::move(raw_combinations),
                                     filter_for_priorities<metric_priority_t::top_level, metric_priority_t::branch>());
        }

        // merge boundness, stall_cylces, frontend, backend
        raw_combinations = combine_combinations(max_events,
                                                std::move(raw_combinations),
                                                filter_for_priorities<metric_priority_t::boundness,
                                                                      metric_priority_t::stall_cycles,
                                                                      metric_priority_t::frontend,
                                                                      metric_priority_t::backend>());

        // merge data and top_level
        raw_combinations =
            combine_combinations(max_events,
                                 std::move(raw_combinations),
                                 filter_for_priorities<metric_priority_t::top_level, metric_priority_t::data>());

        // merge data and ls
        raw_combinations =
            combine_combinations(max_events,
                                 std::move(raw_combinations),
                                 filter_for_priorities<metric_priority_t::data, metric_priority_t::ls>());

        // merge data, ls, l2
        raw_combinations = combine_combinations(
            max_events,
            std::move(raw_combinations),
            filter_for_priorities<metric_priority_t::data, metric_priority_t::ls, metric_priority_t::l2>());

        // merge data, ls, l2, l3
        raw_combinations = combine_combinations(max_events,
                                                std::move(raw_combinations),
                                                filter_for_priorities<metric_priority_t::data,
                                                                      metric_priority_t::ls,
                                                                      metric_priority_t::l2,
                                                                      metric_priority_t::l3>());

        // merge data, ls, l2, l3, ll
        raw_combinations = combine_combinations(max_events,
                                                std::move(raw_combinations),
                                                filter_for_priorities<metric_priority_t::data,
                                                                      metric_priority_t::ls,
                                                                      metric_priority_t::l2,
                                                                      metric_priority_t::l3,
                                                                      metric_priority_t::ll>());

        // merge anything else that will fit together
        raw_combinations =
            combine_combinations(max_events,
                                 std::move(raw_combinations),
                                 [](raw_combination_t const & /*a*/, raw_combination_t const & /*b*/) { return true; });

        return convert_to_final(std::move(raw_combinations));
    }
}
