# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides a pass for shape inference
"""
import copy
import os
from typing import List

import numpy as np
import onnx
import onnxscript.evaluator
from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BaseGraphRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.visitor import BaseTreeVisitor
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    clean_model_proto,
    convert_attr_to_py,
    convert_attrs_to_py,
    get_constant_np,
    get_shape_from_value_info_proto,
    get_value_numeric_shape,
    has_static_shape_on_value,
    iter_all_values,
    logger,
)


def serialize_model_without_data(model: ir.Model, size_threshold_bytes=128) -> onnx.ModelProto:
    """Serialize ir.Model into onnx.ModelProto, without large data
    The large data (larger than size_threshold_bytes) will be set to a dummy empty external file.
    modified from onnxscript.ir.external_data.unload_from_model and onnxscript.ir.save

    Args:
        model: model to serialize ir.Model
        size_threshold_bytes: size threshold, in byte
    Returns:
        proto: serialized protobuf
    """
    # Store the original initializer values so they can be restored
    initializer_values = tuple(model.graph.initializers.values())
    tensors = [v.const_value for v in initializer_values]

    try:
        # In-memory or external tensors, if equal to or above the threshold,
        # should be converted to or re-saved as external tensors

        for value in model.graph.initializers.values():
            if value.const_value is None:
                # Filter out the uninitialized initializer values
                continue
            if value.const_value.nbytes > size_threshold_bytes:
                dummy_external_tensor = ir.ExternalTensor(
                    os.path.normpath("./dummy_data_for_shape_infer.data"),
                    0,
                    value.const_value.nbytes,
                    value.dtype,  # type: ignore[arg-type]
                    shape=value.shape,  # type: ignore[arg-type]
                    name=value.name,  # type: ignore[arg-type]
                )
                value.const_value = dummy_external_tensor

        proto = ir.serde.serialize_model(model)

    finally:
        # Restore the original initializer values so the model is unchanged
        for initializer, tensor in zip(initializer_values, tensors, strict=True):
            initializer.const_value = tensor
    return proto


def serialize_model_with_folded_constant(model: ir.Model, with_data=True) -> onnx.ModelProto:
    """Serialize ir.Model into onnx.ModelProto, constant will be folded

    Args:
        model: model to serialize ir.Model
        with_data: whether to serialize with large data
    Returns:
        proto: serialized protobuf
    """
    if with_data:
        proto = ir.serde.serialize_model(model)
    else:
        proto = serialize_model_without_data(model)

    constant_values = {}
    # iterate for activations
    for node in model.graph:
        if node.op_type == "Constant":
            continue
        for output in node.outputs:
            if output.meta["extra_info"].infered_constant_value is not None:
                constant_values[output.name] = output.meta["extra_info"].infered_constant_value

    new_cst_nodes = []
    new_node_lists = []
    for node in proto.graph.node:
        # remove node whose output is constant
        all_output_cst = True
        for v_i, v in enumerate(node.output):
            if v in constant_values:
                cst_node = onnx.helper.make_node(
                    "Constant", [], [v], name=v+".cst", domain="",
                    value=onnx.numpy_helper.from_array(
                        constant_values[v], name=v)
                )
                new_cst_nodes.append(cst_node)
                # rename origin v to other name to avoid conflict
                # TODO, handle renaming properly
                node.output[v_i] = v + "_###.##_origin##_$"
            else:
                all_output_cst = False
        if not all_output_cst:
            new_node_lists.append(node)

    while len(proto.graph.node) > 0:
        proto.graph.node.pop()
    proto.graph.node.extend(new_cst_nodes)
    proto.graph.node.extend(new_node_lists)

    # after setting some nodes to constant,
    # some nodes maybe not used, so remove them

    proto = clean_model_proto(proto)

    return proto


class ShapeInference(BaseGraphRewriter):
    """Pass to inference shape
    """

    def __init__(self, graph: ir.Graph, model: ir.Model):
        """Initialize an instance of ShapeInference

        Args:
            graph: graph
            model: model ir
        """
        super().__init__(graph)
        self.model = model

    def infer_by_onnx(self): # pylint: disable=R0912,R0914
        """Shape inference by onnx.shape_inference.infer_shapes,
        constant will be firstly foled before the inference

        """
        origin_value_infos = {}
        inferred_value_infos = {}

        # convert model into proto, and fold constants to help onnx infer shape
        proto = serialize_model_with_folded_constant(
            self.model, with_data=False)

        origin_value_infos = {
            info.name: info for info in proto.graph.value_info
        }
        origin_value_infos.update(
            {info.name: info for info in proto.graph.output}
        )
        inferred_proto = onnx.shape_inference.infer_shapes(
            proto, check_type=True, strict_mode=False, data_prop=True,
        )
        inferred_value_infos = {
            info.name: info for info in inferred_proto.graph.value_info
        }
        inferred_value_infos.update(
            {info.name: info for info in inferred_proto.graph.output}
        )

        # check which tensors should be updated shape
        shape_to_update = {}
        type_to_update = {}
        for name, value_info in inferred_value_infos.items():
            new_shape = get_shape_from_value_info_proto(
                value_info, allow_symbols=False)
            if new_shape is None:
                continue

            if name not in origin_value_infos:
                shape_to_update[name] = new_shape
                type_to_update[name] = ir.serde.deserialize_type_proto_for_type(
                    value_info.type)
            else:
                origin_value_info = origin_value_infos[name]
                origin_shape = get_shape_from_value_info_proto(
                    origin_value_info, allow_symbols=False)

                if origin_shape is not None and new_shape is not None:
                    if origin_shape != new_shape:
                        raise ValueError(
                            f"mismatch shape at {name}, infered {new_shape}, origin {new_shape}")
                elif origin_shape is None and new_shape is not None:
                    shape_to_update[name] = new_shape

                value_type = ir.serde.deserialize_type_proto_for_type(value_info.type) \
                    if value_info.type is not None else None
                origin_value_type = ir.serde.deserialize_type_proto_for_type(origin_value_info.type) \
                    if value_info.type is not None else None
                if value_type is not None and origin_value_type is not None:
                    if value_type != origin_value_type:
                        raise ValueError(
                            f"mismatch type at {name}, infered {value_type}, origin {origin_value_type}")
                elif value_type is not None:
                    type_to_update[name] = value_type

        # update the shapes
        updated_counts = 0
        for value in iter_all_values(self.graph):
            updated = False
            if value.name in shape_to_update:
                value.shape = ir.Shape(shape_to_update[value.name])
                updated = True
            if value.name in type_to_update:
                value.type = type_to_update[value.name]
                updated = True
            if updated:
                updated_counts += 1
        return updated_counts

    def apply(self):
        """Inference the shape on the whole graph
        Iteratively do the constant propagation and shape inference 
        until no extra information(shape/type/constant value) can be infered.

        """
        loop_counts_max = 100
        loop_i = 0
        while loop_i < loop_counts_max:
            # step1: propagate some constants, such as some nodes that construct a shape
            ConstantPropagation(self.graph, self.model).apply()
            # step2: infer shape by onnx
            updated_counts = self.infer_by_onnx()
            # step3: infer shape on specific ops manually
            specific_op_infer = ShapeInferForSpecificOp(self.graph)
            specific_op_infer.apply()
            updated_counts += specific_op_infer.updated_count

            if updated_counts == 0:
                break
            loop_i += 1
        if loop_i == loop_counts_max:
            logger.warning("shape inference failed, loop counts max reached")
        return 1


class ShapeInferForSpecificOp(BaseTreeVisitor):

    """Pass to inference shape on specific Ops

    for example, currently onnx.shape_inference.infer_shapes dosen't handle QDQ nodes
    so we need to handle it manually.

    Also, the shape inference on self defined Op can be handled in this pass.
    """

    def __init__(self, graph):
        """Initialize an instance of ShapeInferForSpecificOp

        Args:
            graph: graph
        """
        super().__init__(graph)
        self.updated_count = 0

    def visit_node_DequantizeLinear(self, node):  # pylint: disable=C0103
        """Inference the shape on the DequantizeLinear op
        """
        updated = False
        if node.outputs[0].shape is None and node.inputs[0].shape is not None:
            node.outputs[0].shape = copy.deepcopy(node.inputs[0].shape)
            updated = True

        if node.outputs[0].dtype is None:
            output_dtype = ir.DataType(node.attributes.get("output_dtype", 0))
            if output_dtype == ir.DataType(0):
                # If not supplied, the output data type is inferred from x_scale
                output_dtype = node.inputs[1].dtype
            node.outputs[0].type = ir.TensorType(output_dtype)
            updated = True

        if updated:
            self.updated_count += 1

    def visit_node_QuantizeLinear(self, node):  # pylint: disable=C0103
        """Inference the shape on the QuantizeLinear op
        """
        updated = False
        if node.outputs[0].shape is None and node.inputs[0].shape is not None:
            node.outputs[0].shape = copy.deepcopy(node.inputs[0].shape)
            updated = True

        if node.outputs[0].dtype is None:
            output_dtype = ir.DataType(node.attributes.get("output_dtype", 0))
            if output_dtype == ir.DataType(0):
                # If not supplied, the output data type is inferred from y_zero_point
                output_dtype = node.inputs[2].dtype
            node.outputs[0].type = ir.TensorType(output_dtype)
            updated = True

        if updated:
            self.updated_count += 1

    def visit_node_GroupSlice(self, node: ir.Node):  # pylint: disable=C0103
        """Inference the shape on the GroupSlice op
        """
        updated = False
        assert node.inputs[0] is not None  # check for mypy
        for i, output in enumerate(node.outputs):
            if output.shape is not None and output.type is not None:
                return
            if output.shape is None and node.inputs[0].shape is not None:
                input_shape = get_value_numeric_shape(node.inputs[0])
                output_shape = list(input_shape)[:]
                axis = convert_attr_to_py(node.attributes["axis"], "as_int")
                start = convert_attr_to_py(node.attributes["starts"], "as_ints")[i]
                end = convert_attr_to_py(node.attributes["ends"], "as_ints")[i]
                output_shape[axis] = end - start
                output.shape = ir.Shape(output_shape)
                updated = True
            if output.type is None and node.inputs[0].dtype is not None:
                output.type = ir.TensorType(node.inputs[0].dtype)
                updated = True
        if updated:
            self.updated_count += 1

    def visit_node_FastHadamardTransform(self, node: ir.Node):  # pylint: disable=C0103
        """Inference the shape on the FastHadamardTransform op
        """
        updated = False
        if node.inputs[0] is None:
            return
        
        if node.outputs[0].shape is None and node.inputs[0].shape is not None:
            node.outputs[0].shape = copy.deepcopy(node.inputs[0].shape)
            updated = True
        if node.outputs[0].type is None and node.inputs[0].type is not None:
            node.outputs[0].type = copy.deepcopy(node.inputs[0].type)
            updated = True
        if updated:
            self.updated_count += 1

def get_constant_inputs(node: ir.Node):
    """Get constant inputs to a node, if possible

    Args:
        node: the target node
    Returns:
        all_inputs_cst: whether all inputs are constant
        inputs_const_values: the constant values of the inputs (None for none constant input)
    """
    inputs_const_values = []
    all_inputs_cst = True
    for x in node.inputs:
        if x is not None:
            cst = x.meta["extra_info"].infered_constant_value
            if cst is None and node.graph is not None and x.name in node.graph.initializers:
                cst = get_constant_np(x)

            inputs_const_values.append(cst)
            if cst is None:
                all_inputs_cst = False
        else:
            inputs_const_values.append(None)

    return all_inputs_cst, inputs_const_values


def set_shape_dtype(value: ir.Value, shape: List[int], dtype: np.dtype | ir.DataType | None):
    """Set shape and dtype of a value

    Args:
        Value: the target value
        shape: the shape to set
        dtype: the datatype to set

    Returns: None
    """

    if isinstance(dtype, np.dtype):
        tensor_type = ir.TensorType(ir.DataType.from_numpy(dtype))
    elif isinstance(dtype, ir.DataType):
        tensor_type = ir.TensorType(dtype)
    elif isinstance(dtype, ir.TensorType):
        tensor_type = dtype
    else:
        raise ValueError(f"unhandled type {dtype}")

    origin_tensor_type = None
    if value.type is not None:
        origin_tensor_type = value.type
    if origin_tensor_type is not None and tensor_type is not None:
        if origin_tensor_type != tensor_type:
            raise ValueError(
                f"infered type {tensor_type} mismatch with original shape {origin_tensor_type}"
                f" at '{value.name}'"
            )
    elif origin_tensor_type is None:
        value.type = tensor_type

    if not has_static_shape_on_value(value):
        value.shape = ir.Shape(shape)
    else:
        origin_shape = get_value_numeric_shape(value)
        if tuple(origin_shape) != tuple(shape):
            raise ValueError(
                f"infered shape {tuple(shape)} mismatch with original shape {tuple(origin_shape)}"
                f" at '{value.name}'"
            )


def set_infered_cst_value(value: ir.Value, value_np: np.ndarray):
    """Set constant to a value

    Args:
        Value: the target value
        value_np: the numpy array to set

    Returns: None
    """
    value.meta["extra_info"].infered_constant_value = value_np
    set_shape_dtype(value, list(value_np.shape), value_np.dtype)


class ConstantPropagation(BaseTreeVisitor):
    """Pass to propagate constant
    """
    def __init__(self, graph:ir.Graph, model_ir: ir.Model):
        """Initialize an instance of ConstantPropagation

        Args:
            graph: graph
            model_ir: model ir
        """
        super().__init__(graph)
        self.model_ir = model_ir

    def visit_general_node(self, node):
        """General method to propagate constant on the node
        if no specific method exists for the node, this method will be called

        Args:
            node: shape Op
        """
        constant_value = get_constant_np(node.outputs[0])
        if constant_value is not None:
            # constant node
            set_infered_cst_value(node.outputs[0], constant_value)
        else:
            # if all inputs have infered value
            # then using onnxscript's evaluator to infer the output value
            all_inputs_cst, inputs_const_values = get_constant_inputs(node)
            if all_inputs_cst:
                has_schema = onnx.defs.has(node.op_type,
                                           self.model_ir.opset_imports[node.domain],
                                           domain=node.domain)
                if not has_schema:
                    return
                op_schema = onnx.defs.get_schema(
                    node.op_type, self.model_ir.opset_imports[node.domain],
                    domain=node.domain)
                attributes_py = convert_attrs_to_py(node.attributes)
                outputs_cst = onnxscript.evaluator.default().eval(
                    op_schema, inputs_const_values, attributes=attributes_py)
                if len(node.outputs) == 1:
                    outputs_cst = [outputs_cst]
                assert len(outputs_cst) == len(node.outputs)
                for output_cst, output_v in zip(outputs_cst, node.outputs):
                    set_infered_cst_value(output_v, output_cst.value)

    def visit_node_Shape(self, node: ir.Node): # pylint: disable=C0103
        """Constant propagation for Shape op

        Args:
            node: shape Op
        """
        input_v = node.inputs[0]
        assert input_v is not None # check for mypy, definitely true
        if has_static_shape_on_value(input_v):
            shape = get_value_numeric_shape(input_v)
            end = None
            start = None
            if "end" in node.attributes:
                end = convert_attr_to_py(node.attributes["end"], "as_int")
            if "start" in node.attributes:
                start = convert_attr_to_py(node.attributes["start"],
                                                 "as_int")
            sliced_shape = np.array(shape, dtype=np.int64)[slice(start, end)]
            set_infered_cst_value(node.outputs[0], sliced_shape)
        else:
            self.visit_general_node(node)
