//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#pragma once

#include <cstdio>
#include <cstdint>
#include <cstdlib>
#include <cstring>

#define GGUF_CHECK_ERROR_NE(fileRead, size)   \
  do {                                        \
    int x = fileRead;                         \
    if (x != static_cast<int>((size))) {      \
      goto exit;                              \
    }                                         \
  } while (0)

#define GGUF_CHECK_ERROR_EQ(cmd, error)       \
  do {                                        \
    int x = cmd;                              \
    if (x == static_cast<int>((error))) {     \
      goto exit;                              \
    }                                         \
  } while (0)

#define GGUF_KEY_ARCH     "general.architecture"
#define GGUF_KEY_DECODER  "cross_attention_decoder"

enum class GGUFValueType : int {
  UINT8    = 0,
  INT8     = 1,
  UINT16   = 2,
  INT16    = 3,
  UINT32   = 4,
  INT32    = 5,
  FLOAT32  = 6,
  BOOL     = 7,
  STRING   = 8,
  ARRAY    = 9,
  UINT64   = 10,
  INT64    = 11,
  FLOAT64  = 12
};

struct gguf_file {
  uint32_t magic;
  uint32_t version;
  uint64_t n_tensor;
  uint64_t n_kv;
};

static inline void ggufStringRead(FILE* fp, char** string) {
  uint64_t length;
  GGUF_CHECK_ERROR_NE(fread(&length, sizeof(uint64_t), 1, fp), 1);
  *string = static_cast<char *>(malloc(length + 1));
  GGUF_CHECK_ERROR_NE(fread(*string, sizeof(char), length, fp), length);
  (*string)[length] = '\0';

exit:
  return;
}

bool getIsCrossAttentionDecoder(const char* file_name) {
  // GGUF value type size lambda
  auto getGGUFValueTypeSize = [](GGUFValueType type) -> size_t {
    static const std::unordered_map<GGUFValueType, size_t> s_ggufValueTypeToSize {
      {GGUFValueType::UINT8,    sizeof(uint8_t)},
      {GGUFValueType::INT8,     sizeof(int8_t)},
      {GGUFValueType::UINT16,   sizeof(uint16_t)},
      {GGUFValueType::INT16,    sizeof(int16_t)},
      {GGUFValueType::UINT32,   sizeof(uint32_t)},
      {GGUFValueType::INT32,    sizeof(int32_t)},
      {GGUFValueType::FLOAT32,  sizeof(float)},
      {GGUFValueType::UINT64,   sizeof(uint64_t)},
      {GGUFValueType::INT64,    sizeof(int64_t)},
      {GGUFValueType::FLOAT64,  sizeof(double)},
      {GGUFValueType::BOOL,     sizeof(bool)},
      {GGUFValueType::STRING,   sizeof(char*)}
    };

    return s_ggufValueTypeToSize.contains(type) ? s_ggufValueTypeToSize.at(type)
                                                : std::numeric_limits<size_t>::max();
  };

  FILE* fp = fopen(file_name, "rb");
  if (!fp) {
    return false;
  }

  bool is_arch_key = false;
  bool is_cross_attn_decoder = false;
  struct gguf_file* f = static_cast<struct gguf_file*>(calloc(1, sizeof(struct gguf_file)));

  // Read header
  GGUF_CHECK_ERROR_NE(fread(&f->magic, sizeof(uint32_t), 1, fp), 1);
  GGUF_CHECK_ERROR_NE(fread(&f->version, sizeof(uint32_t), 1, fp), 1);
  GGUF_CHECK_ERROR_NE(fread(&f->n_tensor, sizeof(uint64_t), 1, fp), 1);
  GGUF_CHECK_ERROR_NE(fread(&f->n_kv, sizeof(uint64_t), 1, fp), 1);

  // Read key-value pairs
  for (size_t i = 0; i < f->n_kv; i++) {
    char* key = nullptr;
    enum GGUFValueType type;

    // Read key
    ggufStringRead(fp, &key);
    is_arch_key = !strcmp(key, GGUF_KEY_ARCH);
    if (key) { free(key); }

    // Read value type
    GGUF_CHECK_ERROR_NE(fread(&type, sizeof(type), 1, fp), 1);

    // Read value / Seek
    switch (type) {
      case GGUFValueType::STRING: {
        char *value;
        ggufStringRead(fp, &value);
        if (is_arch_key) {
          is_cross_attn_decoder = !strcmp(value, GGUF_KEY_DECODER);
        }
        free(value);
        break;
      }
      case GGUFValueType::ARRAY: {
        enum GGUFValueType array_type;
        GGUF_CHECK_ERROR_NE(fread(&array_type, sizeof(array_type), 1, fp), 1);
        uint64_t length;
        GGUF_CHECK_ERROR_NE(fread(&length, sizeof(length), 1, fp), 1);
        const size_t size = getGGUFValueTypeSize(array_type);
        if (array_type == GGUFValueType::STRING) {
          for (size_t j = 0; j < length; j++) {
            uint64_t str_length;
            GGUF_CHECK_ERROR_NE(fread(&str_length, sizeof(uint64_t), 1, fp), 1);

            const int64_t offset = static_cast<int64_t>(sizeof(char) * str_length);
            GGUF_CHECK_ERROR_EQ(fseek(fp, offset, SEEK_CUR), (-1));
          }
        } else {
          const int64_t offset = static_cast<int64_t>(size * length);
          GGUF_CHECK_ERROR_EQ(fseek(fp, offset, SEEK_CUR), (-1));
        }
        break;
      }
      case GGUFValueType::UINT8:
      case GGUFValueType::INT8:
      case GGUFValueType::UINT16:
      case GGUFValueType::INT16:
      case GGUFValueType::UINT32:
      case GGUFValueType::INT32:
      case GGUFValueType::FLOAT32:
      case GGUFValueType::BOOL:
      case GGUFValueType::UINT64:
      case GGUFValueType::INT64:
      case GGUFValueType::FLOAT64: {
        const int64_t offset = static_cast<int64_t>(getGGUFValueTypeSize(type));
        GGUF_CHECK_ERROR_EQ(fseek(fp, offset, SEEK_CUR), (-1));
        break;
      }
      default: {
        return false;
      }
    }

    if (is_arch_key) {
      break;
    }
  }

exit:
  free(f);
  fclose(fp);
  return is_cross_attn_decoder;
}
