# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import sys
import yaml
from enum import Enum, auto
import argparse
import threading

try:
    from qti.aisw.converters.common import ir_optimizer
    from qti.aisw.converters.common.backend_awareness import BackendInfo
    from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper
    from qti.aisw.converters.common.utils import validation_utils
except ImportError as ie:
    print("Failed to find necessary optimization packages:")
    print(str(ie))
    print("Please ensure that $QNN_SDK_ROOT/lib/python is in your PYTHONPATH")
    sys.exit(1)

# supported optimization stages
class OptimizationStage(Enum):
    PreLayout = ir_optimizer.OptimizationStage.PreLayout
    PostLayout = ir_optimizer.OptimizationStage.PostLayout
    PreQuant = ir_optimizer.OptimizationStage.PreQuant
    PostQuant = ir_optimizer.OptimizationStage.PostQuant

# supported optimization modes
class OptPassMode(Enum):
    IrOptimizerMainline = auto()
    IrOptimizerExperimental = auto()
    IrOptimizerDisable = auto()
    DisablePostQuantOptimizations = auto()

# passMode: mode in which CPP-IR passes to be executed
# mainline: set of CPP-IR passes enabled in mainline by default
# expt: set of CPP-IR passes which are dev tested and kept under experimental flag and not enabled in mainline by default
# NOTE: Both mainline and expt are a map of PyIr optimizaion name-> CPP-IR optimization name

class OptimizationPassInfo:
    pass_name_mapping = {
        'FOLD_RESHAPES':'FoldReshapes',
        'SQUASH_TRANSPOSE_RESHAPE':'OptimizeTransposeReshape',
        'HANDLE_GATHER_NEGATIVE_INDICES':'HandleGatherNegativeIndices',
        'FOLD_MULTIPLE_TRANSPOSE':'OptimizeTransposes',
        'SINK_TRANSPOSE_BELOW_SUM':'SinkTransposeBelowSum',
        'REMOVE_IDENTITY':'RemoveNoOps',
        'SQUASH_MULTIPLE_PERMUTE':'SquashMultiplePermute',
        'SQUASH_RESHAPE':'SquashReshape',
        'MATCH_SPACETODEPTH':'MatchSpaceToDepth',
        'MATCH_DEPTHTOSPACE':'OptimizeDepthToSpace',
        'MATCH_GATHERND':'OptimizeGatherND',
        'SQUASH_PAD':'SquashPad',
        'ALIGN_MATMUL_RANKS':'AlignMatmulRanks',
        'FOLD_SOFTMAX':'FoldSoftmax',
        'FOLD_CONCATS':'FoldConcats',
        'REPLACE_6D_OPERATION':'Replace6dOps'
    }

    passes_map = {
        'mainline_passes': {
            'ENABLED': [
                'FOLD_RESHAPES',
                'SQUASH_TRANSPOSE_RESHAPE',
                'HANDLE_GATHER_NEGATIVE_INDICES',
                'FOLD_MULTIPLE_TRANSPOSE',
                'REMOVE_IDENTITY',
                'SQUASH_MULTIPLE_PERMUTE',
                'SQUASH_RESHAPE',
                'MATCH_GATHERND',
                'SQUASH_PAD',
                'ALIGN_MATMUL_RANKS',
                'REPLACE_6D_OPERATION'
            ],
            'DISABLED': []
        },
        'expt_passes': {
            'ENABLED': [
                'FOLD_CONCATS'
            ],
            'DISABLED': [
                'MATCH_SPACETODEPTH',
                'MATCH_DEPTHTOSPACE',
                'FOLD_SOFTMAX',
            ]
        }
    }


