# =============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

from __future__ import annotations

import json
import tqdm
from pathlib import Path


class QwenTokenizer:
    def __init__(self, dir_model: Path, export_path: Path | None = None, export_tokenizer_json: bool = False) -> None:
        self.dir_model = dir_model
        self.export_path = export_path
        self.export_tokenizer_json = export_tokenizer_json

    # tiktoken allows representation of tokens as byte arrays and does not guarantee tokens to be valid UTF-8 bytes
    @staticmethod
    def token_bytes_to_string(b: bytes):
        from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
        byte_encoder = bytes_to_unicode()
        return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])

    # to generate BPE merges from tiktoken, for each token in vocab we iteratively try to merge consecutive sub-tokens
    # (initially starting with all consecutive byte pairs). If the newly merged sub-token is present in the vocab and
    # its token id is less than the original token (of which this merge is a sub-token) we add it to the merge list. If
    # at a given stage multiple sub-token pairs in consideration are present in the vocab, then we take the pair whose
    # merge gives us a token with the lowest token id
    @staticmethod
    def _extract(mergeable_ranks: dict[bytes, int], disable: bool = True) -> tuple[dict[str, int], list[tuple]]:
        merges = []
        vocab = {}
        for token, rank in tqdm.tqdm(mergeable_ranks.items(), total = len(mergeable_ranks), disable = disable):
            vocab[QwenTokenizer.token_bytes_to_string(token)] = rank
            if len(token) == 1:
                continue
            max_rank = rank
            pieces = [bytes([byte]) for byte in token]
            from itertools import count
            for _ in count():
                min_idx  = None
                min_rank = None
                current_merges = [(piece_l, piece_r) for piece_l, piece_r in zip(pieces[:-1], pieces[1:])]
                for idx in range(len(current_merges)):
                    merge      = current_merges[idx][0] + current_merges[idx][1]
                    rank_merge = mergeable_ranks.get(merge, None)
                    if rank_merge:
                        if min_rank is None or rank_merge < min_rank:
                            min_idx = idx
                            min_rank = rank_merge
                if min_rank is None:
                    break
                elif min_rank >= max_rank:
                    break
                assert min_idx is not None
                pieces[min_idx:min_idx + 2] = [pieces[min_idx] + pieces[min_idx + 1]]
            assert len(pieces) == 2
            merges.append((pieces[0], pieces[1], mergeable_ranks.get(pieces[0] + pieces[1])))
        merges = sorted(merges, key = lambda merge: merge[2])
        merges = [(QwenTokenizer.token_bytes_to_string(piece_l), QwenTokenizer.token_bytes_to_string(piece_r)) for (piece_l, piece_r, _) in merges]
        return vocab, merges

    def _create_qwen_bpe(self, disable: bool = True) -> dict[str, any]:
        dir_model = self.dir_model
        tokens: list[str] = []
        toktypes: list[int] = []

        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(str(dir_model), trust_remote_code = True)
        vocab_size = json.loads(open(str(dir_model / "config.json"), "rb").read())["vocab_size"]
        assert max(tokenizer.get_vocab().values()) < vocab_size

        vocab, merges = QwenTokenizer._extract(tokenizer.mergeable_ranks, disable = disable)

        added_vocab = tokenizer.special_tokens
        added_vocab = dict(sorted(added_vocab.items(), key = lambda x : x[1]))
        added_vocab = list(added_vocab.keys())

        if (len(added_vocab) + len(vocab)) < vocab_size:
            for i in range(vocab_size - (len(added_vocab) + len(vocab))):
                added_vocab.append(f"[PAD{i}]")

        # Create a Tokenizer object
        from tokenizers import Tokenizer, Regex, models, normalizers, decoders, pre_tokenizers, processors
        custom_tokenizer = Tokenizer(models.BPE(vocab = vocab, merges = merges))
        custom_tokenizer.add_special_tokens(added_vocab)

        custom_normalizer = normalizers.NFC()
        PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
        custom_pre_tokenizer = pre_tokenizers.Sequence([pre_tokenizers.Split(pattern = Regex(PAT_STR), behavior = "isolated", invert = False),
                                                        pre_tokenizers.ByteLevel(add_prefix_space = False, use_regex = False)])
        custom_post_processor = processors.ByteLevel(trim_offsets = False)
        custom_decoder = decoders.ByteLevel()

        custom_tokenizer.normalizer = custom_normalizer
        custom_tokenizer.pre_tokenizer = custom_pre_tokenizer
        custom_tokenizer.post_processor = custom_post_processor
        custom_tokenizer.decoder = custom_decoder

        custom_tokenizer = json.loads(custom_tokenizer.to_str())
        custom_tokenizer["pre_tokenizer"]["pretokenizers"][1]["trim_offsets"] = False
        custom_tokenizer["post_processor"]["add_prefix_space"] = False
        custom_tokenizer["post_processor"]["use_regex"] = False
        custom_tokenizer["decoder"]["add_prefix_space"] = False
        custom_tokenizer["decoder"]["trim_offsets"] = False
        custom_tokenizer["decoder"]["use_regex"] = False

        if self.export_tokenizer_json:
            if self.export_path is not None:
                json.dump(custom_tokenizer, open(str(self.export_path / "tokenizer.json"), "w", encoding = "utf-8"), indent = 2, ensure_ascii = False)
            else:
                raise NotADirectoryError('Output Directory not Specified')

        return custom_tokenizer

