#!/usr/bin/env python3
# =============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

"""

Checks and installs Python dependencies for the QAIRT SDK
Currently supported: Python 3.8 and 3.10

"""

import argparse
import os
import pkg_resources
import re
import subprocess
import sys

from typing import Dict


class QAIRTEnvironmentError(Exception):

    def __str__(self):
        return f"[ERROR] {self.args[0]}"


class SupportedVersion:
    version: str
    overrides: Dict[str, str]

    def __init__(self, ver):
        self.version = ver
        self.overrides = {}


def get_ubuntu_version():
    if os.path.exists("/etc/os-release"):
        with open("/etc/os-release") as f:
            for line in f:
                if line.startswith("VERSION_ID="):
                    return line.strip().split("=")[1].strip('"')
    raise FileNotFoundError("/etc/os-release was not found.")


def normalize_package_name(pkg):
    return re.sub(r"[-_.]+", "-", pkg).lower()


def get_installed_versions(pkg_list, pip_list):
    installed_versions = []
    missing_val = "MISSING"
    missing_packages = []
    for package in pkg_list.keys():
        normalized = normalize_package_name(package)
        if normalized in pip_list:
            installed_versions.append(pip_list[normalized])
        else:
            installed_versions.append(missing_val)
            missing_packages.append(package)
    return installed_versions, missing_packages


def print_package_summary(pkg_list, installed_versions):
    headers = ["Package", "Recommended", "Installed"]
    columns = [pkg_list.keys(), pkg_list.values(), installed_versions]
    widths = [max(len(header), max([len(y) for y in column])) for header, column in zip(headers, columns)]
    print('\t'.join([f'{x.ljust(y)}' for x, y in zip(headers, widths)]))
    print('\t'.join(['~' * x for x in widths]))
    for row in zip(*columns):
        print('\t'.join([f'{x.ljust(y)}' for x, y in zip(row, widths)]))