class OptimizationPassModeManager:

    _instance = None
    _lock = threading.Lock()
    _initialized = False

    def __new__(cls):
        if cls._instance is None:
           with cls._lock:
                # Recheck for second instance in case of race condition
               if cls._instance is None:
                   cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self):
        if not OptimizationPassModeManager._initialized:
            OptimizationPassModeManager._initialized = True

            # Need to be updated as per pass addition
            # Developers while adding a pass need to add it to the expt_default list
            # once verified by QA in a batch, those passes can be moved to mainline_default
            # by raising a new gerrit
            self.mainline_default, self.expt_default = self.format_optimization_map(OptimizationPassInfo.passes_map)
            # These are the passes after modified by config
            self.mainline_final = self.mainline_default.copy()
            self.expt_final = self.expt_default.copy()
            # Default optimization pass mode should be Python optimizations,
            # as in absense of any flags passes Python cases will fail otherwise
            self.passMode = OptPassMode.IrOptimizerDisable
            self.cpp_ir_passes_list = None
            self.recomputePassesList = True

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            # Call the Init
            OptimizationPassModeManager()
        return cls._instance

    def format_optimization_map(self, passes_map):
        mainline_passes = {}
        expt_passes = {}
        pass_name_mapping = OptimizationPassInfo.pass_name_mapping
        try:
            for p in passes_map['mainline_passes']['ENABLED']:
                mainline_passes[p] = [pass_name_mapping[p], True]
            for p in passes_map['mainline_passes']['DISABLED']:
                mainline_passes[p] = [pass_name_mapping[p], False]
            for p in passes_map['expt_passes']['ENABLED']:
                expt_passes[p] = [pass_name_mapping[p], True]
            for p in passes_map['expt_passes']['DISABLED']:
                expt_passes[p] = [pass_name_mapping[p], False]
        except:
            raise Exception(f"Pass {p} not found in Python-CPP_IR pass name mapping")
        return mainline_passes, expt_passes

    def get_optimization_pass_mode(self,passMode):
        if(passMode=="ir_optimizer_mainline"):
            return OptPassMode.IrOptimizerMainline
        elif(passMode=="ir_optimizer_experimental"):
            return OptPassMode.IrOptimizerExperimental
        elif(passMode=="disable_post_quant_optimizations"):
            return OptPassMode.DisablePostQuantOptimizations
        else:
            return OptPassMode.IrOptimizerDisable

    def set_pass_mode(self, args):
        self.recomputePassesList = True
        if hasattr(args, 'optimization_pass_mode') and args.optimization_pass_mode != "":
            self.passMode = self.get_optimization_pass_mode(args.optimization_pass_mode)
        else:
            self.passMode = OptPassMode.IrOptimizerDisable


    def parse_optimization_pass_mode_config(self, args):
        """
        This method reads the YAML config file and assigns the config to optimizer_context
        :return: None
        """
        if hasattr(args, 'optimization_pass_mode_config') and args.optimization_pass_mode_config != "":
            if hasattr(args, 'dump_optimization_pass_mode_config') and args.dump_optimization_pass_mode_config != False:
                raise Exception("dump_optimization_pass_mode_config and optimization_pass_mode_config cannot be passed together")

            pass_mode_config = args.optimization_pass_mode_config

            with open(pass_mode_config, 'r') as stream:
                try:
                    yaml_data = yaml.safe_load(stream)
                except yaml.YAMLError as e:
                    raise RuntimeError('Error parsing YAML config file') from e

            mainline_passes, expt_passes = self.format_optimization_map(yaml_data)

            # verify mainline_passes and expt_passes are subset of the mainline+expt lists
            for p in {*mainline_passes, *expt_passes}:
                if p not in {**self.mainline_final, **self.expt_final}:
                    raise Exception(f"Pass {p} not part of Mainline and Experimental passes. Passes should be part of: ",{*self.mainline_final, *self.expt_final})

            if {*mainline_passes} & {*expt_passes}:
                raise Exception("Mainline and Experimental Passes should NOT be overlapping")
            # Update the pipeline list to be processed
            self.mainline_final = mainline_passes
            self.expt_final = expt_passes

    def set_cpp_ir_pass_list(self):
        """
            Create a list of methods to be skipped based on CPP-IR optimization_pipeline
        """
        self.cpp_ir_passes_list = {}
        # Skip optimizations part of mainline pipeline if flag passed is ['ir_optimizer_mainline' or 'ir_optimizer_expt' or 'disable_post_quant_opt']
        if (self.passMode==OptPassMode.IrOptimizerMainline or
            self.passMode==OptPassMode.IrOptimizerExperimental or
            self.passMode==OptPassMode.DisablePostQuantOptimizations):
            self.cpp_ir_passes_list = set([pyPass for pyPass, cppPass in self.mainline_final.items() if cppPass[1]])
        # Additionally skip experimental passes if flag passed is 'ir_optimizer_expt'
        if (self.passMode==OptPassMode.IrOptimizerExperimental):
            self.cpp_ir_passes_list.update(set([pyPass for pyPass, cppPass in self.expt_final.items() if cppPass[1]]))

    def get_cpp_ir_pass_list(self):
        """
            Return a list of methods to be skipped based on CPP-IR optimization_pipeline
        """
        if self.recomputePassesList:
            self.set_cpp_ir_pass_list()
            self.recomputePassesList = False
        return self.cpp_ir_passes_list

