# =============================================================================
#
#  Copyright (c) 2019-2022 Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import numpy as np

from qti.aisw.converters.common.converter_ir.op_adapter import MultiClassNmsOp, ReshapeOp
from qti.aisw.converters.common.utils import code_to_message, converter_utils
from qti.aisw.converters.common.utils.translation_utils import compare_values
from qti.aisw.converters.tensorflow import util
from qti.aisw.converters.tensorflow.common import LayerDescriptor, LayerResolver, LayerBuilder
from qti.aisw.converters.tensorflow.graph_matcher import (
    ConverterSequenceNode,
    NonConsumableConverterSequenceNode,
    GraphSequence
)
from qti.aisw.converters.tensorflow.layers.constant import ConstantLayerResolver
from qti.aisw.converters.tensorflow.util import ConverterError, get_const_op_value


class NonMaxSuppressionLayerResolver(LayerResolver, object):

    class Descriptor(LayerDescriptor):
        def __init__(self, name, nodes, max_output_size, iou_threshold, score_threshold, nms_op, boxes_op, scores_ops, input_boxes_op, input_scores_op, output_names=None):
            super(NonMaxSuppressionLayerResolver.Descriptor, self).__init__('NonMaxSuppression', name, nodes, output_names=output_names)
            self.max_output_size = max_output_size
            self.iou_threshold = iou_threshold
            self.score_threshold = score_threshold

            self.nms_op = nms_op
            self.boxes_op = boxes_op
            self.scores_ops = scores_ops

            # Input
            self.input_boxes_op = input_boxes_op
            self.input_scores_op = input_scores_op

            # Output
            self.output_boxes_op = None
            self.output_scores_op = None
            self.output_features_op = []

        def is_input_tensor(self, op, tensor):
            if tensor.op.type == "Const" and \
                    any([compare_values(get_const_op_value(tensor.op), t) for t in [float(self.score_threshold),
                                                                                    int(self.max_output_size),
                                                                                    float(self.iou_threshold)]]):
                return False
            return True

        def is_output_op(self, op):
            if op in self.output_features_op:
                return True
            elif op == self.output_boxes_op or op == self.output_scores_op:
                return True
            else:
                return False

    def __init__(self):

        # Seq #1 : with boxes reshaped and scores reshaped and sliced, , for layer nms/NonMaxSuppressionV2
        sequence_1 = GraphSequence([
            # Multi class scores
            NonConsumableConverterSequenceNode('scores_input', ['?']),
            ConverterSequenceNode('scores_reshape', ['Reshape']),
            ConverterSequenceNode('scores_reshape_input_shape', ['?']),
            ConverterSequenceNode('strided_slice_input_beign', ['?']),
            ConverterSequenceNode('strided_slice_input_end', ['?']),
            ConverterSequenceNode('strided_slice_input_strides', ['?']),
            ConverterSequenceNode('scores', ['StridedSlice']),

            # Boxes
            NonConsumableConverterSequenceNode('boxes_input', ['?']),
            ConverterSequenceNode('boxes', ['Reshape']),
            ConverterSequenceNode('boxes_reshape_input_shape', ['?']),

            ConverterSequenceNode('nms', ['NonMaxSuppressionV2']),
            ConverterSequenceNode('max_output_size', ['Const']),
            ConverterSequenceNode('iou_threshold', ['?']),
        ])
        sequence_1.set_inputs('boxes', ['boxes_input', 'boxes_reshape_input_shape'])
        sequence_1.set_inputs('scores_reshape', ['scores_input', 'scores_reshape_input_shape'])
        sequence_1.set_inputs('scores', ['scores_reshape', 'strided_slice_input_beign', 'strided_slice_input_end', 'strided_slice_input_strides'])
        sequence_1.set_inputs('nms', ['boxes', 'scores', 'max_output_size', 'iou_threshold'])
        sequence_1.set_outputs(['nms'])

        # Seq #2, with boxes and scores squeezed, for layer nms/NonMaxSuppressionV2
        sequence_2 = GraphSequence([
            # Multi class scores
            NonConsumableConverterSequenceNode('scores_input', ['?']),
            ConverterSequenceNode('scores', ['Squeeze']),

            # Boxes
            NonConsumableConverterSequenceNode('boxes_input', ['?']),
            ConverterSequenceNode('boxes', ['Squeeze']),

            ConverterSequenceNode('nms', ['NonMaxSuppressionV2']),
            ConverterSequenceNode('max_output_size', ['Const']),
            ConverterSequenceNode('iou_threshold', ['Const']),
        ])
        sequence_2.set_inputs('boxes', ['boxes_input'])
        sequence_2.set_inputs('scores', ['scores_input'])
        sequence_2.set_inputs('nms', ['boxes', 'scores', 'max_output_size', 'iou_threshold'])
        sequence_2.set_outputs(['nms'])

        # Seq #3,where no reshapes/slices are added (the resolver will be handling the reshapes in this case, as needed)
        sequence_3 = GraphSequence([
            NonConsumableConverterSequenceNode('boxes', ['?']),
            NonConsumableConverterSequenceNode('scores', ['?']),
            NonConsumableConverterSequenceNode('max_output_size', ['Const']),
            NonConsumableConverterSequenceNode('stub_1', ['?']),
            ConverterSequenceNode('nms', ['NonMaxSuppressionV2']),
            NonConsumableConverterSequenceNode('iou_threshold', ['?']),
        ])

        sequence_3.set_inputs('nms', ['boxes', 'scores', 'max_output_size', 'iou_threshold'])
        sequence_3.set_outputs(['nms'])

        sequence_3_v3 = GraphSequence([
            NonConsumableConverterSequenceNode('boxes', ['?']),
            NonConsumableConverterSequenceNode('scores', ['?']),
            NonConsumableConverterSequenceNode('max_output_size', ['Const']),
            NonConsumableConverterSequenceNode('stub_1', ['?']),
            ConverterSequenceNode('nms', ['NonMaxSuppressionV3']),
            NonConsumableConverterSequenceNode('iou_threshold', ['?']),
            NonConsumableConverterSequenceNode('score_threshold', ['?']),
        ])
        sequence_3_v3.set_inputs('nms', ['boxes', 'scores', 'max_output_size', 'iou_threshold', 'score_threshold'])
        sequence_3_v3.set_outputs(['nms'])

        self.sequences = [sequence_1, sequence_2, sequence_3, sequence_3_v3]

        # TODO: following added for VIVO support of nms + gather in 1.23.0 to support features as inputs
        #       required until nms layer supported independently
        # Filter seqs
        filter_sequence = GraphSequence([
            ConverterSequenceNode('gather', ['GatherV2']),
            ConverterSequenceNode('axis', ['Const']),
            NonConsumableConverterSequenceNode('params', ['?']),
            NonConsumableConverterSequenceNode('indices', ['NonMaxSuppressionV2', 'NonMaxSuppressionV3'])
        ])
        filter_sequence.set_inputs('gather', ['params', 'indices', 'axis'])
        filter_sequence.set_outputs(['gather'])

        # Filter seqs 2
        filter_sequence_2 = GraphSequence([
            ConverterSequenceNode('gather', ['Gather']),
            NonConsumableConverterSequenceNode('params', ['?']),
            NonConsumableConverterSequenceNode('indices', ['NonMaxSuppressionV2'])
        ])
        filter_sequence_2.set_inputs('gather', ['params', 'indices'])
        filter_sequence_2.set_outputs(['gather'])

        self.g_sequences = [filter_sequence, filter_sequence_2]

    # TODO: following added for VIVO support of nms + gather in 1.23.0 to support features as inputs
    #       required until nms layer supported independently
    def _resolve_for_gather_layer(self, graph_matcher, graph_helper, descriptor):
        for sequence in self.g_sequences:
            for match in graph_matcher.match_sequence(sequence):
                # Filter ops use nms as input.
                if match['indices'] != descriptor.nms_op:
                    continue

                params_op = match['params']
                output_op = match['gather']

                if params_op == descriptor.boxes_op or params_op == descriptor.input_boxes_op:
                    descriptor.output_boxes_op = output_op
                elif params_op in descriptor.scores_ops:
                    descriptor.output_scores_op = output_op
                else:
                    descriptor.output_features_op.append(output_op)

                descriptor.child_ops.extend(match.consumed_nodes)

        # Validation
        if not (descriptor.output_boxes_op and descriptor.output_scores_op):
            raise ConverterError('Cannot find bboxes or scores')

        # Order is important
        output_names = [str(descriptor.output_boxes_op.outputs[0].name),
                        str(descriptor.output_scores_op.outputs[0].name),
                        descriptor.layer_name + "_classes",
                        descriptor.layer_name + "_valid_num_detections"]

        for feature_output in descriptor.output_features_op:
            output_names.append(str(feature_output.outputs[0].name))

        descriptor.output_names = output_names

    def resolve_layer(self, graph_matcher, graph_helper):
        descriptors = []
        for sequence in self.sequences:
            for match in graph_matcher.match_sequence(sequence):

                # resolve layer for nms operation
                nms_op = match['nms']
                boxes_op = match['boxes']
                scores_ops = [match[k] for k in match.keys() if k.startswith("score")]

                input_boxes_op = match['boxes_input'] if 'boxes_input' in match else boxes_op
                input_scores_op = match['scores_input'] if 'scores_input' in match else match['scores']

                max_output_size = graph_helper.evaluate_tensor_output(match['max_output_size'].outputs[0])
                iou_threshold = graph_helper.evaluate_tensor_output(match['iou_threshold'].outputs[0])
                score_threshold = graph_helper.evaluate_tensor_output(match['score_threshold'].outputs[0]) if 'score_threshold' in match else float(0)
                if score_threshold == -1 * np.inf:
                    score_threshold = np.finfo(np.float32).min

                if "axis" in match:
                    # NMS + gather supported in 1 Op with expectation of Gather on axis 0.
                    converter_utils.log_assert(match['axis'] == 0,
                                               code_to_message.get_error_message('ERROR_TF_NMS_GATHER_INVALID_AXIS')
                                                                                (match['axis']))
                consumed_nodes = match.consumed_nodes

                nms_descriptor = NonMaxSuppressionLayerResolver.Descriptor(
                    str(nms_op.name), consumed_nodes, max_output_size, iou_threshold, score_threshold, nms_op, boxes_op,
                    scores_ops, input_boxes_op, input_scores_op, output_names=[str(nms_op.outputs[0].name)])

                descriptors.extend([nms_descriptor])

                # TODO: following added for VIVO support of nms + gather in 1.23.0 to support features as inputs
                #       required until nms layer supported independently
                # resolve layer for gather operation
                self._resolve_for_gather_layer(graph_matcher, graph_helper, nms_descriptor)

                if input_boxes_op.type == 'Const':
                    boxes_tensor = graph_helper.evaluate_tensor_output(input_boxes_op.outputs[0])
                    boxes_shape = graph_helper.get_op_output_shape(input_boxes_op)
                    if len(boxes_shape) == 2:
                        boxes_shape.insert(0, 1)
                    else:
                        raise ConverterError(code_to_message.get_error_message('ERROR_TF_NMS_BOXES_SHAPE'), len(boxes_shape))
                    const_descriptor = ConstantLayerResolver.Descriptor(str(input_boxes_op.name), [input_boxes_op],
                                                                        boxes_tensor, boxes_shape, nms_descriptor)
                    descriptors.append(const_descriptor)

        return descriptors