class BaichuanTokenizer:
    def __init__(self, dir_model: Path, export_path: Path | None = None, export_tokenizer_json: bool = False) -> None:
        self.export_path = export_path
        self.export_tokenizer_json = export_tokenizer_json
        self.model_path = dir_model / "tokenizer.model"
        from sentencepiece import SentencePieceProcessor
        self.sp = SentencePieceProcessor(str(self.model_path))

    # to generate BPE merges from sentencepiece, we take a cartesian product of the token list with itself. We then
    # eliminate all the tokens not in the vocab from the token list generated by the cartesian product. The list is
    # then sorted by token id to generate the final merges
    @staticmethod
    def _extract(mergeable_ranks: dict[str, int]) -> tuple[dict[str, int], list[tuple]]:
        # Create the BPE merges
        vocab = mergeable_ranks
        merges = []
        for piece_l in tqdm.tqdm(mergeable_ranks.keys(), total = len(mergeable_ranks)):
            merges.extend([(piece_l, piece_r, rank_merge) for piece_r in mergeable_ranks.keys() if (rank_merge := mergeable_ranks.get(piece_l + piece_r)) is not None])
        merges = sorted(merges, key = lambda merge: merge[2])
        merges = [(piece_l, piece_r) for (piece_l, piece_r, _) in merges]

        return vocab, merges

    def _create_baichuan_bpe(self) -> dict[str, any]:
        mergeable_ranks = {self.sp.id_to_piece(index): index for index in range(self.sp.GetPieceSize())}
        vocab, merges = BaichuanTokenizer._extract(mergeable_ranks)

        # Create the Tokenizer
        from tokenizers import Tokenizer, models, normalizers, decoders
        custom_tokenizer = Tokenizer(models.BPE(vocab = vocab, merges = merges))

        # Create the Normalizer and Decoder pipelines
        custom_normalizer = normalizers.Replace(" ", "▁")
        custom_decoder = decoders.Replace("▁", " ")

        custom_tokenizer.normalizer = custom_normalizer
        custom_tokenizer.decoder = custom_decoder

        custom_tokenizer = json.loads(custom_tokenizer.to_str())

        if self.export_tokenizer_json:
            if self.export_path is not None:
                json.dump(custom_tokenizer, open(str(self.export_path / "tokenizer.json"), "w", encoding = "utf-8"), indent = 2, ensure_ascii = False)
            else:
                raise NotADirectoryError('Output Directory not Specified')

        return custom_tokenizer

