# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
Helper functions
"""
import json
import os
from collections import deque
from typing import Callable, Dict, Iterable, List, Sequence, Set, Tuple

import numpy as np
import onnx
import safetensors.numpy
from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.ir_extra_info import VariableExtraInfo
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.logger import logger


def load_safetensors(safetensor_path):
    '''
    Load safetensor from the file
    '''
    return safetensors.numpy.load_file(safetensor_path)


def save_safetensors(tensor_dict, safetensor_path):
    '''
    Save safetensor to the file
    '''
    return safetensors.numpy.save_file(tensor_dict, safetensor_path)


def save_named_safetensors(named_safetensors: Dict[str, Dict], dir_path: str, enable_log=True):
    """
    Save named safetensors to the directory
    Args:
        named_safetensors: a dict of {encset_name: safetensors}
        dir_path: the directory path
        enable_log: wheter to print log infomation
    """
    saved_paths = {}
    # extract safetensors from extra_info
    for encset_name, graph_safetensors in named_safetensors.items():
        safetensor_path = os.path.join(dir_path, encset_name+".safetensors")
        save_safetensors(graph_safetensors, safetensor_path)
        saved_paths[encset_name] = safetensor_path

        if enable_log:
            logger.info(
                "saved '%s' safetensors to %s",
                encset_name,
                safetensor_path)
    return saved_paths


def save_named_encodings(named_graph_encodings: Dict[str, Dict], dir_path: str, enable_log=True):
    """
    Save named encodings to the directory
    Args:
        named_graph_encodings: a dict of {encset_name: encodings}
        dir_path: the directory path
        enable_log: wheter to print log infomation
    """
    # extract encodings from extra_info
    saved_encodings_paths = {}

    for encset_name, graph_encodings in named_graph_encodings.items():
        enc_path = os.path.join(dir_path, encset_name+".json")
        with open(enc_path, "w", encoding="utf-8") as f:
            f.write(json.dumps(graph_encodings, indent=4))
        saved_encodings_paths[encset_name] = enc_path

        if enable_log:
            logger.info("saved '%s' encodings to %s",
                         encset_name, enc_path)
    return saved_encodings_paths


def save_updatable_tensor_names(updatable_tensor_names: List[str], path: str):
    """
    Save updatable_tensor_names to the file
    Args:
        updatable_tensor_names: a list of updatable tensor names
        path: the file path
    """
    with open(path, 'w', encoding="utf-8") as f:
        for n in updatable_tensor_names:
            f.write(n+"\n")


def safe_replace_all_uses_with(graph: ir.Graph,
                               old_value: ir.Value,
                               new_value: ir.Value|None,
                               except_users=None):
    """
    Replace all the uses of old_value by new_value
    automatically handle the graph outputs if old_value is one of them
    In some case, we don't want to replace some users, we can add them in excpet_users

    Args:
        graph: ir graph
        old_value: the old value
        new_value: the new value to replace 
        except_users: the users to be excepted
    """
    assert new_value is not None # easier to call this function with mypy check
    if old_value is new_value:
        return
    is_graph_output = False
    if old_value.is_graph_output():
        is_graph_output = True

    for user_node, index in tuple(old_value.uses()):
        if except_users and user_node in except_users:
            continue
        user_node.replace_input_with(index, new_value)

    if is_graph_output:
        graph_outputs = graph.outputs
        graph_outputs[graph_outputs.index(old_value)] = new_value


def safe_insert_node_after(graph: ir.Graph,
                           node: ir.Node|None,
                           new_nodes: Iterable[ir.Node] | ir.Node, /) -> None:
    """
    Insert new_nodes after node in the graph
    if node is None, then assert new_nodes in the beginning

    Args:
        graph: ir graph
        node: the position to insert after
        new_nodes: the new node to insert
    """
    if node is not None:
        graph.insert_after(node, new_nodes)
    else:
        # node is None (for example node=v.producer(), where v is an initializer),
        # we add new_nodes as the first nodes
        graph.insert_before(graph[0], new_nodes)


def get_shape_of_slice(src_shape, axes, starts, ends, steps=None):
    """
    Calculate the shape of a slice given the source shape, axes, starts, ends, and steps.

    Args:
        src_shape (list): The shape of the source array.
        axes (list): The axes to slice.
        starts (list): The start indices for each axis.
        ends (list): The end indices for each axis.
        steps (list, optional): The step sizes for each axis. Defaults to None.

    Returns:
        list: The shape of the slice.
    """
    if steps is None:
        steps = [1] * len(axes)

    assert len(axes) == len(starts) == len(ends) == len(
        steps), "Axes, starts, ends, and steps must have the same length"

    out_shape = list(src_shape)[:]

    for axis, start, end, step in zip(axes, starts, ends, steps):
        assert axis >= 0, "Axis must be non-negative"
        assert start >= 0, "Start index must be non-negative"
        assert end >= 0, "End index must be non-negative"
        assert step > 0, "Step size must be positive"

        out_shape[axis] = (end - start) // step

    return out_shape


def check_static_shape(value: ir.Value | None):
    """
    Check whether a value has a static shape
    a ValueError will be raised if the value has not static shape

    Args:
        value: the value to check
    """
    if value is None:
        return
    if not has_static_shape_on_value(value):
        raise ValueError(
            f"Tensor {value.name} has not static shape, please run shape infer firstly")


def check_static_shape_of_node_io(node: ir.Node):
    """
    Check whether all inputs/outputs of the specified node have static shape
    a ValueError will be raised if not
    Args:
        node: the node to check
    """
    for v in node.inputs:
        check_static_shape(v)
    for v in node.outputs:
        check_static_shape(v)


def have_static_shape_on_node_io(node: ir.Node) -> bool:
    """
    Check whether all inputs/outputs of the specified node have static shape

    Args:
        node: the node to check
    """
    for v in node.inputs:
        if v is None:
            continue
        if v.shape is None:
            return False
        if not v.shape.is_static():
            return False
    for v in node.outputs:
        if v is None:
            continue
        if v.shape is None:
            return False
        if not v.shape.is_static():
            return False
    return True


def has_static_shape_on_value(value: ir.Value) -> bool:
    """
    Check whether the specified value has static shape

    Args:
        value: the value to check
    """
    if value.shape is None:
        return False
    if not value.shape.is_static():
        return False
    return True


def get_value_numeric_shape(value: ir.Value | None):
    """
    Get numeric shape of the value
    """
    assert value is not None
    shape = value.shape
    assert shape is not None
    return shape.numpy()


def get_constant_np(value: ir.Value|None):
    """
    Get value's constant numpy value
    a ValueError will be raised if the value is not a constant
    """
    if value is None:
        return None
    if value.const_value is not None:
        return value.const_value.numpy()

    producer = value.producer()
    if producer is not None:
        if producer.op_type == "Constant":
            return convert_attr_to_py(producer.attributes["value"])
        if producer.op_type == "Identity":
            # check for mypy, definitely true
            assert producer.inputs[0] is not None
            return get_constant_np(producer.inputs[0])

    return None


def is_constant(value: ir.Value|None) -> bool:
    """
    Check if the given value is constant
    """
    if value is None:
        return False
    if value.const_value is not None:
        return True
    producer = value.producer()
    if producer is not None:
        if producer.op_type == "Constant":
            return True
        if producer.op_type == "Identity":
            # check for mypy, definitely true
            assert producer.inputs[0] is not None
            return is_constant(producer.inputs[0])

    return False


def make_initializer(graph: ir.Graph, name: str, np_array: np.ndarray|list|tuple):
    """
    Create initializer to the graph
    name should be unique
    """
    if isinstance(np_array, (list, tuple)):
        np_array = np.array(np_array)
    tensor = ir.TensorProtoTensor(
        onnx.numpy_helper.from_array(np_array, name=name)
    )
    v = ir.Value(name=name,
                 type=ir.TensorType(tensor.dtype),
                 shape=tensor.shape,
                 const_value=tensor)
    v.shape = ir.Shape(tuple(np_array.shape))
    v.type = ir.TensorType(tensor.dtype)
    v.meta["extra_info"] = VariableExtraInfo()
    graph.initializers[name] = v
    return v


def iter_all_values(graph_ir: ir.Graph):
    """
    Iterate all values in the graph
    """
    # iterate inputs
    yield from graph_ir.inputs

    # iterate for initializers
    yield from graph_ir.initializers.values()

    # iterate for activations
    for node in graph_ir:
        yield from node.outputs


def convert_attr_to_py(value, schema:str|None=None):
    # pylint: disable=[too-many-return-statements, too-many-branches]
    """
    Convert attribute value to python object automatically
    """
    if schema is not None:
        # schema can be only applied with ir.Attr
        assert isinstance(value, ir.Attr)
        if schema == "as_int":
            return value.as_int()
        if schema == "as_ints":
            return value.as_ints()
        if schema == "as_float":
            return value.as_float()
        if schema == "as_floats":
            return value.as_floats()
        if schema == "as_string":
            return value.as_string()
        if schema == "as_strings":
            return value.as_strings()
        if schema == "as_tensor":
            return value.as_tensor().numpy()
        if schema == "as_tensors":
            return [x.numpy() for x in value.as_tensors()]

    if isinstance(value, ir.TensorProtocol):
        return value.numpy()
    if isinstance(value, (int, float, str)):
        return value
    if isinstance(value, Sequence):
        # copy to prevent mutation
        return [convert_attr_to_py(x) for x in value]
    if isinstance(value, ir.Attr):
        return convert_attr_to_py(value.value)

    raise NotImplementedError


def convert_attrs_to_py(attrs: Dict[str, ir.Attr]):
    """
    Convert attributes to dictionary with python objects automatically
    """
    py_attrs = {}
    for k, v in attrs.items():
        py_attrs[k] = convert_attr_to_py(v.value)
    return py_attrs


def clean_model_proto(proto: onnx.ModelProto):  # pylint: disable=[too-many-branches]
    """
    Remove unused node from the given proto
    """
    # remove unused node
    act_use_counts = {}
    for n in proto.graph.node:
        for v in n.input:
            if v not in act_use_counts:
                act_use_counts[v] = 0
            act_use_counts[v] += 1
    for v in proto.graph.output:
        if v.name not in act_use_counts:
            act_use_counts[v.name] = 0
        act_use_counts[v.name] += 1

    clean_node_lists = []
    for n in proto.graph.node[::-1]:
        all_unused = True
        for v in n.output:
            if act_use_counts.get(v, 0) > 0:
                all_unused = False
                break
        if all_unused:
            # all outputs are unused
            # so remove this node
            for v in n.input:
                act_use_counts[v] -= 1
        else:
            clean_node_lists.append(n)

    clean_node_lists.reverse()

    # remove unused initializers
    clean_initializer = []
    for v in proto.graph.initializer:
        if act_use_counts.get(v.name, 0) <= 0:
            continue
        clean_initializer.append(v)

    while len(proto.graph.initializer) > 0:
        proto.graph.initializer.pop()
    proto.graph.initializer.extend(clean_initializer)

    while len(proto.graph.node) > 0:
        proto.graph.node.pop()
    proto.graph.node.extend(clean_node_lists)
    return proto


def get_shape_from_value_info_proto(
    val_info: onnx.ValueInfoProto,
    allow_symbols: bool = False,
) -> List[str | int] | None:
    """
    Copied from MHA2SHA repo
    Function to get the shape from value info proto.
    """
    tensor_shape = []
    tensor_type = val_info.type.tensor_type

    if not tensor_type.HasField("shape"):
        return None

    # iterate through dimensions of the shape:
    for d in tensor_type.shape.dim:
        # the dimension may have a definite (integer) value or a symbolic identifier or neither:
        if d.HasField("dim_value"):
            tensor_shape.append(d.dim_value)
        elif d.HasField("dim_param"):
            # unknown dimension with symbolic name
            if allow_symbols:
                tensor_shape.append(d.dim_param)
            else:
                return None
        else:
            return None
    return tensor_shape


def load_updatable_tensor_list(updatable_tensors_path: str) -> List[str]:
    """
    Load updatable tensor list from the given file path
    """
    updatable_tensors = []
    if updatable_tensors_path is not None:
        with open(updatable_tensors_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line:
                    updatable_tensors.append(line)
    return updatable_tensors


def get_base_enc_name(named_encodings_paths, named_safetensors_paths):
    """
    Infer the base encodings encset_name
    """
    if len(named_encodings_paths):
        return None

    base_enc_name_candidates = set(
        named_encodings_paths.keys()) - set(named_safetensors_paths.keys())
    if len(base_enc_name_candidates) == 1:
        return base_enc_name_candidates.pop()
    if len(base_enc_name_candidates) == 0:
        logger.warning("Cannot find base enc name, please ensure encodings path for base is correctly set"
                        "Note: base graph has no safetensors")
        return None
    if len(base_enc_name_candidates) > 1:
        base_enc_name = base_enc_name_candidates.pop()
        logger.warning("Found multiple base enc candiates %s, using %s" +
                        "Note: base graph has no safetensors",
                        base_enc_name_candidates,
                        base_enc_name)
        return base_enc_name
    raise ValueError("cannot infer the base encoding name")


def get_attribute_with_default(node: ir.Node, attr_name: str, default_value):
    """
    Get node attribute, if it has not the specified attribute, then return default_value

    Args:
        node: 
        attr_name: the attribute name
        default_value: the default attribute value
    Return:
        attribute with python object
    """
    if attr_name in node.attributes:
        return convert_attr_to_py(node.attributes[attr_name])

    return default_value


def scan_previous_nearest_candidate(start_values: List[ir.Value|None],
                                    check_fn:Callable[[ir.Value], bool],
                                    ignore_fn:Callable[[ir.Value], bool]):
    """
    Scan the nearest acceptable nodes bottom-up from the start values

    Args:
        start_values: searching start nodes
        check_fn: whether the node is acceptable
        ignore_fn: whether the node should be ignored, and continue to scan its input
    """
    check_values: List[ir.Value | None] = []
    check_values += start_values
    while len(check_values) > 0:
        candidate_v = check_values.pop(0)
        if candidate_v is None:
            continue
        if ignore_fn is not None and ignore_fn(candidate_v):
            candidate_v_producer = candidate_v.producer()
            if candidate_v_producer is not None:
                check_values += list(candidate_v_producer.inputs)
        elif check_fn(candidate_v):
            yield candidate_v


class ConditionOnValueProducer:
    """
    Represents the condition on the value producer types
    used for nearest candidate scanning
    """

    def __init__(self, valid_producer_types):
        self.valid_producer_types = valid_producer_types

    def __call__(self, v: ir.Value):
        producer = v.producer()
        if producer is None:
            return False
        if producer.op_type in self.valid_producer_types:
            return True
        return False


def load_json(json_path: str):
    """
    load json file
    """
    with open(json_path, "r", encoding="utf-8") as f:
        return json.load(f)


def autocomplete_opset(model:onnx.ModelProto, default_possible_opsets:Dict[str,int]):
    '''
    Opset for custom domain (for example qti.aisw) may not be well defined in the model.
    In this case onn.shape_inference will fail, so we need to autocomplete the opset for it
    '''
    opset_imports = {
        x.domain:x.version for x in model.opset_import
    }
    opsets_to_add = {}
    for n in model.graph.node:
        if n.domain not in opset_imports \
            and n.domain in default_possible_opsets \
            and n.domain not in opsets_to_add:

            logger.warning('opset version for domain "%s" is not set, autocomplete it to version "%d"',
                           n.domain, default_possible_opsets[n.domain])
            opsets_to_add[n.domain] = default_possible_opsets[n.domain]

    for domain, version in opsets_to_add.items():
        model.opset_import.append(
            onnx.OperatorSetIdProto(domain=domain, 
                                    version=version)
        )

def load_external_data_for_constant(model:ir.Model):
    '''
    Load external data for constant node.
    currently, onnx_ir will not handle external data for constant node in the serialization,
    to fix this, we should firstly load them before we serialize the graph into onnx protobuf
    '''
    for n in model.graph:
        if n.op_type == "Constant":
            if "value" not in n.attributes:
                continue
            value = n.attributes["value"]
            if value is None:
                continue
            if value.type != ir.AttributeType.TENSOR:
                continue
            external_tensor = value.value
            if isinstance(external_tensor, ir.ExternalTensor):
                # load data from file
                tensor = ir.Tensor(external_tensor.numpy().copy(),
                                  name=external_tensor.name,
                                  dtype=external_tensor.dtype)
                n.attributes["value"] = ir.AttrTensor(
                    name="value",
                    value=tensor,
                    doc_string=n.attributes["value"].doc_string,
                )

def is_used(v:ir.Value):
    '''
    Check if the indicated value is used or not
    '''
    return len(v.uses()) > 0 or v.is_graph_output()

def scan_ancestors_with_budget(start_v: ir.Value | None, 
                               max_layers_to_traverse:int|None) -> List[ir.Value]:
    if start_v is None:
        return []
    
    ancestors:List[ir.Value] = []
    visited:Set[int] = set()
    queue:deque[Tuple[ir.Value|None, int]] = deque([(start_v, 0)])  # (node, current_layer_id)
    
    while queue:
        current_v, current_layer_id = queue.popleft()
        
        if current_v is None:
            continue
        if id(current_v) in visited:
            continue
        
        visited.add(id(current_v))
        ancestors.append(current_v)
        
        if max_layers_to_traverse and current_layer_id >= max_layers_to_traverse:
            continue
        producer = current_v.producer()
        if producer is None:
            continue
        for input_value in producer.inputs:
            if id(input_value) not in visited:
                queue.append((input_value, current_layer_id + 1))
    return ancestors


def scan_least_common_ancestor(start_v1: ir.Value | None, 
                                start_v2: ir.Value | None, 
                                max_layers_to_traverse:int|None):
    """
    Scan ancestors of a value, return the common least ancestors
    """
    
    # Collect all ancestors of start_v1 and start_v2
    ancestors_of_v1 = scan_ancestors_with_budget(start_v1, max_layers_to_traverse)
    ancestors_of_v2 = scan_ancestors_with_budget(start_v2, max_layers_to_traverse)
    
    # common nodes
    common_ancestors_set = set(x.name for x in ancestors_of_v1).intersection(
                            set(x.name for x in ancestors_of_v2))
    
    # find lca
    lca = None
    for ancestor in ancestors_of_v1:
        if ancestor.name in common_ancestors_set:
            lca = ancestor
            break
    return lca
