# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================


import os
import json
import onnx
from onnx import helper, TensorProto
import copy
from qti.aisw.converters.common.tensor_transforms.operator import Operator
from qti.aisw.converters.common import modeltools
from qti.aisw.converters.qnn_backend.ir_to_dlc import DLCBackend
from qti.aisw.converters.common.utils.converter_utils import log_info
from qti.aisw.lora.helpers import validate_file_path


class Transform:
    def __init__(self, src_tensors, dest_tensors, operator):
        """
        Initialize a Transform object that represents a tensor transformation.

        Args:
            src_tensors (list[str]): Names of the input tensors
            dest_tensors (list[str]): Names of the output tensors
            operator (Operator): Operator object containing the ONNX operator information
        """
        self.src_tensors = src_tensors
        self.dest_tensors = dest_tensors
        self.operator = operator

    def __str__(self):
        """String representation of the Transform."""
        return f"Transform(src_tensors={self.src_tensors}, " \
               f"dest_tensors={self.dest_tensors}, operator={self.operator})"


class TransformManager:
    class TransformGroup:
        """Inner class to hold a set of transforms"""
        def __init__(self):
            self.transforms = []  # List to store Transform objects
            self.input_tensors = set()  # Tensors that are only inputs (no producers)
            self.output_tensors = set()  # Tensors that are only outputs (no consumers)
            self.produced_tensors = set()  # All tensors that appear as destinations
            self.consumed_tensors = set()  # All tensors that are used as inputs

    def __init__(self):
        """
        Initialize a TransformManager to manage multiple groups of tensor transformations.
        Each group is identified by a string ID, with 'base' as the default group.
        """
        self.transform_groups = {}  # Dictionary mapping IDs to TransformGroup objects


    def _get_transform_group(self, group_id):
        """Get the TransformGroup for an ID or create it if it doesn't exist."""
        if group_id not in self.transform_groups:
            self.transform_groups[group_id] = self.TransformGroup()
        return self.transform_groups[group_id]


    def get_group_ids(self):
        """Get a list of all transform group IDs.

        Returns:
            list[str]: List of transform group IDs
        """
        return list(self.transform_groups.keys())


    def copy_group(self, source_id, new_id):
        """
        Copy an existing transform group with a new ID.

        Args:
            source_id (str): ID of the existing transform group to copy
            new_id (str): ID for the new copied transform group

        Raises:
            KeyError: If the source transform group doesn't exist
        """
        if source_id not in self.transform_groups:
            raise KeyError(f"Transform group '{source_id}' not found.")

        # Create a deep copy of the source transform group
        copied_state = copy.deepcopy(self.transform_groups[source_id])
        self.transform_groups[new_id] = copied_state

    def rename_group(self, source_id, new_id):
        """
        Rename an existing transform group with a new ID.

        Args:
            source_id (str): ID of the existing transform group to rename
            new_id (str): ID for the renamed transform group

        Raises:
            KeyError: If the source transform group doesn't exist
            KeyError: If the new transform group already exists
        """
        if source_id not in self.transform_groups:
            raise KeyError(f"Transform group '{source_id}' not found.")
        if new_id in self.transform_groups:
            raise KeyError(f"Transform group '{new_id}' already exists.")

        # Rename the transform group
        transform_group = self.transform_groups.pop(source_id)
        self.transform_groups[new_id] = transform_group


    def add_transform(self, src_tensors, dest_tensors, operator, group_id):
        """Add a new transformation to the manager and update input/output tensor sets."""
        transform_group = self._get_transform_group(group_id)
        transform = Transform(src_tensors, dest_tensors, operator)
        transform_group.transforms.append(transform)

        # Add any new source tensors as inputs if they've never been produced
        for src in src_tensors:
            if src not in transform_group.produced_tensors:
                transform_group.input_tensors.add(src)

        # Track all tensors that are produced as destinations
        transform_group.produced_tensors.update(dest_tensors)

        # If any destination tensor was previously an input, it's no longer an input
        transform_group.input_tensors.difference_update(dest_tensors)

        # Track which tensors are consumed as inputs
        transform_group.consumed_tensors.update(src_tensors)

        # For each destination tensor in this new transform
        for dest in dest_tensors:
            # Add it as an output initially
            transform_group.output_tensors.add(dest)
            # But remove it if it's already been consumed by another transform
            if dest in transform_group.consumed_tensors:
                transform_group.output_tensors.remove(dest)

        # Remove this transform's source tensors from outputs since they're now consumed
        transform_group.output_tensors.difference_update(src_tensors)

        return transform

    def get_metadata(self):
        """
        Get transformation metadata as a serializable dictionary for all transform groups.

        Returns:
            dict: A dictionary with format:
                {
                    "version": "1.0.0",
                    "transform_groups": {
                        id1: metadata_dict1,
                        id2: metadata_dict2,
                        ...
                    }
                }
                Each metadata_dict has the same format as before.
        """
        def parse_transform_group(transform_group, group_id):
            """Helper function to serialize a single transform set"""
            transforms_list = self._sort_transforms_topologically(group_id)
            transforms_data = []
            for transform in transforms_list:
                transform_dict = {
                    "src_tensors": transform.src_tensors,
                    "dest_tensors": transform.dest_tensors,
                    "operator": {
                        "op_type": transform.operator.op_type,
                        "attributes": transform.operator.attributes,
                        "input_shapes": transform.operator.input_shapes,
                        "output_shapes": transform.operator.output_shapes
                    }
                }
                transforms_data.append(transform_dict)

            return {
                "input_tensors": list(transform_group.input_tensors),
                "output_tensors": list(transform_group.output_tensors),
                "transforms": transforms_data
            }


        metadata_content = {}
        for group_id, transform_group in self.transform_groups.items():
            metadata_content[group_id] = parse_transform_group(transform_group, group_id)

        # Create the final metadata structure
        metadata = {
            "version": "1.0.0",
            "transform_groups": metadata_content
        }

        return metadata


    def save_metadata(self, output_path):
        """
        Save transformation metadata to a json file

        Args:
            output_path (str): Path to save the metadata JSON
        """

        # Ensure directory exists
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)

        metadata = self.get_metadata()
        with open(output_path, 'w') as fp:
            json.dump(metadata, fp, indent=4)
        log_info(f"metadata saved at: {output_path}")


    def load_metadata(self, metadata_path):
        """
        Import transforms from a metadata dictionary.

        Args:
            metadata_path (str): Path to json containing transform metadata in the format:
                {
                    "group_id": {
                        "transforms": [
                            {
                                "src_tensors": [...],
                                "dest_tensors": [...],
                                "operator": {
                                    "op_type": str,
                                    "attributes": {...},
                                    "input_shapes": [...],
                                    "output_shapes": [...]
                                }
                            },
                            ...
                        ],
                        "input_tensors": [...],
                        "output_tensors": [...]
                    },
                    ...
                }
        """
        def load_group(metadata_dict, group_id):
            """Helper function to load a single transform set"""
            # Clear existing transforms for this ID
            if group_id in self.transform_groups:
                self.remove_group(group_id)

            # Process each transform in the metadata
            for operator_data in metadata_dict["transforms"]:
                # Create Operator object from the metadata
                operator = Operator(
                    op_type=operator_data["operator"]["op_type"],
                    attributes=operator_data["operator"]["attributes"],
                    input_shapes=operator_data["operator"]["input_shapes"],
                    output_shapes=operator_data["operator"]["output_shapes"]
                )

                # Add the transform using existing method to maintain dependency tracking
                self.add_transform(
                    src_tensors=operator_data["src_tensors"],
                    dest_tensors=operator_data["dest_tensors"],
                    operator=operator,
                    group_id=group_id
                )

            # Set input and output tensors from metadata if they exist
            if "input_tensors" in metadata_dict and "output_tensors" in metadata_dict:
                transform_group = self._get_transform_group(group_id)
                transform_group.input_tensors = set(metadata_dict["input_tensors"])
                transform_group.output_tensors = set(metadata_dict["output_tensors"])

        with open(metadata_path, 'r') as fp:
            metadata = json.load(fp)

        for group_id, group_metadata in metadata['transform_groups'].items():
            load_group(group_metadata, group_id)


    def serialize_onnx(self, output_dir):
        """Serialize each transform group into an ONNX model using the ONNX helper API."""
        def create_node(transform):
            """Create an ONNX node from a Transform object"""
            return helper.make_node(
                transform.operator.op_type,
                inputs=transform.src_tensors,
                outputs=transform.dest_tensors,
                **transform.operator.attributes
            )

        # Process each transform group
        for group_id, transform_group in self.transform_groups.items():
            try:
                # Get all transforms in topological order
                sorted_transforms = self._sort_transforms_topologically(group_id)

                # Create map of tensor name to its shape for all tensors
                tensor_shapes = {}
                for transform in sorted_transforms:
                    # Map input tensor names to their shapes
                    for idx, tensor_name in enumerate(transform.src_tensors):
                        tensor_shapes[tensor_name] = transform.operator.input_shapes[idx]
                    # Map output tensor names to their shapes
                    for idx, tensor_name in enumerate(transform.dest_tensors):
                        tensor_shapes[tensor_name] = transform.operator.output_shapes[idx]

                # Create value infos for all tensors
                value_infos = [
                    helper.make_tensor_value_info(name, TensorProto.FLOAT, shape)
                    for name, shape in tensor_shapes.items()
                ]

                # Split value infos into inputs, outputs, and intermediates
                inputs = [vi for vi in value_infos if vi.name in transform_group.input_tensors]
                outputs = [vi for vi in value_infos if vi.name in transform_group.output_tensors]
                intermediates = [vi for vi in value_infos
                                if vi.name not in transform_group.input_tensors
                                and vi.name not in transform_group.output_tensors]

                # Create nodes from transforms in topological order
                nodes = [create_node(transform) for transform in sorted_transforms]

                # Create graph with value info for all tensors
                graph = helper.make_graph(
                    nodes=nodes,
                    name=f"transform_group_{group_id}",
                    inputs=inputs,
                    outputs=outputs,
                    value_info=intermediates,
                    initializer=[]
                )

                # Create model
                model = helper.make_model(graph, producer_name='transform_manager')

                # Save model
                output_path = os.path.join(output_dir, f"{group_id}.onnx")
                os.makedirs(os.path.dirname(output_path), exist_ok=True)
                onnx.save(model, output_path)
                log_info(f"model saved at: {output_path}")
            except Exception as e:
                log_info(f"Error serializing transform group '{group_id}': {str(e)}")


    def _sort_transforms_topologically(self, group_id):
        """Sort transforms in topological order using breadth-first search."""
        transform_group = self._get_transform_group(group_id)

        # Build a map of tensor names to the transforms that produce them
        tensor_producers = {}
        for transform in transform_group.transforms:
            for dest in transform.dest_tensors:
                tensor_producers[dest] = transform

        # Start with transforms that only use input tensors
        sorted_transforms = []
        seen_transforms = set()
        queue = []

        # Find initial transforms (those that only use input tensors)
        for transform in transform_group.transforms:
            if all(src in transform_group.input_tensors or src not in transform_group.produced_tensors
                  for src in transform.src_tensors):
                queue.append(transform)
                seen_transforms.add(transform)

        # BFS through the transform graph
        while queue:
            current = queue.pop(0)
            sorted_transforms.append(current)

            # Find transforms that consume the outputs of the current transform
            for dest in current.dest_tensors:
                for transform in transform_group.transforms:
                    if (transform not in seen_transforms and dest in transform.src_tensors):
                        # Only add if all prerequisites are met
                        if all(src in transform_group.input_tensors or
                              tensor_producers.get(src) in seen_transforms
                              for src in transform.src_tensors):
                            queue.append(transform)
                            seen_transforms.add(transform)

        return sorted_transforms


    def has_group(self, group_id):
        """
        Check if a transform group ID exists in the manager.

        Args:
            group_id (str): ID of the transform group to check

        Returns:
            bool: True if the group exists, False otherwise
        """
        return group_id in self.transform_groups


    def remove_group(self, group_id):
        """
        Remove a transform group ID from the manager.

        Args:
            group_id (str): ID of the transform group to remove

        Raises:
            KeyError: If the group_id doesn't exist
        """

        if group_id not in self.transform_groups:
            raise KeyError(f"Transform group '{group_id}' not found")
        del self.transform_groups[group_id]

    def save_od_transforms_dlc(self, dlc_paths, output_path):
        dlc_writer = modeltools.IrDlcSerializer(output_path, "", "", "", "")
        dlc_writer.initialize()

        for path in dlc_paths:
            validate_file_path(path)
            dlc_reader = modeltools.IrDlcReader()
            dlc_reader.open(path)
            graph_names = list(dlc_reader.get_ir_graph_names())
            for graph_name in graph_names:
                if graph_name.startswith("od_transform_"):
                    cpp_graph = dlc_reader.get_ir_graph(graph_name)
                    dlc_writer.serialize(cpp_graph)
        dlc_writer.finish()
        log_info("On-device transform DLC saved at: " + output_path)