class CerebrasTokenizer:
    def __init__(self, dir_model: Path, export_path: Path | None = None, export_tokenizer_json: bool = False) -> None:
        self.export_path = export_path
        self.vocab_path = dir_model / "vocab.json"
        self.merges_path = dir_model / "merges.txt"
        self.export_tokenizer_json = export_tokenizer_json

    @staticmethod
    def _extract(vocab_path : Path, merges_path : Path, disable: bool = True) -> tuple[dict[str, int], list[tuple]]:
        vocab = json.loads(open(vocab_path, "rb").read())
        merge_list = open(merges_path, "r", encoding = "utf-8").read().split("\n")[1 : -1]
        merges = []
        for merge in tqdm.tqdm(merge_list, total = len(merge_list), disable = disable):
            piece_l, piece_r = merge.split(" ")
            merges.append((piece_l, piece_r))
        return vocab, merges

    def _create_cerebras_bpe(self, disable: bool = True) -> dict[str, any]:
        vocab, merges = CerebrasTokenizer._extract(self.vocab_path, self.merges_path, disable = disable)

        # Create the Tokenizer
        from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors
        custom_tokenizer = Tokenizer(models.BPE(vocab = vocab, merges = merges))

        # Create the Pre-Tokenizer, Post-Processor and Decoder pipelines
        custom_pre_tokenzier = pre_tokenizers.ByteLevel(add_prefix_space = False, trim_offsets = True)
        custom_post_processor = processors.ByteLevel(add_prefix_space = True, trim_offsets = False)
        custom_decoder = decoders.ByteLevel(add_prefix_space = True, trim_offsets = True)

        custom_tokenizer.pre_tokenizer = custom_pre_tokenzier
        custom_tokenizer.post_processor = custom_post_processor
        custom_tokenizer.decoder = custom_decoder

        custom_tokenizer = json.loads(custom_tokenizer.to_str())

        if self.export_tokenizer_json:
            if self.export_path is not None:
                json.dump(custom_tokenizer, open(str(self.export_path / "tokenizer.json"), "w", encoding = "utf-8"), indent = 2, ensure_ascii = False)
            else:
                raise NotADirectoryError('Output Directory not Specified')

        return custom_tokenizer

class MistralTokenizer:
    def __init__(self, dir_model: Path, export_path: Path | None = None, export_tokenizer_json: bool = False) -> None:
        self.export_path = export_path
        self.tokenizer_path = dir_model / "tokenizer.json"
        self.export_tokenizer_json = export_tokenizer_json

    def _create_mistral_tokenizer(self, disable: bool = True) -> dict[str, any]:
        custom_tokenizer = json.loads(open(self.tokenizer_path, "rb").read())
        # Delete the strip-space decoder block from the tokenizer's pipeline
        del custom_tokenizer["decoder"]["decoders"][3]
        if self.export_tokenizer_json:
            if self.export_path is not None:
                json.dump(custom_tokenizer, open(str(self.export_path / "tokenizer.json"), "w", encoding = "utf-8"), indent = 2, ensure_ascii = False)
            else:
                raise NotADirectoryError('Output Directory not Specified')
        return custom_tokenizer