class NonMaxSuppressionLayerBuilder(LayerBuilder):

    @staticmethod
    def _build_input_layers(ir_graph, converter_context, descriptor, names):
        """
        This function helps to reshape the inputs of tf.image.non_max_suppression to align with
        what IR expects for multiclassnms
        Boxes: [batch, num_boxes, 4]
        Scores: [batch, num_boxes, classes]
        Input Features: [batch, ***]
        .
        """
        for op in names:
            input_shape = converter_context.graph_helper.get_op_output_shape(op)
            if len(input_shape) < 3:
                input_name = names[op]
                intermediate_output_name = input_name + '_nms_reshape_to_3d'
                names[op] = intermediate_output_name

                if len(input_shape) == 2 and input_shape[0] == 1:
                    input_shape.append(1)
                else:
                    # Add separate case to scores when it is 1 dimensional so we want to
                    # append and pre-prepend a dimension to align second dim with num of boxes.
                    if "score" in intermediate_output_name and len(input_shape) == 1:
                        input_shape.append(1)
                    input_shape = util.expand_to_rank(input_shape, 3)

                ir_graph.add(ReshapeOp(input_name + '_pre_reshape',
                                       shape=input_shape),
                             [input_name],
                             [intermediate_output_name])

    @staticmethod
    def _compare_op_shapes(converter_context, ops):
        """
        Compares the shape of all ops in the list
        :param ops: list of ops
        :type converter_context: converters.tensorflow.converter.ConverterContext
        :return: True if all are equal or empty list, False otherwise
        """
        if len(ops):
            shape = converter_context.graph_helper.get_op_output_shape(ops[0])  # get shape for first op
            for op in ops:
                if shape != converter_context.graph_helper.get_op_output_shape(op):
                    return False
        else:
            print("WARNING: empty list provided to compare nms ops shapes")
        return True

    def build_layer(self, ir_graph, converter_context, descriptor, input_descriptors, output_descriptors):
        """
        :type ir_graph: converters.common.converter_ir.op_graph.IROpGraph
        :type converter_context: converters.tensorflow.converter.ConverterContext
        :type descriptor: NonMaxSuppressionLayerResolver.Descriptor
        :type input_descriptors: [converters.tensorflow.common.LayerDescriptor]
        :type output_descriptors: [converters.tensorflow.common.LayerDescriptor]
        :rtype: int
        """

        names = {}
        for input_descriptor in input_descriptors:
            if input_descriptor.is_output_op(descriptor.input_boxes_op):
                names[descriptor.input_boxes_op] = input_descriptor.output_names[0]
            elif input_descriptor.is_output_op(descriptor.input_scores_op):
                names[descriptor.input_scores_op] = input_descriptor.output_names[0]

        if len(names) != 2:
            raise ConverterError("Failed to detect inputs for nms op.")

        input_names = [names[descriptor.input_boxes_op], names[descriptor.input_scores_op]]
        input_names.extend(list(set(self.get_input_names(converter_context, descriptor, input_descriptors)) - set(input_names)))

        # input/output ops list
        input_output_ops_pairs = [(descriptor.input_boxes_op, descriptor.output_boxes_op),
                                  (descriptor.input_scores_op, descriptor.output_scores_op)]

        # add reshape input layers as needed to input layers to work with snpe multiclassnms layer
        self._build_input_layers(ir_graph, converter_context, descriptor, names)

        # _build_input_layers may insert ReshapeOp which changes the inputs to NMS
        # Replace the input_names with those from the output of _build_input_layers
        input_names[0] = names[descriptor.input_boxes_op]
        input_names[1] = names[descriptor.input_scores_op]
        output_names = descriptor.output_names[:]

        # adding suffix for boxes and scores since we need to do post reshape(below) to get back to TF shape
        temp_output_names = output_names[:]
        for input_op, output_op in input_output_ops_pairs:
            for i in range(0, len(output_names)):
                if temp_output_names[i] == output_op.outputs[0].name:
                    temp_output_names[i] = output_op.outputs[0].name + "_intermediate"

        ir_graph.add(MultiClassNmsOp(name=descriptor.layer_name,
                                        score_threshold=descriptor.score_threshold,
                                        iou_threshold=descriptor.iou_threshold,
                                        max_total_detections=descriptor.max_output_size),
                     input_names=input_names,
                     output_names=temp_output_names)

        # Post-processing, revert back reshaped layers to the expected output shape from Tensorflow
        for input_op, output_op in input_output_ops_pairs:
            for i in range(0, len(output_names)):
                if output_op.outputs[0].name in output_names[i]:
                    output_name = output_op.outputs[0].name
                    shape = converter_context.graph_helper.get_op_output_shape(output_op)
                    ir_graph.add(ReshapeOp(output_name + '_post_reshape_to_' + str(len(shape)) + 'd',
                                           shape=shape),
                                 [temp_output_names[i]],
                                 [output_name])
