# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""Implements utility classes to handle AIMET encodings"""

from abc import ABCMeta, abstractmethod


class AimetEncodings(metaclass=ABCMeta):
    """Abstract base class for AIMET encodings"""

    def __init__(self, encodings: dict = {}) -> None:
        """Initializes the AimetEncodings object

        Args:
            encodings (Optional): Dictionary containing encoding data. Defaults to empty dict
        """
        self._dict = encodings

    @property
    @abstractmethod
    def encodings(self) -> dict:
        """Abstract property to return the internal encoding dictionary

        Returns:
            dict: The encoding dictionary
        """
        pass

    @property
    def version(self) -> str:
        """Return the version of the encodings that is initialized"""
        return self._dict["version"]

    @abstractmethod
    def get(self, name: str) -> dict | None:
        """Abstract method to retrieve encoding by name

        Args:
            name: Name of the parameter or activation

        Returns:
            dict | None: Encoding dictionary if found, else None
        """
        pass

    @abstractmethod
    def copy(self, original_name: str, new_name: str, value: dict) -> bool:
        """Abstract method to copy an encoding from one name to another

        Args:
            original_name: Original encoding name
            new_name: New encoding name
            value: Encoding value to copy

        Returns:
            bool: True if copy was successful, False otherwise
        """
        pass

    @abstractmethod
    def delete(self, name: str) -> bool:
        """Abstract method to delete an encoding by name

        Args:
            name: Name of the encoding

        Returns:
            bool: True if deletion was successful, False otherwise
        """
        pass

    @abstractmethod
    def _map_slices(self, original: str, slices: list[str]) -> bool:
        """Abstract method to map an encoding to multiple slices

        Args:
            original: Original encoding name
            slices: List of slice names

        Returns:
            bool: True if mapping was successful, False otherwise
        """
        pass


class AimetEncodingsV061(AimetEncodings):
    """Handles AIMET encodings for version 0.6.1"""

    def __init__(self, encodings: dict = {}) -> None:
        """Initializes the AimetEncodingsV061 object

        Args:
            encodings (Optional): AIMET encodings dictionary. Defaults to version 0.6.1 format
        """
        if not encodings:
            encodings = {
                "param_encodings": {},
                "activation_encodings": {},
                "quantizer_args": {},
                "excluded_layers": [],
                "version": "0.6.1",
            }
        super().__init__(encodings)
        self.param_enc = encodings["param_encodings"]
        self.act_enc = encodings["activation_encodings"]

    @property
    def encodings(self) -> dict:
        """Returns the encoding dictionary"""
        return self._dict

    def get(self, name: str) -> dict | None:
        """Retrieves encoding by name"""
        if (enc := self.param_enc.get(name)) or (enc := self.act_enc.get(name)):
            return enc

    def copy(self, original_name: str, new_name: str, value: dict) -> bool:
        """Copies an encoding from one name to another"""
        if original_name in self.param_enc:
            self._set_param_encodings(new_name, value)
        elif original_name in self.act_enc:
            self._set_activation_encodings(new_name, value)
        else:
            return False

        return True

    def _set_activation_encodings(self, name: str, value: dict):
        """Sets activation encoding"""
        self.act_enc[name] = value

    def _set_param_encodings(self, name: str, value: dict):
        """Sets parameter encoding"""
        self.param_enc[name] = value

    def delete(self, name) -> bool:
        """Deletes an encoding by name"""
        if name in self.param_enc:
            del self.param_enc[name]
        elif name in self.act_enc:
            del self.act_enc[name]
        else:
            return False

        return True

    def _map_slices(self, original: str, slices: list[str]) -> bool:
        """Maps an encoding to multiple slices"""
        enc = self.get(original)

        if enc is None:
            return False

        n_slices = len(slices)
        for slice in slices:
            n = int(slice.split(":")[-1])
            new_enc = enc.copy()

            if len(enc) > 1:
                n_dim = len(enc)
                head_dim = n_dim // n_slices
                start = n * head_dim
                end = min((n + 1) * head_dim, n_dim)
                new_enc = enc[start:end]

            self.copy(original, slice, new_enc)

        return True