class GraphOptimizer(object):
    class ArgParser(ArgParserWrapper):
        def __init__(self, **kwargs):
            super(GraphOptimizer.ArgParser, self).__init__(**kwargs)
            graph_optimizer_group = self.add_argument_group(title='Graph Optimizer Options')
            graph_optimizer_group.add_argument('--ir_optimizer_config',type=str,
                                                action=validation_utils.validate_filename_arg(must_exist=True),
                                                help=argparse.SUPPRESS, default="")
            graph_optimizer_group.add_argument('--dump_ir_optimizer_config_template', default=False, action="store_true",
                                                help=argparse.SUPPRESS)
            graph_optimizer_group.add_argument('--dump_pass_trace_info', default=False, action="store_true",
                                                help=argparse.SUPPRESS)
            graph_optimizer_group.add_argument('--dump_ir', nargs='+', type=str, default="",
                                                help=argparse.SUPPRESS)
            # Flag to select one of the pipelines among: ir_optimizer_mainline: Run mainline passes list and Disable corresponding passes in python,
            #                                            ir_optimizer_experimental: Run mainline+expt passes list and Disable corresponding passes in python,
            #                                            ir_optimizer_disable: Run all passes in python ONLY,
            #                                            disable_post_quant_optimizations: Disable all passes in python and G2G
            graph_optimizer_group.add_argument('--optimization_pass_mode', nargs='?', type=str, default="ir_optimizer_disable",
                                                help=argparse.SUPPRESS)
            # Flag to pass config to manage passes in a pipeline
            graph_optimizer_group.add_argument("--optimization_pass_mode_config",type=str,
                                        action=validation_utils.validate_filename_arg(must_exist=True),
                                        help=argparse.SUPPRESS, default="")
            # Dump the hardcoded passes in the pipelines
            graph_optimizer_group.add_argument('--dump_optimization_pass_mode_config', default=False, action="store_true",
                                                help=argparse.SUPPRESS)

        @classmethod
        def convert_args(cls, args):
            args_dict = vars(args).copy()
            return args_dict

    def __init__(self, args):
        self.dump_ir_optimizer_config_template = args.get('dump_ir_optimizer_config_template', False)
        self.backend_info_obj = BackendInfo.get_instance(args.get('backend', ""), args.get('soc_model', ""))
        self.optimizer_ctx = ir_optimizer.IrOptimizerContext()
        self.optimizer_ctx.dump_pass_trace_info = args.get('dump_pass_trace_info', False)
        self.optimization_pass_mode_manager = OptimizationPassModeManager.get_instance()

        # Dump ir optimizer config template
        if self.dump_ir_optimizer_config_template:
            self.dump_ir_optimizer_config_yaml_template()
            sys.exit(0)

        # Parse ir optimizer config
        if args.get('ir_optimizer_config', "") != "":
            self.parse_optimizer_config(args['ir_optimizer_config'])

        # Configure conditional passes
        self.configure_conditional_passes(args)

        # Apply pipeline mode based optimization filter
        self.apply_optimization_pass_mode(args)

        # Validate and initialize dump ir options.
        if args.get('dump_ir', "") != "":
            # Check the first value in the list to detect the Ir dump level.
            if args['dump_ir'][0] not in ['OPTIMIZER', 'PASS', 'OP']:
                raise Exception("Invalid Ir dump level passed, supported values"
                                "are, OPTIMIZER, PASS and OP")

            # Validate if pass names are passed only for PASS debugging.
            if args['dump_ir'][0] == 'OPTIMIZER' and len(args['dump_ir']) != 1:
                raise Exception("Passing individual pass names is only supported with"
                                "PASS dump")

            # Enable dump Ir option
            self.optimizer_ctx.dump_ir = True
            if args['dump_ir'][0] == 'OPTIMIZER':
                self.optimizer_ctx.dump_ir_level = ir_optimizer.IrDumpLevel.PreAndPostOptimization

            if args['dump_ir'][0] == 'PASS':
                self.optimizer_ctx.dump_ir_level = ir_optimizer.IrDumpLevel.PostPassChange
                if len(args['dump_ir']) != 1:
                    self.optimizer_ctx.dump_ir_pass_names = args['dump_ir'][1].split(",")
                    # Check whether the pass names passed in dump_ir is valid pass names.
                    registered_pass_names = ir_optimizer.get_registered_pass_names()
                    for pass_name in self.optimizer_ctx.dump_ir_pass_names:
                        if pass_name not in registered_pass_names:
                            raise Exception("Invalid pass name, ",pass_name," passed with"
                                            "dump_ir. No such pass is registered")

            if args['dump_ir'][0] == "OP":
                self.optimizer_ctx.dump_ir_level = ir_optimizer.IrDumpLevel.PostOpChange
                # Op level dumping requires mandatory operator types to be passed.
                if len(args['dump_ir']) == 1:
                    raise Exception("dump_ir OP level debugging requires operator "
                                    "types to passed.")

                self.optimizer_ctx.dump_ir_op_types = args['dump_ir'][1].split(",")

        # Validate optimization pipeline flag
        if args.get('optimization_pass_mode', "") != "":
            if args['optimization_pass_mode'] not in ['ir_optimizer_mainline', 'ir_optimizer_experimental', 'ir_optimizer_disable', 'disable_post_quant_optimizations']:
                raise Exception("Invalid optimization pipeline passed, supported values "
                                "ir_optimizer_mainline, ir_optimizer_experimental, ir_optimizer_disable, disable_post_quant_optimizations")

        # Dump optimization_pass_mode config
        if args.get('dump_optimization_pass_mode_config', False):
            self.dump_optimization_pass_mode_config()
            sys.exit(0)

    def parse_optimizer_config(self, optimizer_config):
        """
        This method reads the YAML config file and assigns the config to optimizer_context
        :return: None
        """
        with open(optimizer_config, 'r') as stream:
            try:
                yaml_data = yaml.safe_load(stream)
            except yaml.YAMLError as e:
                raise RuntimeError('Error parsing YAML config file') from e

        opt_config = {"enabledPasses":[], "disabledPasses":[]}

        # Not possible to directly update self.optimizer_ctx.optimizer_config due to Pybind11
        # restriction. Create a dict and copy it to optimizer context.
        for key, value in yaml_data.items():
            for opt, status in value.items():
                if not status.get("skip"):
                    opt_config["enabledPasses"].append(opt)
                else:
                    opt_config["disabledPasses"].append(opt)

        self.optimizer_ctx.optimizer_config = opt_config

    def dump_ir_optimizer_config_yaml_template(self):
        """
        Dumps the yaml template for Ir optimizer configuration. This file can be edited
        as per the custom requirements and passed using the option ir_optimizer_config
        :return: None
        """
        yaml_data = {}
        # Get common pass config. backend_info None will return common optimization passes
        backend_info = None
        common_optimizations = ir_optimizer.get_dump_pass_config(backend_info)
        if len(common_optimizations):
            yaml_data["common"] = {}
            for pass_name in common_optimizations:
                yaml_data["common"][pass_name] = {"skip": common_optimizations[pass_name]}

        # Get backend pass config if backend arg is passed.
        if self.backend_info_obj:
            backend_info = self.backend_info_obj.get_backend_info_ptr()
            backend_optimizations = ir_optimizer.get_dump_pass_config(backend_info)
            if len(backend_optimizations):
                yaml_data["backend"] = {}
                for pass_name in backend_optimizations:
                    yaml_data["backend"][pass_name] = {"skip": backend_optimizations[pass_name]}

        with open("ir_optimizer_config_template.yaml", 'w', encoding="utf-8") as file:
            yaml.dump(yaml_data, file, sort_keys=False)

    def configure_conditional_passes(self, args):
        """
        This method disables/enables the conditional passes based on the pass arguments provided and updates the optimizer_context config.
        The individual conditional pass argument options take precedence over the options provided through the ir_optimizer_config.
        For instance, if pass P is enabled as per ir optimizer config but the argument enable_pass_P is false, the pass will be disabled.
        :return: None
        """
        def enablePass(config, passName, doEnable):
            if doEnable:
                if passName in config["disabledPasses"]:
                    config["disabledPasses"].remove(passName)
                config["enabledPasses"].append(passName)
            else:
                if passName in config["enabledPasses"]:
                    config["enabledPasses"].remove(passName)
                config["disabledPasses"].append(passName)

        opt_config = self.optimizer_ctx.optimizer_config
        opt_config.setdefault("enabledPasses", [])
        opt_config.setdefault("disabledPasses", [])

        enablePass(opt_config, "OptimizeGatherND", args["enable_match_gathernd"])
        enablePass(opt_config, "HandleGatherNegativeIndices", args["handle_gather_negative_indices"])
        enablePass(opt_config, "AlignMatmulRanks", args["align_matmul_ranks"])

        self.optimizer_ctx.optimizer_config = opt_config

    def apply_optimization_pass_mode(self, args):
        """
        This method disables/enables the passes based on the pass mode arguments provided and updates the optimizer_context config.
        At this point ir_optimizer_config and condition_passes based filtering has been performed.
        Case a:
            If opt_pass_mode = ir_optimizer_disable OR disable_post_quant_optimizations
                ALL enabled_passes and mainline + expt passes will be moved to disabled --> Essentially no passes using CPP-IR will run
        Case b:
            If opt_pass_mode = ir_optimizer_mainline OR ir_optimizer_expt
                All passes in mainline pipeline (plus expt in case of ir_optimizer_expt) will be executed GIVEN it is not explicitely part of disable_passes already
                NOTE: Any extra passes already part of enabled_passes which is not part of mainline( plus expt in case of ir_optimizer_expt) WILL BE MOVED to disabled_passes

        One has to be CAUTIOUS while using this method along with ir_optimizer_config, because op_graph_optimization passes won't be enabled/disabled
        based on ir_optimizer_config.
        For Example:
            Pass "A" is part of mainline, but disabled with ir_optimizer_config.
            So "A" won't be run using op_graph_optimization, and also won't be run with CPP-IR -> causing a discrepancy between actual and expected.
        :return: None
        """
        opt_config = self.optimizer_ctx.optimizer_config
        opt_config.setdefault("enabledPasses", [])
        opt_config.setdefault("disabledPasses", [])

        opt_pass_mode = self.optimization_pass_mode_manager.passMode
        # Assuming there can be passes not part of mainline/expt for safety
        if(opt_pass_mode == OptPassMode.IrOptimizerDisable or
            opt_pass_mode == OptPassMode.DisablePostQuantOptimizations):
            disabled_set = set(opt_config["enabledPasses"])
            disabled_set.update([val[0] for val in self.optimization_pass_mode_manager.mainline_default.values()])
            disabled_set.update([val[0] for val in self.optimization_pass_mode_manager.expt_default.values()])
            disabled_set.update(opt_config["disabledPasses"])
            opt_config["enabledPasses"] = []
            opt_config["disabledPasses"] = list(disabled_set)

        elif(opt_pass_mode == OptPassMode.IrOptimizerMainline):
            # Remove those passes from mainline which are explicitely disabled
            enabled_set = set([val[0] for val in self.optimization_pass_mode_manager.mainline_final.values() if val[1]]) - set(opt_config["disabledPasses"])
            # Disable passes present in enabledPasses but not in mainline
            disabled_set = set(opt_config["enabledPasses"]) - enabled_set
            # Whichever passes disabled should still remain disabled
            disabled_set.update(opt_config["disabledPasses"])
            disabled_set.update([val[0] for val in self.optimization_pass_mode_manager.mainline_final.values() if not val[1]])
            # Any pass not part of mainline list to be disabled
            disabled_set.update((set([val[0] for val in self.optimization_pass_mode_manager.mainline_default.values()])|\
                                set([val[0] for val in self.optimization_pass_mode_manager.expt_default.values()]))\
                                 - enabled_set)
            opt_config["enabledPasses"] = list(enabled_set)
            opt_config["disabledPasses"] = list(disabled_set)

        elif(opt_pass_mode == OptPassMode.IrOptimizerExperimental):
            # Remove those passes from mainline + expt which are explicitely disabled
            enabled_set = (set([val[0] for val in self.optimization_pass_mode_manager.mainline_final.values() if val[1]])|\
                            set([val[0] for val in self.optimization_pass_mode_manager.expt_final.values() if val[1]])) -\
                             set(opt_config["disabledPasses"])
            # Disable passes present in enabledPasses but not in mainline + expt
            disabled_set = set(opt_config["enabledPasses"]) - enabled_set
            # Filtered out mainline and experimental passes using config should also be disabled
            disabled_set.update((set([val[0] for val in self.optimization_pass_mode_manager.mainline_default.values()])|\
                                set([val[0] for val in self.optimization_pass_mode_manager.expt_default.values()])) \
                                - enabled_set)
            # Whichever passes disabled should still remain disabled
            disabled_set.update(opt_config["disabledPasses"])
            disabled_set.update([val[0] for val in self.optimization_pass_mode_manager.mainline_final.values() if not val[1]])
            disabled_set.update([val[0] for val in self.optimization_pass_mode_manager.expt_final.values() if not val[1]])
            opt_config["enabledPasses"] = list(enabled_set)
            opt_config["disabledPasses"] = list(disabled_set)

        self.optimizer_ctx.optimizer_config = opt_config

    def dump_optimization_pass_mode_config(self):
        """
        Dumps the yaml template for optimization pipeline configuration. This file can be edited
        as per the custom requirements and passed using the option optimization_pass_mode_config
        :return: None
        """

        with open("optimization_pass_mode_config.yaml", 'w', encoding="utf-8") as file:
            comment_str = "# By Default ir_optimizer_mainline pipeline is used\n" \
                          "# Experimental passes can be tested in two way:\n" \
                          "# \t1. Pass ir_optimizer_experimental as the flag \n" \
                          "# \t2. Move the pass from experimental list to mainline line in this file\n\n"
            file.write(comment_str)
            yaml.dump(OptimizationPassInfo.passes_map, file, sort_keys=False)

    def optimize(self, ir_graph, optimization_stages:list):
        """
        This method optimizes the IR graph (inplace).
        :return: None
        """
        enabled_optimization_stages = [stage.value for stage in optimization_stages]
        backend_info = None
        # Create backed_info object if specific backend is passed to apply backend specific optimizations
        if self.backend_info_obj:
            backend_info = self.backend_info_obj.get_backend_info_ptr()
        optimizer = ir_optimizer.IrOptimizer(ir_graph, enabled_optimization_stages, backend_info, self.optimizer_ctx)

        # Apply optimizations(in place) on IrGraph
        optimizer.optimize()