def get_pkgs_to_install(pkg_list, installed_pkgs):
    """
    Gets list of packages to install. Packages that are already installed are skipped. If version for installed package
    does not match the expected package, warning statement is printed and package is still skipped
    :param pkg_list: list of packages to vet for installation
    :param installed_pkgs: packages already installed in the current environment
    :return: a dictionary of packages to install with pkg_name:pkg_version as key:value
    """
    to_be_installed_pkgs = {}
    for required_package_name in pkg_list.keys():
        if required_package_name in installed_pkgs.keys():
            recommended_package_version = pkg_list[required_package_name]
            installed_package_version = installed_pkgs[required_package_name]
            if recommended_package_version == installed_package_version:
                print(
                    f"{required_package_name} is already installed with tested version: {recommended_package_version}"
                )
            else:
                print(
                    f"WARNING: {required_package_name} installed version: {installed_package_version} "
                    f"does not match tested version: {recommended_package_version}"
                )

        else:
            to_be_installed_pkgs[required_package_name] = pkg_list[
                required_package_name
            ]

    return to_be_installed_pkgs


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", "--dry-run", required=False, action='store_true',
                        help="perform the check but dont install")
    parser.add_argument("-o", "--with_optional", required=False, action='store_true',
                        help="include installation of optional dependencies")

    args = parser.parse_args()

    version38 = SupportedVersion("3.8")
    version38.overrides['absl-py'] = '0.13.0'
    version38.overrides['attrs'] = '22.2.0'
    version38.overrides['decorator'] = '5.1.1'
    version38.overrides['invoke'] = '2.0.0'
    version38.overrides['joblib'] = '1.0.1'
    version38.overrides['jsonschema'] = '4.19.0'
    version38.overrides['lxml'] = '4.6.2'
    version38.overrides['mock'] = '5.1.0'
    version38.overrides['numpy'] = '1.23.5'
    version38.overrides['opencv-python'] = '4.5.2.52'
    version38.overrides['optuna'] = '3.3.0'
    version38.overrides['packaging'] = '21.0'
    version38.overrides['pandas'] = '2.0.2'
    version38.overrides['paramiko'] = '3.5.1'
    version38.overrides['pillow'] = '6.2.1'
    version38.overrides['plotly'] = '5.16.1'
    version38.overrides['psutil'] = '5.9.5'
    version38.overrides['pytest'] = '7.0.1'
    version38.overrides['pyyaml'] = '3.10'
    version38.overrides['scipy'] = '1.9.1'
    version38.overrides['tabulate'] = '0.8.5'
    version38.overrides['typing-extensions'] = '4.6.2'
    version38.overrides['xlsxwriter'] = '3.1.9'

    version310 = SupportedVersion("3.10")

    supported_python_versions = {
        version38.version: version38,
        version310.version: version310
    }

    if sys.platform.startswith('win'):
        version310.overrides["matplotlib"] = '3.5.0'
        version310.overrides['pandas'] = '2.0.2'

    current_python_version = f"{sys.version_info[0]}.{sys.version_info[1]}"

    # Verify the currently installed Python version
    try:
        # For Linux, we support specific Python versions depending on the host Ubuntu version
        if sys.platform.startswith('linux'):

            ubuntu_to_python_version = {
                "22.04": version310.version,
                "20.04": version38.version
            }

            ubuntu_version = get_ubuntu_version()
            if ubuntu_version not in ubuntu_to_python_version:
                raise QAIRTEnvironmentError(f"Unsupported Ubuntu version {ubuntu_version}.")

            if current_python_version != ubuntu_to_python_version[ubuntu_version]:
                raise QAIRTEnvironmentError(
                    f"Python {current_python_version} is unsupported on Ubuntu {ubuntu_version}."
                    f" Please use Python {ubuntu_to_python_version[ubuntu_version]} instead.")
        else:
            # Just verify that the Python version is one of the supported versions
            if current_python_version not in supported_python_versions:
                raise QAIRTEnvironmentError(f"Unsupported Python version {current_python_version}. "
                                            f"Please use one of: {', '.join(supported_python_versions.keys())}")

    except QAIRTEnvironmentError as e:
        print(e)
        exit()

    if ((('VIRTUAL_ENV' not in os.environ) and
         ('CONDA_DEFAULT_ENV' not in os.environ))):
        print(
            "Error: Script should only be execute in virtual environment, "
            "please check you had activated the python virtual environment"
        )
        exit()

    required_packages = {
        "safetensors": "0.4.3",
        "absl-py": "2.1.0",
        "aenum": "3.1.15",
        "attrs": "23.2.0",
        "dash": "2.12.1",
        "decorator": "4.4.2",
        "invoke": "1.7.3",
        "joblib": "1.4.0",
        "jsonschema": "4.19.0",
        "lxml": "5.2.1",
        "mako": "1.1.0",
        "matplotlib": "3.3.4",
        "mock": "3.0.5",
        "numpy": "1.26.4",
        "opencv-python": "4.5.4.58",
        "optuna": "3.3.0",
        "packaging": "24.0",
        "pandas": "2.0.1",
        "paramiko": "3.5.1",
        "pathlib2": "2.3.6",
        "pillow": "10.2.0",
        "plotly": "5.20.0",
        "protobuf": "3.19.6",
        "psutil": "6.1.1",
        "pydantic": "2.7.4",
        "pytest": "8.1.1",
        "pyyaml": "5.3",
        "rich": "13.9.4",
        "scikit-optimize": "0.9.0",
        "scipy": "1.10.1",
        "six": "1.16.0",
        "tabulate": "0.9.0",
        "typing-extensions": "4.14.0",
        "xlsxwriter": "1.2.2"
    }

    optional_packages = {
        "qairt-visualizer": "0.5.1"
    }

    for lib, version in supported_python_versions[current_python_version].overrides.items():
        required_packages[lib] = version

    installed_packages = {}
    for installed_package in pkg_resources.working_set:
        installed_package_name = installed_package.key
        if installed_package_name not in list(required_packages.keys()) + list(optional_packages.keys()):
            continue
        installed_packages[installed_package.key] = installed_package.version

    required_packages = dict(sorted(required_packages.items()))
    to_be_installed_packages = {}
    if args.with_optional:
        optional_packages = dict(sorted(optional_packages.items()))
        to_be_installed_packages.update(get_pkgs_to_install(optional_packages, installed_packages))
    to_be_installed_packages.update(get_pkgs_to_install(required_packages, installed_packages))

    if len(to_be_installed_packages) > 0:
        if not args.dry_run:
            print(f"Python Modules missing: {', '.join(map(str, to_be_installed_packages.keys()))}")
            subprocess.check_call(
                [
                    sys.executable,
                    '-m',
                    "pip",
                    "install",
                    "--upgrade",
                    "pip"
                ],
                stdout=subprocess.DEVNULL,
            )

            print("Installing missing modules using pip3")

            for pkg_name, pkg_ver in to_be_installed_packages.items():
                print(f"Installing {pkg_name} version: {pkg_ver}")
                try:
                    subprocess.check_call(
                        [sys.executable, '-m', "pip", "install", f"{pkg_name}=={pkg_ver}"],
                        stdout=subprocess.DEVNULL,
                    )
                except subprocess.CalledProcessError:
                    print(f"Error: Failed to install {pkg_name}:{pkg_ver}")

    pip_list = [x.split() for x in subprocess.check_output(
        [sys.executable, '-m', 'pip', 'list']).decode('utf-8').split('\n')[2:]]
    pip_list = {normalize_package_name(x[0]): x[1] for x in pip_list if x}

    print("\nSummary:")
    print("Required Packages:\n")
    installed_versions, required_packages_missing = get_installed_versions(required_packages, pip_list)
    print_package_summary(required_packages, installed_versions)

    print("\nOptional Packages:")
    installed_versions, optional_package_missing = get_installed_versions(optional_packages, pip_list)
    if not args.with_optional and optional_package_missing:
        print("To install optional packages, pass -o|--with_optional\n")
    print_package_summary(optional_packages, installed_versions)

    if required_packages_missing:
        # Only throw runtime error if there are required packages missing
        raise RuntimeError("Missing required package(s): " + ", ".join(required_packages_missing))


if __name__ == "__main__":
    main()