class AimetEncodingsV100(AimetEncodings):
    """Handles AIMET encodings for version 1.0.0"""

    def __init__(self, encodings: dict = {}) -> None:
        """Initializes the AimetEncodingsV100 object

        Args:
            encodings (dict, optional): AIMET encodings dictionary. Defaults to version 1.0.0 format
        """
        if not encodings:
            encodings = {
                "param_encodings": [],
                "activation_encodings": [],
                "quantizer_args": {},
                "excluded_layers": [],
                "version": "1.0.0",
            }
        super().__init__(encodings)
        self.param_enc = {e["name"]: e for e in encodings["param_encodings"]}
        self.act_enc = {e["name"]: e for e in encodings["activation_encodings"]}

    @property
    def encodings(self) -> dict:
        """Returns the internal encoding dictionary"""
        self._dict["param_encodings"] = list(self.param_enc.values())
        self._dict["activation_encodings"] = list(self.act_enc.values())
        return self._dict

    def get(self, name: str) -> dict | None:
        """Retrieves encoding by name"""
        if (enc := self.param_enc.get(name)) or (enc := self.act_enc.get(name)):
            return enc

    def copy(self, original_name: str, new_name: str, value: dict) -> bool:
        """Copies an encoding from one name to another"""
        if original_name in self.param_enc:
            self._set_param_encodings(new_name, value)
        elif original_name in self.act_enc:
            self._set_activation_encodings(new_name, value)
        else:
            return False

        return True

    def _set_activation_encodings(self, name: str, value: dict):
        """Sets activation encoding"""
        value["name"] = name
        self.act_enc[name] = value

    def _set_param_encodings(self, name: str, value: dict):
        """Sets parameter encoding"""
        value["name"] = name
        self.param_enc[name] = value

    def delete(self, name) -> bool:
        """Deletes an encoding by name"""
        if name in self.param_enc:
            del self.param_enc[name]
        elif name in self.act_enc:
            del self.act_enc[name]
        else:
            return False

        return True

    def _map_slices(self, original: str, slices: list[str]) -> bool:
        """Maps an encoding to multiple slices"""
        enc = self.get(original)

        if enc is None:
            return False

        n_slices = len(slices)
        for slice in slices:
            n = int(slice.split(":")[-1])
            new_enc = enc.copy()

            if len(enc["scale"]) > 1:
                n_dim = len(enc["scale"])
                head_dim = n_dim // n_slices
                start = n * head_dim
                end = min((n + 1) * head_dim, n_dim)

                for key in ["scale", "offset"]:
                    new_enc[key] = enc[key][start:end]

                # NOTE: LPBQ only
                if int_scale := enc.get("per_block_int_scale", None):
                    per_slice = len(int_scale) // n_slices
                    start = n * per_slice
                    end = min((n + 1) * per_slice, len(int_scale))
                    new_enc["per_block_int_scale"] = int_scale[start:end]

            self.copy(original, slice, new_enc)

        return True


class AimetEncodingsFactory:
    """Factory class to create AimetEncodings instances based on version"""

    @staticmethod
    def from_dict(encodings: dict):
        """Creates an AimetEncodings instance from a dictionary

        Args:
            encodings: AIMET encodings dictinary

        Returns:
            AimetEncodings: AimetEncoidngs instance

        Raises:
            NotImplementedError: If the encodings version is unsupported
        """
        version = encodings["version"]
        if version == "1.0.0":
            return AimetEncodingsV100(encodings)
        elif version == "0.6.1":
            return AimetEncodingsV061(encodings)
        else:
            raise NotImplementedError(f"Unsupported AIMET encodings version: {version}")

    @staticmethod
    def from_version(version: str):
        """Creates an AimetEncodings instance with default values based on version

        Args:
            encodings: AIMET encodings dictinary

        Returns:
            AimetEncodings: AimetEncodings instance, with empty param_encodings and activation_encodings
        Raises:
            NotImplementedError: If the encodings version is unsupported

        """
        if version == "1.0.0":
            return AimetEncodingsV100()
        elif version == "0.6.1":
            return AimetEncodingsV061()
        else:
            raise NotImplementedError(f"Unsupported AIMET encodings version: {version}")
