# =============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All rights reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import json
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Optional

import pandas as pd
from qti.aisw.accuracy_debugger.argparser.compare_encodings_parser import CompareEncodingsParser
from qti.aisw.accuracy_debugger.argparser.framework_runner_parser import FrameworkRunnerParser
from qti.aisw.accuracy_debugger.argparser.inference_engine_parser import InferenceEngineParser
from qti.aisw.accuracy_debugger.argparser.model_snooper_parser import ModelSnooperParser
from qti.aisw.accuracy_debugger.argparser.tensor_visualizer_parser import TensorVisualizerParser
from qti.aisw.accuracy_debugger.argparser.verifier_parser import VerifierParser
from qti.aisw.accuracy_debugger.common_config import EncodingInputConfig
from qti.aisw.accuracy_debugger.compare_encodings.compare_encodings import CompareEncodings
from qti.aisw.accuracy_debugger.inference_engine.qairt_inference_engine import (
    InferenceEngine,
    InferenceEngineInputConfig,
)
from qti.aisw.accuracy_debugger.model_snooper_module import ModelSnooper, ModelSnooperInputConfig
from qti.aisw.accuracy_debugger.snooping.snooper import Snooper
from qti.aisw.accuracy_debugger.tensor_visualizer.tensor_visualizer import TensorVisualizer
from qti.aisw.accuracy_debugger.utils.file_utils import dump_csv
from qti.aisw.tools.core.utilities.framework.framework_manager import FrameworkManager
from qti.aisw.tools.core.utilities.framework.utils.helper import Helper
from qti.aisw.tools.core.utilities.qairt_logging.log_areas import LogAreas
from qti.aisw.tools.core.utilities.qairt_logging.logging_utility import QAIRTLogger
from qti.aisw.tools.core.utilities.verifier.verifier import Verifier


def execute_framework_runner(args: list) -> None:
    """Parses arguments for Framework runner and executes it.

    Args:
        args (list): List of arguments to be parsed.
    """
    # Parse arguments for framework runner.
    framework_args = FrameworkRunnerParser().parse(args)

    # Get logger for framework runner.
    logger = get_logger(
        logger_name="Framework Runner",
        level=framework_args.log_level.upper(),
    )

    # Create framework manager object.
    framework_manager = FrameworkManager(parent_logger=logger)

    # Load input model and input data.
    input_model = framework_manager.load(framework_args.input_model)
    input_data = Snooper._load_input_tensors(framework_args.input_sample)
    infer_shape = True if input_model.graph.value_info == [] else False
    # Generate intermediate outputs.
    reference_outputs = framework_manager.generate_intermediate_outputs(
        input_model,
        input_data,
        output_tensor_names=framework_args.output_tensor,
        infer_shape=infer_shape
    )
    output_dir = create_working_directory(
        working_dir=framework_args.working_directory, subdirectory="framework_runner"
    )

    Helper.save_output_to_file(reference_outputs, output_dir)


def execute_inference_engine(args: list) -> None:
    """Parses arguments for Inference engine and executes it.

    Args:
        args (list): Arguments to be provided for inference engine parser.
    """
    # Parse arguments for inference engine.
    inference_args = InferenceEngineParser().parse(args)

    # Create Logger using get_logger method for Inference Engine
    logger = get_logger(
        logger_name="Inference Engine",
        level=inference_args.log_level.upper(),
    )

    # Create inference engine object.
    inference_engine = InferenceEngine(logger=logger)

    # Remove log_level argument from inference engine args
    delattr(inference_args, "log_level")

    inference_args.working_directory = create_working_directory(
        working_dir=inference_args.working_directory, subdirectory="inference_engine"
    )
    try:
        # Create inference engine input config
        inference_input_config = InferenceEngineInputConfig(
            input_model=inference_args.input_model,
            converter_arguments=inference_args.converter_args,
            quantizer_arguments=inference_args.quantizer_args,
            backend=inference_args.backend,
            platform=inference_args.platform,
            context_bin_backend_extension=inference_args.offline_prepare_backend_extension_config,
            offline_prepare=inference_args.offline_prepare,
            net_run_arguments=inference_args.net_run_args,
            net_run_input_data=inference_args.input_list,
            net_run_backend_extension=inference_args.netrun_backend_extension_config,
            dump_output=True,
            remote_host_details=inference_args.remote_host_details,
            working_directory=inference_args.working_directory,
            context_bin_gen_arguments=inference_args.context_bin_args,
            soc_model=inference_args.soc_model,
        )

        # Run inference engine
        output = inference_engine.run_inference_engine(inference_input_config)
    except Exception as e:
        logger.error(f"Inference Engine Failed: {e}")
        raise e

    logger.info(f"Inference Engine Completed. Results path: {inference_args.working_directory}")