class T5Tokenizer:
    def __init__(self, dir_model: Path, export_path: Path | None = None, export_tokenizer_json: bool = False) -> None:
        self.export_path = export_path
        self.export_tokenizer_json = export_tokenizer_json
        self.model_path = dir_model / "tokenizer.model"
        from sentencepiece import SentencePieceProcessor
        self.sp = SentencePieceProcessor(str(self.model_path))

    # to generate BPE merges from sentencepiece, we take a cartesian product of the token list with itself. We then
    # eliminate all the tokens not in the vocab from the token list generated by the cartesian product. The list is
    # then sorted by token id to generate the final merges
    @staticmethod
    def _extract(mergeable_ranks: dict[str, int]) -> tuple[dict[str, int], list[tuple]]:
        # Create the BPE merges
        vocab = mergeable_ranks
        merges = []
        for piece_l in tqdm.tqdm(mergeable_ranks.keys(), total = len(mergeable_ranks)):
            merges.extend([(piece_l, piece_r, rank_merge) for piece_r in mergeable_ranks.keys() if (rank_merge := mergeable_ranks.get(piece_l + piece_r)) is not None])
        merges = sorted(merges, key = lambda merge: merge[2])
        merges = [(piece_l, piece_r) for (piece_l, piece_r, _) in merges]

        return vocab, merges

    def _create_t5_bpe(self) -> dict[str, any]:
        mergeable_ranks = {self.sp.id_to_piece(index): index for index in range(self.sp.GetPieceSize())}
        vocab, merges = T5Tokenizer._extract(mergeable_ranks)

        # Create the Tokenizer
        from tokenizers import Tokenizer, models, pre_tokenizers, normalizers, decoders
        custom_tokenizer = Tokenizer(models.BPE(vocab = vocab, merges = merges))

        # Create the Pre-Tokenizer & Decoder pipelines
        custom_pre_tokenizer = pre_tokenizers.Sequence([pre_tokenizers.WhitespaceSplit(), pre_tokenizers.Metaspace(replacement = "▁", prepend_scheme = "always")])
        custom_decoder = decoders.Replace("▁", " ")

        custom_tokenizer.pre_tokenizer = custom_pre_tokenizer
        custom_tokenizer.decoder = custom_decoder

        custom_tokenizer = json.loads(custom_tokenizer.to_str())

        if self.export_tokenizer_json:
            if self.export_path is not None:
                json.dump(custom_tokenizer, open(str(self.export_path / "tokenizer.json"), "w", encoding = "utf-8"), indent = 2, ensure_ascii = False)
            else:
                raise NotADirectoryError('Output Directory not Specified')

        return custom_tokenizer

class FSMTTokenizer:
    def __init__(self, dir_model: Path, export_path: Path | None = None, export_tokenizer_json: bool = False) -> None:
        self.dir_model = dir_model
        self.export_path = export_path
        self.export_tokenizer_json = export_tokenizer_json
        self.vocab_path = dir_model / "vocab-src.json"
        self.merges_path = dir_model / "merges.txt"

    @staticmethod
    def _extract(vocab_path : Path, merges_path : Path, disable: bool = True) -> tuple[dict[str, int], list[tuple]]:
        with open(vocab_path, "r", encoding="utf-8") as f:
            vocab = json.load(f)

        with open(merges_path, "r", encoding="utf-8") as f:
            raw_merges = [line.strip() for line in f if line and not line.startswith("#")]

        # Filter merges to include only those where both tokens and their combination exists in vocab
        merges = []
        vocab_keys = set(vocab.keys())
        for line in raw_merges:
            parts = line.split()
            if len(parts) == 2 and all(part in vocab_keys for part in parts) and (parts[0] + parts[1] in vocab_keys):
                merges.append(tuple(parts))

        return vocab, merges

    def _create_fsmt_bpe(self) -> dict[str, any]:
        vocab, merges = FSMTTokenizer._extract(self.vocab_path, self.merges_path)

        # Create the Tokenizer
        from tokenizers import Tokenizer, models, pre_tokenizers, decoders
        custom_tokenizer = Tokenizer(models.BPE(vocab=vocab, merges=merges, end_of_word_suffix="</w>"))
        custom_tokenizer.add_special_tokens(["</s>"])

        # Create the Pre-Tokenizer & Decoder pipelines
        custom_tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
        custom_tokenizer.decoder = decoders.Replace("</w>", " ")

        custom_tokenizer = json.loads(custom_tokenizer.to_str())

        if self.export_tokenizer_json:
            if self.export_path is not None:
                json.dump(custom_tokenizer, open(str(self.export_path / "tokenizer.json"), "w", encoding = "utf-8"), indent = 2, ensure_ascii = False)
            else:
                raise NotADirectoryError('Output Directory not Specified')

        return custom_tokenizer
