# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================

from qti.aisw.converters.common import ir_graph
from qti.aisw.converters.common.converter_ir import op_adapter
from qti.aisw.converters.common.converter_ir.op_graph import QuantUpdatableMode
from qti.aisw.converters.common.utils.converter_utils import log_debug


class GraphPropertySetter(object):
    """
    Stateless utility class for setting graph properties.

    Methods:
        set_graph_properties(self, graph, quant_updatable_mode, lora_tensor_names):
            Infers and sets graph properties based on quant updatable mode and
            list of lora tensor names.

        copy_ir_graph_properties(self, py_graph, cpp_graph):
            Copy graph properties from an IrGraph instance to IrOpGraph instance.
    """
    def set_graph_properties(self, graph, quant_updatable_mode, lora_tensor_names):
        """
        Given mode and tensor names, sets graph properties related to which
        tensor's values and/or quant encodings are updatable. Currently used for
        lora models.
        """
        self._set_updatable(graph, lora_tensor_names)
        self._set_quant_updatable_flags(graph, quant_updatable_mode, lora_tensor_names)
        self._validate_graph_properties(graph)

    def copy_ir_graph_properties(self, py_graph, cpp_graph):
        """
        Copy graph properties from an IrGraph instance to IrOpGraph instance.
        """
        py_graph.update_graph_property_mask(ir_graph.IR_GRAPH_PROPERTY_MASK_NO_QUANT_UPDATABLE, cpp_graph.is_no_quant_updateable())
        py_graph.update_graph_property_mask(ir_graph.IR_GRAPH_PROPERTY_MASK_UPDATABLE, cpp_graph.is_updateable())
        py_graph.update_graph_property_mask(ir_graph.IR_GRAPH_PROPERTY_MASK_ADAPTER_ONLY_QUANT_UPDATABLE, cpp_graph.is_adapter_only_quant_updateable())
        self._validate_graph_properties(py_graph)

    def _set_updatable(self, graph, lora_tensor_names):
        is_updatable = len(lora_tensor_names) > 0
        graph.update_graph_property_mask(ir_graph.IR_GRAPH_PROPERTY_MASK_UPDATABLE, is_updatable)

    def _set_quant_updatable_flags(self, graph, quant_updatable_mode, lora_tensor_names):
        def set_no_quant_updatable(graph):
            graph.update_graph_property_mask(ir_graph.IR_GRAPH_PROPERTY_MASK_NO_QUANT_UPDATABLE, True)

        def set_adapter_only_quant_updatable(graph, quant_updatable_mode, lora_tensor_names):
            if quant_updatable_mode is None:
                # backward compatible assignment if no mode given
                is_adapter_only_quant_updatable = self.has_native_and_static_tensors(graph, lora_tensor_names)
            elif quant_updatable_mode == QuantUpdatableMode.ADAPTER_ONLY:
                if not self.has_native_and_static_tensors(graph, lora_tensor_names):
                    raise ValueError(f"Quant updatable mode, {quant_updatable_mode.value}, "
                                     "expects both lora weights and activation tensors in --lora_weight_list.")
                is_adapter_only_quant_updatable = True
            else:
                is_adapter_only_quant_updatable = False

            graph.update_graph_property_mask(ir_graph.IR_GRAPH_PROPERTY_MASK_ADAPTER_ONLY_QUANT_UPDATABLE, is_adapter_only_quant_updatable)

            if is_adapter_only_quant_updatable:
                log_debug("LoRA graph has adapter only updatable quantization encodings.")
            else:
                log_debug("LoRA graph does not have adapter only updatable quantization encodings.")

        if quant_updatable_mode is None:
            set_adapter_only_quant_updatable(graph, quant_updatable_mode, lora_tensor_names)
        elif quant_updatable_mode == QuantUpdatableMode.NONE:
            self._validate_has_only_static_tensors(graph, quant_updatable_mode, lora_tensor_names)
            set_no_quant_updatable(graph)
        elif quant_updatable_mode == QuantUpdatableMode.ADAPTER_ONLY:
            set_adapter_only_quant_updatable(graph, quant_updatable_mode, lora_tensor_names)
        elif quant_updatable_mode == QuantUpdatableMode.ALL:
            self._validate_has_only_static_tensors(graph, quant_updatable_mode, lora_tensor_names)
        else:
            raise ValueError(f"Quant updatable mode, {quant_updatable_mode}, "
                              "is not supported.")

    def has_only_static_tensors(self, graph, lora_tensor_names):
        for tensor_name in lora_tensor_names:
            if graph.has_buffer(tensor_name):
                producer_op = graph.get_producer_op(tensor_name)
                if not isinstance(producer_op, op_adapter.ConstantOp):
                    return False
        return True

    def has_native_and_static_tensors(self, graph, lora_tensor_names):
        has_updateable_static_tensor = False
        has_updateable_native_tensor = False
        invalid_native_tensors = []

        for tensor_name in lora_tensor_names:
            if graph.has_buffer(tensor_name):
                producer_op = graph.get_producer_op(tensor_name)
                if isinstance(producer_op, op_adapter.ConstantOp):
                    has_updateable_static_tensor = True
                else:
                    has_updateable_native_tensor = True

                    if producer_op.type not in [ir_graph.QNN_OP_CONV_2D, "elementwise_product"]:
                        invalid_native_tensors.append(tensor_name)

        if invalid_native_tensors:
            raise RuntimeError("Activation tensors in lora_weight_list can be only "
                               "lora conv2d or mul outputs (lora/adapter branch tensors). "
                               "The below tensors are outputs from other op types.\n{}"
                               .format("\n".join(invalid_native_tensors)))

        return has_updateable_static_tensor and has_updateable_native_tensor

    def _validate_has_only_static_tensors(self, graph, quant_updatable_mode, lora_tensor_names):
        if not self.has_only_static_tensors(graph, lora_tensor_names):
            raise ValueError(f"Quant updatable mode, {quant_updatable_mode.value}, "
                              "expects only lora weights in --lora_weight_list but "
                              "contains weights and activation tensors.")

    def _validate_graph_properties(self, graph):
        is_updatable = graph.get_graph_property_mask(ir_graph.IR_GRAPH_PROPERTY_MASK_UPDATABLE)
        is_no_quant_updatable = graph.get_graph_property_mask(ir_graph.IR_GRAPH_PROPERTY_MASK_NO_QUANT_UPDATABLE)
        is_adapter_only_quant_updatable = graph.get_graph_property_mask(ir_graph.IR_GRAPH_PROPERTY_MASK_ADAPTER_ONLY_QUANT_UPDATABLE)


        if not is_updatable:
            if is_adapter_only_quant_updatable:
                raise RuntimeError(f"Invalid graph property combination: ADAPTER_ONLY_QUANT_UPDATABLE is set but UPDATABLE is not.")
            elif is_no_quant_updatable:
                raise RuntimeError(f"Invalid graph property combination: NO_QUANT_UPDATABLE is set but UPDATABLE is not.")

        if is_no_quant_updatable and is_adapter_only_quant_updatable:
            raise RuntimeError(f"Invalid graph property combination: NO_QUANT_UPDATABLE and ADAPTER_ONLY_QUANT_UPDATABLE are both set.")