def execute_verification(args: list) -> None:
    """Parses arguments for Verifier and executes it.

    Args:
        args (list): List of arguments to be parsed.
    """
    # Parse arguments for Verifier
    verification_args = VerifierParser().parse(args)

    # Create logger object for Verifier
    logger = get_logger(
        logger_name="Verifier",
        level=verification_args.log_level.upper(),
    )
    logger.info("Starting Verifier...")

    if verification_args.graph_info:
        logger.info(f"Loading graph info from {verification_args.graph_info}...")
        with open(verification_args.graph_info, "r") as fp:
            verification_args.graph_info = json.load(fp)

    # Create Verifier object with required comparators
    logger.info(
        "Creating Verifier object with the following comparators: %s",
        ", ".join([comp.name for comp in verification_args.comparators]),
    )
    verifier_obj = Verifier(logger=logger, comparators=verification_args.comparators)

    # Execute verification
    logger.info(
        "Running verifier for the following tensors: \ninference_tensor: %s \nreference_tensor: %s",
        verification_args.inference_tensor,
        verification_args.reference_tensor,
    )
    verifier_output = verifier_obj.verify_directory_of_tensors(
        inference_tensors=verification_args.inference_tensor,
        inference_dtype=verification_args.inference_dtype,
        reference_tensors=verification_args.reference_tensor,
        reference_dtype=verification_args.reference_dtype,
        dlc_file=verification_args.dlc_file,
        graph_info=verification_args.graph_info,
        disable_layout_transform=verification_args.is_qnn_golden_reference,
    )

    output_dir = create_working_directory(
        working_dir=verification_args.working_directory, subdirectory="verification"
    )

    verifier_df = pd.DataFrame.from_dict(verifier_output, orient="index")
    verifier_df.index.names = ["inference_tensor_name", "reference_tensor_name"]
    output_csv = output_dir / "verification.csv"

    logger.info("Saving verification results to %s", output_csv)
    dump_csv(data_frame=verifier_df, csv_path=output_csv, index=True)

    logger.info("Finished Verifier execution.")


def execute_compare_encodings(args: list) -> None:
    """Parses arguments and runs the Compare encodings.

    Args:
        args (list): List of arguments to be parsed.
    """
    # Parse arguments for Compare encodings
    compare_encodings_args = CompareEncodingsParser().parse(args)

    # Create logger for Tensor Visualizer
    logger = get_logger(
        logger_name="Compare Encodings",
        level=compare_encodings_args.log_level.upper(),
    )

    # Create EncodingInputConfig
    encoding_config1 = EncodingInputConfig(
        encoding_file_path=compare_encodings_args.encoding1_file_path,
        quantized_dlc_path=compare_encodings_args.quantized_dlc1_path,
    )
    encoding_config2 = EncodingInputConfig(
        encoding_file_path=compare_encodings_args.encoding2_file_path,
        quantized_dlc_path=compare_encodings_args.quantized_dlc2_path,
    )

    output_dir = create_working_directory(
        working_dir=compare_encodings_args.working_directory, subdirectory="compare_encodings"
    )

    # Execute Compare encodings module
    logger.info("Running Compare Encodings...")
    CompareEncodings(logger).run(
        encoding_config1=encoding_config1,
        encoding_config2=encoding_config2,
        output_dir=output_dir,
        framework_model_path=compare_encodings_args.framework_model_path,
        scale_threshold=compare_encodings_args.scale_threshold,
    )
    logger.info("Finished Compare Encodings.")


