# =============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All rights reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import psutil
from qti.aisw.tools.core.utilities.framework.utils.helper import Helper


class ActivationStatus:
    """Activation Status"""

    INITIALIZED = "INITIALIZED"
    SKIP = "SKIP"
    CONVERTER_FAILURE = "CONVERTER_FAILURE"
    OPTIMIZER_FAILURE = "OPTIMIZER_FAILURE"
    QUANTIZER_FAILURE = "QUANTIZER_FAILURE"
    OFFLINE_PREPARE_FAILURE = "OFFLINE_PREPARE_FAILURE"
    NET_RUN_FAILURE = "NET_RUN_FAILURE"
    CUSTOM_OVERRIDE_GENERATION_FAILURE = "CUSTOM_OVERRIDE_GENERATION_FAILURE"
    INFERENCE_FAILURE = "INFERENCE_FAILURE"
    INFERENCE_DONE = "INFERENCE_DONE"
    VERIFICATION_FAILURE = "VERIFICATION_FAILURE"
    SUCCESS = "SUCCESS"

    def __init__(self, activation_name, msg="initialize") -> None:
        self._current_status = ActivationStatus.INITIALIZED
        self._msg = msg
        self._activation_name = activation_name

    def set_status(self, status, msg):
        self._current_status = status
        self._msg = msg

    def get_status(self):
        return self._current_status

    def get_msg(self):
        return self._msg


@dataclass
class ActivationInfo:
    """Represents activation information of a tensor.

    Attributes:
        dtype (str): Data type of the tensor.
        shape (list[int]): Shape of the tensor.
        distribution (tuple[float, float, float]): Distribution of the tensor.
    """

    dtype: str
    shape: list[int]
    distribution: tuple[float, float, float]


def filter_snooping_report(
    snooping_report: pd.DataFrame, inference_data: dict[str, np.ndarray]
) -> pd.DataFrame:
    """Filters given snooping report and returns filtered report.

    Filtering is applied to below scenarios:
    1. Conv -> Relu
    2. Add -> Relu
    In both cases, target graphs dump Relu output for Conv/Add nodes, leading to inconsistencies
    between framework outputs or AIMET outputs. Conv/Add entries that match the subsequent
    Relu node will be removed from snooping report.

    Args:
        snooping_report: Snooping report dataframe
        inference_data: Inference outputs corresponding to each entry in Snooping report

    Returns:
        pd.DataFrame: A Dataframe containing filtered snooping report
    """
    remove_indexes = []
    for index in range(0, len(snooping_report.index) - 1):
        if (
            snooping_report["Layer Type"][index] in ["Conv2d", "Eltwise_Binary"]
            and snooping_report["Layer Type"][index + 1] == "ElementWiseNeuron"
        ):
            current_node_name = snooping_report["Source Name"][index]
            next_node_name = snooping_report["Source Name"][index + 1]
            current_node_data = inference_data[Helper.transform_node_names(current_node_name)]
            next_node_data = inference_data[Helper.transform_node_names(next_node_name)]

            if current_node_data.shape == next_node_data.shape:
                unique_data = np.unique(current_node_data == next_node_data)
                if len(unique_data) == 1 and unique_data[0] == True:
                    remove_indexes.append(index)

    return snooping_report.drop(labels=remove_indexes, axis=0)


def get_free_cpu_cores(threshold: float = 20.0) -> int:
    """Get cpu cores with utilization within threshold.

    Args:
        threshold (float, optional): cpu utilization percentage. Defaults to 20.0.

    Returns:
        int: Number of free cores.
    """
    # Get the CPU usage for each core
    usage_per_core = psutil.cpu_percent(percpu=True, interval=1)

    # Count how many cores are under the usage threshold
    free_cores = len([core for core in usage_per_core if core < threshold])

    return free_cores


def convert_raw_file(
    input_file: Path, input_dtype: str, target_dtype: str, output_file: Path = None
) -> Path:
    """This function loads data from a file, converts it to the target data type,
    and dumps it to a new file.

    Args:
        input_file (Path): The path to the input file.
        input_dtype (str): The data type of the input data.
        target_dtype (str): The target data type.
        output_file (Path): The path to the output file. If not provided,
            the output file will be created in the same directory as the input file.

    Returns:
        Path: The path to the output file.

    Raises:
        FileNotFoundError: If the input file doesn't exist.
        IOError: If there's an issue reading from or writing to files.
    """
    try:
        # Load the data from the input file
        data = np.fromfile(input_file, dtype=input_dtype)

        # Convert the data to the target data type
        converted_data = data.astype(target_dtype)

        # If the output file is not provided, create a new file name
        if output_file is None:
            output_file = input_file.parent / f"converted_{input_file.name}"

        # Dump the converted data to the output file
        converted_data.tofile(output_file)

        return output_file
    except FileNotFoundError:
        raise FileNotFoundError(f"Input file {input_file} not found")
    except Exception as e:
        raise IOError(f"Error converting file {input_file}: {str(e)}")


def convert_data(input_list: Path, user_provided_dtypes: list, output_dir: Path) -> Path:
    """This function converts all the tensors present in input_list such that they will be
    supported by converter. The converted tensors are dumped into new files.The paths of the new
    input tensors are stored in a list file created inside converted_inputs directory
    Args:
        input_list: input list provided
        user_provided_dtypes: List containing the datatypes of input tensors
        output_dir: Directory path to store the converted inputs
    Returns:
        Path to the new input list file
    Raises:
        FileNotFoundError: If the input_list file doesn't exist
        ValueError: If the number of input files doesn't match the dtypes provided
        IOError: If there's an issue reading from or writing to files
    """
    # Check if input_list file exists
    if not input_list.exists():
        raise FileNotFoundError(f"Input list file {input_list} not found")

    try:
        # Create a directory to dump the converted input files
        converted_input_file_dump_path = output_dir / "converted_calib_data"
        converted_input_file_dump_path.mkdir(parents=True, exist_ok=True)

        # Create a new input list file in the dump directory
        new_input_list_file_path = converted_input_file_dump_path / "input_list.txt"

        # Open the original and new input list files
        with open(input_list, "r") as old_file, open(new_input_list_file_path, "w") as new_file:
            # Iterate over each line in the original input list file
            for line in old_file:
                line = line.strip().split()
                if line:
                    if len(line) != len(user_provided_dtypes):
                        raise ValueError("Number of input files doesn't match the dtypes provided")
                    new_file_name_and_path = []
                    # Iterate over each file name and path in the line
                    for user_provided_dtype, file_name_and_path in zip(user_provided_dtypes, line):
                        file_name_and_path = (
                            file_name_and_path.split(":=")
                            if ":=" in file_name_and_path
                            else [None, file_name_and_path]
                        )
                        # If user provided dtype is None, set it to float32
                        if user_provided_dtype is None:
                            user_provided_dtype = "float32"
                        # Convert the tensor to 32 bit if necessary
                        target_dtype = "float32"
                        # Convert and dump the tensor to a new file
                        output_file = convert_raw_file(
                            Path(file_name_and_path[1]),
                            user_provided_dtype,
                            target_dtype,
                            converted_input_file_dump_path / Path(file_name_and_path[1]).name,
                        )

                        # Add the new file name and path to the new line
                        new_file_name_and_path.append(
                            (file_name_and_path[0] + ":=" if file_name_and_path[0] else "")
                            + str(output_file)
                        )
                    # Write the new line to the new input list file
                    new_file.write(" ".join(new_file_name_and_path) + "\n")

        # Update new input list file path
        return new_input_list_file_path
    except (IOError, OSError) as e:
        raise IOError(f"Error processing input list file: {str(e)}")
