# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
Helper functions related to encodings
"""
import enum
import json
from typing import Dict, List

import numpy as np
from onnxscript import ir


class EncType(enum.Enum):
    """
    Encodings Type 
    """
    PER_CHANNEL = "PER_CHANNEL"
    PER_TENSOR = "PER_TENSOR"
    LPBQ = "LPBQ"


class EncKind(enum.Enum):
    """
    Encodings Kind, param or activation 
    """
    PARAM = "param_encodings"
    ACTIVATION = "activation_encodings"


class GraphEncodingInfo:
    """
    Encodings of the whole graph
    """

    def __init__(self):
        self.param_encodings: Dict[str, TensorEncodingInfo] = {}
        self.activation_encodings: Dict[str, TensorEncodingInfo] = {}
        self.quantizer_args = None

    def add_tensor_encodings(self, name, encodings):
        """
        Add tensor encodings to the graph encoding info

        Args:
            name: name of the tensor
            encodings: TensorEncInfo
        """
        if name in self.param_encodings:
            assert False, f"already added {name}"

        if name in self.activation_encodings:
            assert False, f"already added {name}"

        if encodings.enc_kind == EncKind.PARAM:
            self.param_encodings[name] = encodings
        elif encodings.enc_kind == EncKind.ACTIVATION:
            self.activation_encodings[name] = encodings
        else:
            assert False

    def get_tensor_encodings(self, name):
        """
        Get tensor encodings by name

        Args:
            name: name of the tensor
        Returns:
            TensorEncInfo
        """
        if name in self.param_encodings:
            return self.param_encodings[name]
        if name in self.activation_encodings:
            return self.activation_encodings[name]

        raise ValueError(f"cannot find tensor encodings {name}")


class TensorEncodingInfo: # pylint: disable=[too-many-instance-attributes]
    """
    Encodings of the tensor
    """
    # pylint: disable=[too-many-arguments,too-many-positional-arguments,redefined-builtin]
    def __init__(self,
                 enc_type:EncType,
                 bw:int,
                 dtype:str,
                 is_sym:bool,
                 offset:np.ndarray,
                 scale:np.ndarray,
                 max:np.ndarray|None,
                 min:np.ndarray|None,
                 enc_kind:EncKind):
        assert isinstance(enc_type, EncType)
        assert isinstance(enc_kind, EncKind)

        self.enc_kind = enc_kind
        self.enc_type = enc_type
        self.bw = bw
        self.dtype = dtype
        self.is_sym = is_sym
        self.offset:np.ndarray = offset
        self.scale:np.ndarray = scale

        # optional
        self.max:np.ndarray|None = max
        self.min:np.ndarray|None = min

        # extra attributes for LPBQ
        self.compressed_bw:int|None = None
        self.block_size:int|None = None
        self.per_block_int_scale:np.ndarray|None = None

        # channel_axis and block_axis should be infered automatically
        # by calling infer_encodings_chn_block_axis
        self.channel_axis = 0
        self.block_axis = 1

        self.is_signed:bool|None = None  # optional

    def __eq__(self, other):
        for attr in self.__dict__:
            v = getattr(self, attr)
            other_v = getattr(other, attr)
            if isinstance(v, np.ndarray) and isinstance(other_v, np.ndarray):
                if v.shape != other_v.shape:
                    return False
                if not (v == other_v).all():
                    return False
            elif getattr(self, attr) != getattr(other, attr):
                return False
        return True


def _get_field_as_np_array(dict_json, name, default=None):
    """
    Get json field as np array
    """
    if name not in dict_json:
        return default
    if dict_json[name] is None:
        return default
    return np.array(dict_json[name])


def load_encodings(encodings_src: str) -> GraphEncodingInfo:
    """
    Load encodings
    """
    with open(encodings_src, "r", encoding="utf-8") as f:
        src_enc = json.load(f)
    return deserialize_encodings(src_enc)


def deserialize_encodings(src_enc: Dict) -> GraphEncodingInfo:
    """
    Deserialize json encodings to GraphEncodingInfo
    """
    if src_enc["version"] == "0.6.1":
        src_enc = convert_v0_6_1_to_v1(src_enc)
    assert src_enc["version"] == "1.0.0"

    graph_enc = GraphEncodingInfo()
    graph_enc.quantizer_args = src_enc["quantizer_args"]

    for e_type in [EncKind.PARAM, EncKind.ACTIVATION]:
        for t_enc in src_enc[e_type.value]:
            name = t_enc["name"]
            new_t_enc = TensorEncodingInfo(
                enc_type=EncType[t_enc["enc_type"]],
                bw=t_enc["bw"],
                dtype=t_enc["dtype"],
                is_sym=t_enc.get("is_sym", None),
                offset=_get_field_as_np_array(t_enc, "offset", None),
                scale=_get_field_as_np_array(t_enc, "scale", None),
                max=_get_field_as_np_array(t_enc, "max", None),
                min=_get_field_as_np_array(t_enc, "min", None),
                enc_kind=e_type
            )

            if t_enc["enc_type"] == EncType.LPBQ.value:
                new_t_enc.per_block_int_scale = np.array(
                    t_enc["per_block_int_scale"])
                new_t_enc.block_size = t_enc["block_size"]
                new_t_enc.compressed_bw = t_enc["compressed_bw"]
            getattr(graph_enc, e_type.value)[name] = new_t_enc
    return graph_enc

def infer_encodings_chn_block_axis(tensor_enc:TensorEncodingInfo, tensor:ir.Value) -> None:
    """
    Infer encodings' channel axis and block axis for the given tensor
    """
    if tensor_enc.enc_type == EncType.PER_TENSOR:
        return

    used_as_matmul_weights = [user.op_type == "MatMul" and arg_id == 1 for user,arg_id in tensor.uses()]
    if tensor.shape is None:
        raise ValueError(f"cannot get the shape of {tensor.name}, shape inference should be called firstly")
    shape = tensor.shape.numpy()
    rank = len(shape)
    if any(used_as_matmul_weights):
        tensor_enc.block_axis = rank - 2
        tensor_enc.channel_axis = rank - 1
    else:
        # for other cases, for example conv/rmsnorm,
        # we assume the first dim is the channel dim
        # and the second dim is the block dim
        tensor_enc.block_axis = 1
        tensor_enc.channel_axis = 0



def serialize_graph_encodings(encodings: GraphEncodingInfo) -> Dict:
    """
    Serialize GraphEncodingInfo to json encodings 
    """
    data: Dict[str, str | List | Dict | int | bool | None] = {
        "quantizer_args": encodings.quantizer_args,
        "activation_encodings": [],
        "param_encodings": [],
        "version": "1.0.0"
    }

    for e_type in ["activation_encodings", "param_encodings"]:
        for name, t_enc in getattr(encodings, e_type).items():
            curr_enc = serialize_tensor_encodings(t_enc, name)
            data[e_type].append(curr_enc)  # type: ignore
    return data


def save_encodings(encodings: GraphEncodingInfo, dst_path: str):
    """
    Save GraphEncodingInfo to json file 
    """
    data = serialize_graph_encodings(encodings)

    with open(dst_path, "w", encoding="utf-8") as f:
        # json.dump(data, f, indent=4)
        f.write(json.dumps(data, indent=4))


def serialize_tensor_encodings(t_enc: TensorEncodingInfo, name: str) -> Dict:
    """
    Serialize TensorEncodingInfo to json format
    """
    curr_enc = {
        "name": name,
        "enc_type": t_enc.enc_type.value,
        "bw": t_enc.bw,
        "dtype": t_enc.dtype,
    }
    if t_enc.is_sym is not None:
        curr_enc["is_sym"] = t_enc.is_sym
    if t_enc.offset is not None:
        curr_enc["offset"] = t_enc.offset.tolist()
    if t_enc.scale is not None:
        curr_enc["scale"] = t_enc.scale.tolist()
    if t_enc.max is not None:
        curr_enc["max"] = t_enc.max.tolist()
    if t_enc.min is not None:
        curr_enc["min"] = t_enc.min.tolist()
    if t_enc.enc_type == EncType.LPBQ:
        # type: ignore
        assert isinstance(t_enc.per_block_int_scale, np.ndarray)
        curr_enc["per_block_int_scale"] = t_enc.per_block_int_scale.tolist()
        curr_enc["block_size"] = t_enc.block_size
        curr_enc["compressed_bw"] = t_enc.compressed_bw
    return curr_enc


def convert_v0_6_1_to_v1(src_enc: Dict) -> Dict:
    """
    Convert v0.6.1 aimet encodings to v1 encodings
    """
    if src_enc["version"] == "1.0.0":
        return src_enc
    assert src_enc["version"] == "0.6.1"
    dst_enc = {
        "version": "1.0.0"
    }

    is_symmetric_map = {
        "False": False,
        "True": True
    }

    for e_type in ["activation_encodings", "param_encodings"]:
        dst_enc[e_type] = []  # type: ignore
        for name, t_enc in src_enc[e_type].items():
            new_t_enc = {
                "bw": t_enc[0]["bitwidth"],
                "dtype": t_enc[0]["dtype"].upper(),
                "name": name,
            }
            if "is_symmetric" in t_enc[0]:
                new_t_enc["is_sym"] = is_symmetric_map[t_enc[0]["is_symmetric"]]
            if "offset" in t_enc[0]:
                new_t_enc["offset"] = [x["offset"] for x in t_enc]
            if "scale" in t_enc[0]:
                new_t_enc["scale"] = [x["scale"] for x in t_enc]
            if "max" in t_enc[0]:
                new_t_enc["max"] = [x["max"] for x in t_enc]
            if "min" in t_enc[0]:
                new_t_enc["min"] = [x["min"] for x in t_enc]
            
            if "offset" in new_t_enc and len(new_t_enc["offset"]) > 1:
                new_t_enc["enc_type"] = "PER_CHANNEL"
            else:
                new_t_enc["enc_type"] = "PER_TENSOR"
            dst_enc[e_type].append(new_t_enc)  # type: ignore

    dst_enc["quantizer_args"] = src_enc["quantizer_args"]
    return dst_enc