def execute_tensor_visualizer(args: list) -> None:
    """Parses arguments and runs the Tensor Visualizer.

    Args:
        args (list): List of arguments to be parsed.
    """
    # Parse arguments for Tensor Visualizer
    tensor_visualizer_args = TensorVisualizerParser().parse(args)

    # Create logger for Tensor Visualizer
    logger = get_logger(
        logger_name="Tensor Visualizer",
        level=tensor_visualizer_args.log_level.upper(),
    )

    # Run Tensor Visualizer
    logger.info("Running Tensor Visualizer...")
    TensorVisualizer(logger).run(
        target_tensors=tensor_visualizer_args.target_tensors,
        golden_tensors=tensor_visualizer_args.golden_tensors,
        working_directory=tensor_visualizer_args.working_directory,
        datatype=tensor_visualizer_args.data_type,
    )

    logger.info("Finished Tensor Visualizer...")


def execute_model_snooper(args: list) -> None:
    """Parses arguments for Model snooper and executes it.

    Args:
        args (list): List of arguments to be parsed.
    """
    # Parse arguments for model snooper
    snooper_args = ModelSnooperParser().parse(args)

    # Create Logger using get_logger method for Model Snooper
    logger = get_logger(
        logger_name="Snooping",
        level=snooper_args.log_level.upper(),
    )

    # Create model snooper object
    model_snooper = ModelSnooper(logger=logger)

    # Remove log_level argument from snooper args
    delattr(snooper_args, "log_level")

    if snooper_args.working_directory is None:
        snooper_args.working_directory = Path.cwd() / "working_directory"
        snooper_args.working_directory.mkdir(parents=True, exist_ok=True)

    try:
        # Create model snooper input config object
        snooper_input_config = ModelSnooperInputConfig(
            input_model=snooper_args.input_model,
            input_sample=snooper_args.input_sample,
            algorithm=snooper_args.algorithm,
            converter_arguments=snooper_args.converter_args,
            quantizer_arguments=snooper_args.quantizer_args,
            context_bin_gen_arguments=snooper_args.context_bin_args,
            context_bin_backend_extension=snooper_args.offline_prepare_backend_extension_config,
            offline_prepare=snooper_args.offline_prepare,
            net_run_arguments=snooper_args.net_run_args,
            net_run_backend_extension=snooper_args.netrun_backend_extension_config,
            comparators=snooper_args.comparator,
            debug_subgraph_inputs=snooper_args.debug_subgraph_inputs,
            debug_subgraph_outputs=snooper_args.debug_subgraph_outputs,
            working_directory=snooper_args.working_directory,
            backend=snooper_args.backend,
            platform=snooper_args.platform,
            soc_model=snooper_args.soc_model,
            remote_host_details=snooper_args.remote_host_details,
            golden_reference_path=snooper_args.golden_reference,
            is_qnn_golden_reference=snooper_args.is_qnn_golden_reference,
            retain_compilation_artifacts=snooper_args.retain_compilation_artifacts,
            dump_output_tensors=True,
        )
        # Run Model Snooper
        output = model_snooper.run(snooper_input_config)
    except Exception as e:
        logger.error(f"Snooping Failed: {e}")
        raise e

    logger.info(f"Snooping Completed. Report generated at: {output.snooping_report}")


def get_logger(logger_name: str, level: str = "INFO") -> Logger:
    """Function to get a component specific logger

    Args:
        logger_name (str): Logger name.
        level (str, optional): Log level. Defaults to "INFO".

    Returns:
        Logger: Logger object.
    """
    log_area = LogAreas.register_log_area(logger_name)
    logger = QAIRTLogger.register_area_logger(
        area=log_area,
        level=level,
        formatter_val="simple",
        handler_list=["dev_console"],
    )
    return logger


def create_working_directory(subdirectory: str, working_dir: Optional[Path] = None) -> Path:
    """Create and return a timestamped working directory path.

    Args:
        working_dir (Path): Path to working directory.
        subdirectory (str): Subdirectory to be created.

    Returns:
        Path: The full path to the created working directory.
    """
    # If no working directory is provided, use the current working directory
    if working_dir is None:
        working_dir = Path.cwd() / "working_directory"
        working_dir.mkdir(parents=True, exist_ok=True)

    # If the provided working directory is not a directory, raise an error
    elif not Path(working_dir).exists():
        raise FileNotFoundError(f"Working directory {working_dir} does not exist.")

    working_dir = Path(working_dir) / subdirectory

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    working_dir = working_dir / timestamp
    working_dir.mkdir(parents=True, exist_ok=True)

    return working_dir
