#!/usr/bin/env python3
# SPDX-License-Identifier: BSD-3-Clause
#
# Copyright (C) 2024 by Arm Limited.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
This script is a formatter for raw profiling data generated by Streamline's
command line prototype analysis tooling. It consumes a CSV input file,
generated by Streamline's prototype analysis command line tool, and outputs a
formatted XLSX file. Formatting is configured based on a customizable YAML
configuration.

This script requires Python 3.8, and requires additional third-party modules as
defined in the requirements.txt file.
'''
import argparse
import datetime
import numpy as np
import os
import pandas as pd
import re
import statistics
import sys
import yaml

# Default config file
DEFAULT_CONFIG = 'sl-format_default-config.yaml'

# Build number (patched by release process)
RELEASE_BUILD = 'v9.2.2'

# Maximum number of shades from white to max/255 red. We only use 160 shades
# because shades of very saturated red are hard to visually distinguish.
MAX_SHADE = 159
MAX_SHADE_COUNT = MAX_SHADE + 1

# Must be at least one input column in the data using the primary key. This is
# used to cross-reference between PMU and SPE data sets, as the rows may not
# line up between the two.
PRIMARY_KEY = 'symbol'

# To ensure we have at least one common sample count column we create a "Meta:
# Samples" column using (in precedence order): metrics, periodic, spe. One of
# these must exist in each data file loaded.
META_SAMPLES = "Meta: Samples"
METRICS_SAMPLES = "Metrics: Sample Count"
PERIODIC_SAMPLES = "Periodic Samples: Periodic Samples"
SPE_SAMPLES = "SPE: # Samples"


class ConfigSet(list):
    '''
    A parsed YAML configuration file.
    '''

    def __init__(self, config_file):
        '''
        Populate a config set from the image file.

        Args:
            config_file: The input config file.

        Raises:
            FileNotFoundError: File not found.
            KeyError: Expected YAML key missing.
        '''
        # Load the data from file
        with open(config_file, 'r') as handle:
            raw_config = yaml.safe_load(handle)

        # Generate classes for each node
        configs = [Config(i, x) for i, x in enumerate(raw_config)]

        # Pull out primaries
        for config in configs:
            if not config.is_alt_for:
                self.append(config)

        # Insert alternatives with primary as parent
        for config in configs:
            if not config.is_alt_for:
                continue

            target = config.is_alt_for
            primary = self.get_by_dst(target)
            if not primary:
                raise KeyError(f'ERROR: Parent series {target} is missing')
            primary.alternatives.append(config)

    def get_by_dst(self, dst_name):
        '''
        Get a counter config by dst_name.

        Args:
            dst_name: The name to match.

        Returns:
            Return Counter or None if not found.
        '''
        for config in self:
            if config.dst_name == dst_name:
                return config

        return None


class Config:
    '''
    A parsed YAML configuration entry.
    '''

    def __init__(self, index, yaml):
        '''
        A parsed YAML configuration entry.

        Args:
            index: The entry index in the config file.
            yaml: The YAML config entry.

        Raises:
            KeyError: Expected YAML key missing.
        '''

        # Perform some basic consistency checks on YAML payload
        if 'series' not in yaml:
            raise KeyError(f'ERROR: Config item {index} is not a series')

        # Check: All series must contain a source name
        yaml = yaml['series']
        if 'src_name' not in yaml:
            raise KeyError(f'ERROR: Config item missing "src_name" key')

        # Populate class variables for the common mappings
        self.src_name = yaml['src_name']
        self.dst_name = yaml.get('dst_name', self.src_name)

        self.src_type = yaml.get('src_type', 'pmu')
        self.data_type = yaml.get('dtype', None)

        self.is_primary = self.src_name == PRIMARY_KEY

        self.alternatives = []
        self.is_alt_for = yaml.get('alt_for', None)

        # Keep the raw YAML
        self.yaml = yaml


def get_copy_yr():
    '''
    Return the year range for the header string.
    '''
    startyear = 2024
    endyear = datetime.datetime.now().year

    year = f"{startyear}"
    if startyear != endyear:
        year = f"{year}-{endyear}"

    return year


def printe(*args):
    '''
    Print an error message to stderr.
    '''
    print(*args, file=sys.stderr)


def load_raw_data(config, data_file):
    '''
    Load a raw CSV file as pandas data frame
    '''
    # Set converters on string columns to avoid NaNs in empty cells
    converters = {}
    for series in config:
        if series.data_type == 'str':
            converters[series.src_name] = str
        elif series.data_type == 'int':
            converters[series.src_name] = int

    # Try to load the file
    try:
        data_frame = pd.read_csv(data_file, converters=converters)

    except FileNotFoundError:
        printe(f'ERROR: No source file "{data_file}"')
        return None

    # Setup a common meta sample count column, no matter what data is used
    sample_count_columns = (
        METRICS_SAMPLES,
        PERIODIC_SAMPLES,
        SPE_SAMPLES
    )

    for column in sample_count_columns:
        if column in data_frame.keys():
            data_frame[META_SAMPLES] = data_frame[column]
            break
    else:
        printe(f'ERROR: No sample count column found in "{data_file}"')
        return None

    return data_frame


def load_raw_spe_data(config, data_file):
    '''
    Load a raw CSV file as pandas data frame
    '''
    data_frame = load_raw_data(config, data_file)
    if data_frame is None:
        return None

    # Is this SPE data?
    for column in data_frame.keys():
        if column.startswith("SPE"):
            return data_frame

    printe(f'ERROR: No SPE data found in "{data_file}"')
    return None


def strip_params(value):
    '''
    Strip function parameters from a function prototype.

    WARNING: This is a simplistic implementation, so it is conservative.
    '''
    # Preprocess any sized array pointers into basic pointers
    # For example, "unsigned char (*) [4]" is turned into "unsigned char*"
    pattern = re.compile(r" \(\*\) \[\d+\]")
    while True:
        value, count = pattern.subn("*", value)
        if not count:
            break

    # Strip parameters from anything with a single set of parens
    value = str(value)
    if value.count('(') == 1 and value.count(')') == 1:
        return str(value).split('(')[0]
    return value


def select_series(config, main_series, data):
    """
    Select a config main or alternative based on data presence.

    Args:
        config: The ConfigSet to modify, if needed.
        main_series: The main series Config object.
        data: The data table to search for a match.

    Returns:
        Returns the matching series config, or None if no matches.
    """
    index = config.index(main_series)

    if main_series.src_name in data:
        return main_series

    # Main series not found, so try to match a fallback alternative
    for alt_series in main_series.alternatives:
        if alt_series.src_name in data:
            found_alt_series = alt_series
            break

    # Have no alternative so drop series completely
    else:
        name = main_series.dst_name
        printe(f'WARNING: No series "{name}" or alternative')
        return None

    # Have an alternative so swap to that
    name = main_series.dst_name
    aname = found_alt_series.dst_name
    printe(f'WARNING: No series "{name}", using alternative "{aname}"')
    config[index] = found_alt_series
    return found_alt_series


def select_pmu_columns(config, data):
    '''
    Select columns from the raw data.

    Returns:
        New data frame and dst_name of the primary key column.
    '''
    new_data = pd.DataFrame()
    primary_series_dst_name = None

    # Select the PMU columns to transfer
    drop_keys = []
    for series in config:
        # Transfer PMU columns to start with
        if series.src_type != 'pmu':
            continue

        # Select a series or the first fallback
        matched_series = select_series(config, series, data)
        if not matched_series:
            drop_keys.append(series)
            continue

        series = matched_series
        data_values = data[series.src_name].values

        if matched_series.is_primary:
            primary_series_dst_name = matched_series.dst_name

        # Copy with any inline parameter stripping
        noparams = series.yaml.get('strip_params', False)
        if noparams:
            new_data[series.dst_name] = [strip_params(x) for x in data_values]
        else:
            new_data[series.dst_name] = data_values

    if not primary_series_dst_name:
        printe(f'ERROR: PMU data missing "src_name={PRIMARY_KEY}" primary key')

    for key in drop_keys:
        config.remove(key)

    return (new_data, primary_series_dst_name)


def select_spe_columns(config, data, spe_data, dst_name_primary):
    '''
    Merge in columns populated from the SPE data, based on using an
    overlapping primary key column value.
    '''
    # Build a remap table from PMU indices to SPE row indices
    pmu_symbols = [x for x in data[dst_name_primary]]

    if PRIMARY_KEY not in spe_data:
        printe(f'ERROR: SPE data missing "src_name={PRIMARY_KEY}" primary key')
        return None

    spe_symbols = [strip_params(x) for x in spe_data[PRIMARY_KEY]]

    remap_table = []
    for symbol in pmu_symbols:
        try:
            remap_table.append(spe_symbols.index(symbol))
        except ValueError:
            remap_table.append(None)

    # Select the SPE columns to transfer
    drop_keys = []
    for series in config:
        # Only handle SPE columns here
        if series.src_type != 'spe':
            continue

        # Select a series or the first fallback
        matched_series = select_series(config, series, spe_data)
        if not matched_series:
            drop_keys.append(series)
            continue

        series = matched_series

        # Build an SPE data column, in PMU row order
        new_column = []
        for spe_index in remap_table:
            if spe_index is not None:
                new_column.append(spe_data[series.src_name][spe_index])
            else:
                new_column.append(np.nan)

        # Insert the data into the data table
        data[series.dst_name] = new_column

    for key in drop_keys:
        config.remove(key)

    return data


def filter_rows(config, data):
    '''
    Filter unneeded rows from the data frame, edited in place.
    '''
    # Create a copy we can use for significance calculations
    orig_data = data.copy()

    # Filter first
    for series in config:
        dst_name = series.dst_name

        # Process data filters
        data_filter = series.yaml.get('filter', None)
        if data_filter == 'significance':
            min_significance = series.yaml.get('min_row_significance')
            series_total = sum(orig_data[dst_name].values)
            min_value = series_total * min_significance
            data.drop(data[data[dst_name] < min_value].index, inplace=True)

    # Data modify second
    for series in config:
        dst_name = series.dst_name

        # Process data modifications
        data_filter = series.yaml.get('data_mod', None)
        if data_filter == 'rev%':
            data[dst_name] = data[dst_name].apply(lambda x: 100.0 - x)


def style_abs_ramp_up(sconfig, value):
    '''
    Determine a style index for a value and an absolute_ramp_up style.
    '''
    min_ramp = sconfig['min_ramp']
    max_ramp = sconfig['max_ramp']
    assert min_ramp <= max_ramp, 'Series min_ramp is not less than max_ramp'

    assign = 0
    if value > max_ramp:
        assign = MAX_SHADE
    elif value > min_ramp:
        ramplen = max_ramp - min_ramp
        assign = int(((value - min_ramp)/(ramplen)) * MAX_SHADE)

    assert 0 <= assign <= MAX_SHADE
    return assign


def style_abs_ramp_down(sconfig, value):
    '''
    Determine a style index for a value and an absolute_ramp_down style.
    '''
    min_ramp = sconfig['min_ramp']
    max_ramp = sconfig['max_ramp']
    assert min_ramp <= max_ramp, 'Series min_ramp is not less than max_ramp'

    assign = 0
    if value < min_ramp:
        assign = MAX_SHADE
    elif value < max_ramp:
        ramplen = max_ramp - min_ramp
        assign = MAX_SHADE - int(((value - min_ramp)/ramplen) * MAX_SHADE)

    assert 0 <= assign <= MAX_SHADE
    return assign


def style_rel_ramp_up(sconfig, value, min_val, max_val):
    '''
    Determine a style index for a value and an relative_ramp_up style.
    '''
    min_ramp = sconfig['min_ramp']
    max_ramp = sconfig['max_ramp']
    assert min_ramp <= max_ramp, 'Series min_ramp is not less than max_ramp'

    delta = max_val - min_val
    relative = ((value - min_val) / delta)

    assign = 0
    if relative > max_ramp:
        assign = MAX_SHADE
    elif relative > min_ramp:
        ramplen = max_ramp - min_ramp
        assign = int(((relative - min_ramp)/(ramplen)) * MAX_SHADE)

    assert 0 <= assign <= MAX_SHADE
    return assign


def style_rel_ramp_down(sconfig, value, min_val, max_val):
    '''
    Determine a style index for a value and an relative_ramp_down style.
    '''
    min_ramp = sconfig['min_ramp']
    max_ramp = sconfig['max_ramp']
    assert min_ramp <= max_ramp, 'Series min_ramp is not less than max_ramp'

    delta = max_val - min_val
    relative = ((value - min_val) / delta)

    assign = 0
    if relative < min_ramp:
        assign = MAX_SHADE
    elif relative < max_ramp:
        ramplen = max_ramp - min_ramp
        assign = MAX_SHADE
        assign -= int(((relative - min_ramp)/(ramplen)) * MAX_SHADE)

    assert 0 <= assign <= MAX_SHADE
    return assign


def style_stdev_ramp_up(sconfig, value, mean, stdev):
    '''
    Determine a style index for a value and an stdev_ramp_up style.
    '''
    min_ramp = sconfig['min_ramp']
    max_ramp = sconfig['max_ramp']
    assert min_ramp <= max_ramp, 'Series min_ramp is not less than max_ramp'

    stdev_count = (value - mean) / stdev

    assign = 0
    if stdev_count > max_ramp:
        assign = MAX_SHADE
    elif stdev_count > min_ramp:
        ramplen = max_ramp - min_ramp
        assign = int(((stdev_count - min_ramp)/(ramplen)) * MAX_SHADE)

    assert 0 <= assign <= MAX_SHADE
    return assign


def style_stdev_ramp_down(sconfig, value, mean, stdev):
    '''
    Determine a style index for a value and an stdev_ramp_down style.
    '''
    min_ramp = sconfig['min_ramp']
    max_ramp = sconfig['max_ramp']
    assert min_ramp <= max_ramp, 'Series min_ramp is not less than max_ramp'

    stdev_count = (value - mean) / stdev

    assign = 0
    if stdev_count < min_ramp:
        assign = MAX_SHADE
    elif stdev_count < max_ramp:
        ramplen = max_ramp - min_ramp
        assign = MAX_SHADE
        assign -= int(((stdev_count - min_ramp)/(ramplen)) * MAX_SHADE)

    assert 0 <= assign <= MAX_SHADE
    return assign


def get_col_coord(col):
    '''
    Compute an Excel style column coordinate e.g. 0 => A.
    '''
    chars = []
    while True:
        col, rem = divmod(col, 26)
        chars.append(chr(ord('A') + rem))

        # Zero index the tens column
        col = col - 1
        if col < 0:
            break

    return ''.join(reversed(chars))


def get_cell_coord(col, row):
    '''
    Compute an Excel style cell coordinate e.g. [0,0] => A1.
    '''
    return f'{get_col_coord(col)}{row+1}'


def write_xlsx(output_file, config, data):
    '''
    Write the formatted spreadsheet output.
    '''
    with pd.ExcelWriter(output_file, engine='xlsxwriter') as writer:
        data.to_excel(writer, sheet_name='Sheet1', index=False)

        book = writer.book
        sheet = writer.sheets['Sheet1']

        # Create N shades of red style as for float data cell styles
        formats = []
        for i in range(0, MAX_SHADE_COUNT):
            fmt = book.add_format({
                'bg_color': f'#FF{255-i:02x}{255-i:02x}',
                'num_format': '0.00',
                'border': 1,
                'border_color': '#BBBBBB'
            })
            formats.append(fmt)

        # Create plain white style for basic data cell style
        plain_fmt = book.add_format({
            'bg_color': '#FFFFFF',
            'num_format': '0',
            'border': 1,
            'border_color': '#BBBBBB'
        })

        # Create plain grey style for table header cell style
        title_fmt = book.add_format({
            'bg_color': '#DDDDDD',
            'border': 1,
            'border_color': '#BBBBBB'
        })

        # Wrap the active data in a table for select and filtering
        max_row = len(data.iloc[:, 0]) + 1
        max_col = get_col_coord(len(data.keys()) - 1)
        table_cols = [{'header': x} for x in data.keys()]
        sheet.add_table(f'A1:{max_col}{max_row}',
                        {'style': None, 'columns': table_cols})

        # Apply column width autofit
        sheet.autofit()

        # Apply cell color styles
        for i, col_name in enumerate(data.keys()):
            sconfig = config.get_by_dst(col_name).yaml
            assert sconfig

            # Style the header
            sheet.write(get_cell_coord(i, 0), col_name, title_fmt)

            # Does this series column need styling?
            style = sconfig.get('style', None)

            # Default style for number formatting
            if not style:
                for j, value in enumerate(data[col_name]):
                    if data.dtypes.iloc[i] == np.float64:
                        sheet.write(get_cell_coord(i, j+1), value, formats[0])
                    else:
                        sheet.write(get_cell_coord(i, j+1), value, plain_fmt)

            elif style == 'absolute_ramp_up':
                for j, value in enumerate(data[col_name]):
                    assign = style_abs_ramp_up(sconfig, value)
                    sheet.write(get_cell_coord(i, j+1), value, formats[assign])

            elif style == 'absolute_ramp_down':
                for j, value in enumerate(data[col_name]):
                    assign = style_abs_ramp_down(sconfig, value)
                    sheet.write(get_cell_coord(i, j+1), value, formats[assign])

            elif style == 'relative_ramp_up':
                dmin = min(data[col_name].values)
                dmax = max(data[col_name].values)
                for j, value in enumerate(data[col_name]):
                    assign = style_rel_ramp_up(sconfig, value, dmin, dmax)
                    sheet.write(get_cell_coord(i, j+1), value, formats[assign])

            elif style == 'relative_ramp_down':
                dmin = min(data[col_name].values)
                dmax = max(data[col_name].values)
                for j, value in enumerate(data[col_name]):
                    assign = style_rel_ramp_down(sconfig, value, dmin, dmax)
                    sheet.write(get_cell_coord(i, j+1), value, formats[assign])

            elif style == 'stdev_ramp_up':
                mean = statistics.mean(data[col_name].values)
                stdev = statistics.stdev(data[col_name].values)
                for j, value in enumerate(data[col_name]):
                    assign = style_stdev_ramp_up(sconfig, value, mean, stdev)
                    sheet.write(get_cell_coord(i, j+1), value, formats[assign])

            elif style == 'stdev_ramp_up':
                mean = statistics.mean(data[col_name].values)
                stdev = statistics.stdev(data[col_name].values)
                for j, value in enumerate(data[col_name]):
                    assign = style_stdev_ramp_down(sconfig, value, mean, stdev)
                    sheet.write(get_cell_coord(i, j+1), value, formats[assign])


def parse_cli():
    '''
    Parse the command line.

    Returns:
        Return an argparse results object.
    '''
    parser = argparse.ArgumentParser(
        prog='sl-format',
        description='Streamline CLI tools profile pretty-printer')

    parser.add_argument(
        'pmu_file', metavar='PMU_FILE', type=str, nargs=1,
        help='the input CSV file for PMU sample data')

    parser.add_argument(
        '-s', '--spe-file', metavar='SPE_FILE', type=str, default=None,
        help='the input CSV file for SPE sample data')

    parser.add_argument(
        '-c', '--config', metavar='CONFIG_FILE', type=str, default=None,
        help='the input YAML configuration file')

    parser.add_argument(
        '-o', '--output', dest='output_file', type=str, default=None,
        help='the output XLSX file')

    parser.add_argument(
        '-l', '--list', action='store_true', default=False,
        help='show the columns in the input data files')

    args = parser.parse_args()

    args.pmu_file = args.pmu_file[0]

    if not args.config:
        script_dir = os.path.dirname(__file__)
        args.config = os.path.join(script_dir, DEFAULT_CONFIG)

    return args


def main():
    '''
    Main function
    '''
    print(f"sl-format {RELEASE_BUILD}")
    print(f"Copyright (c) {get_copy_yr()} Arm Limited. All rights reserved.\n")

    args = parse_cli()

    # Load input config and data files
    try:
        config = ConfigSet(args.config)
    except FileNotFoundError:
        printe(f'ERROR: No source file "{args.config}"')
        return 1
    except KeyError as ex:
        printe(str(ex))
        return 1

    if not args.list and not args.output_file:
        printe(f'ERROR: Output file must be specified')
        return 1

    if args.output_file and not args.output_file.endswith('.xlsx'):
        printe(f'ERROR: Output file must be *.xlsx')
        return 1

    raw_data = load_raw_data(config, args.pmu_file)
    if raw_data is None:
        return 1

    raw_spe_data = None
    if args.spe_file:
        raw_spe_data = load_raw_spe_data(config, args.spe_file)
        if raw_spe_data is None:
            return 1

    # Select and filter based on PMU data
    data, dst_name_primary = select_pmu_columns(config, raw_data)
    if dst_name_primary is None:
        return 1

    filter_rows(config, data)

    # Supplement with SPE data if available
    if args.spe_file:
        data = select_spe_columns(config, data, raw_spe_data, dst_name_primary)
        if data is None:
            return 1

    # Replace NaNs with zeros to keep xlsxwriter happy
    data.fillna(0, inplace=True)

    if args.list:
        print('PMU data')
        print('========')
        for key in raw_data.keys():
            print(f'  * {key}')
        print('\n')

        if raw_spe_data is not None:
            print('SPE data')
            print('========')
            for key in raw_spe_data.keys():
                print(f'  * {key}')
            print('\n')

    # Was this command list-only?
    if not args.output_file:
        return 0

    try:
        write_xlsx(args.output_file, config, data)
    except PermissionError:
        printe(f'ERROR: Could not open output file "{args.output_file}"')
        return 1

    return 0


if __name__ == '__main__':
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        print("ERROR: User interrupted execution")
