#!/usr/bin/env python3
# =============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================
"""
Please note that this script should be compilable with python2 as well.
The reason behind is that this script is mostly used python2 which we would
error back saying that the expected python version.
Please use "python2 -m py_compile <currentFilePath>" to see if it broke.
"""
import os
import re
import sys
from typing import Dict

class SupportedVersion:
    version: str
    overrides: Dict[str, str]

    def __init__(self, ver):
        self.version = ver
        self.overrides = {}

supportedPythonVersions = {}
version38 = SupportedVersion("3.8")
version38.overrides['absl-py'] = '0.13.0'
version38.overrides['attrs'] = '22.2.0'
version38.overrides['decorator'] = '5.1.1'
version38.overrides['invoke'] = '2.0.0'
version38.overrides['joblib'] = '1.0.1'
version38.overrides['jsonschema'] = '4.19.0'
version38.overrides['lxml'] = '4.6.2'
version38.overrides['mock'] = '5.1.0'
version38.overrides['numpy'] = '1.23.5'
version38.overrides['opencv-python'] = '4.5.2.52'
version38.overrides['optuna'] = '3.3.0'
version38.overrides['packaging'] = '21.0'
version38.overrides['pandas'] = '2.0.2'
version38.overrides['paramiko'] = '3.3.1'
version38.overrides['pillow'] = '6.2.1'
version38.overrides['plotly'] = '5.16.1'
version38.overrides['psutil'] = '5.9.5'
version38.overrides['pytest'] = '7.0.1'
version38.overrides['pyyaml'] = '3.10'
version38.overrides['scipy'] = '1.9.1'
version38.overrides['tabulate'] = '0.8.5'
version38.overrides['typing-extensions'] = '4.6.2'
version38.overrides['xlsxwriter'] = '3.1.9'

version310 = SupportedVersion("3.10")

supportedPythonVersions[version38.version] = version38
supportedPythonVersions[version310.version] = version310
if sys.platform.startswith('win'):
    supportedPythonVersions[version38.version] = version38
    version310.overrides["matplotlib"] = '3.5.0'
    version310.overrides['pandas'] = '2.0.2'

currentVersion = f"{sys.version_info[0]}.{sys.version_info[1]}"
if currentVersion != version38.version:
    print(
        f'Warning:  Supported versions of Python are [3.8]. Found instead: {currentVersion}'
    )
if (currentVersion not in supportedPythonVersions):
    print(
        "Error: Supported versions of Python are [{0}]. Found instead: {1}".format(
            ', '.join(supportedPythonVersions.keys()),
            currentVersion,
        )
    )
    exit()

if ((('VIRTUAL_ENV' not in os.environ) and
     ('CONDA_DEFAULT_ENV' not in os.environ))):
    print(
        "Error: Script should only be execute in virtual environment, please check you had activiated the python virtual environment"
    )
    exit()

import subprocess
import pkg_resources

requiredPackages = {
    "safetensors":"0.4.3",
    "absl-py": "2.1.0",
    "attrs": "23.2.0",
    "dash": "2.12.1",
    "decorator": "4.4.2",
    "invoke": "1.7.3",
    "joblib": "1.4.0",
    "jsonschema": "4.19.0",
    "lxml": "5.2.1",
    "mako": "1.1.0",
    "matplotlib": "3.3.4",
    "mock": "3.0.5",
    "numpy": "1.26.4",
    "opencv-python": "4.5.4.58",
    "optuna": "3.3.0",
    "packaging": "24.0",
    "pandas": "2.0.1",
    "paramiko": "3.4.0",
    "pathlib2": "2.3.6",
    "pillow": "10.2.0",
    "plotly": "5.20.0",
    "protobuf": "3.19.6",
    "psutil": "5.6.4",
    "pydantic": "2.7.4",
    "pytest": "8.1.1",
    "pyyaml": "5.3",
    "scikit-optimize": "0.9.0",
    "scipy": "1.10.1",
    "six": "1.16.0",
    "tabulate": "0.9.0",
    "typing-extensions": "4.10.0",
    "xlsxwriter": "1.2.2"
}

for lib, version in supportedPythonVersions[currentVersion].overrides.items():
    requiredPackages[lib] = version

installedPackages = {}
for installedPackage in pkg_resources.working_set:
    installedPackageName = installedPackage.key
    if not installedPackageName in requiredPackages.keys():
        continue
    installedPackages[installedPackage.key] = installedPackage.version

toBeInstalledPackages = {}
for requiredPackageName in requiredPackages.keys():
    if requiredPackageName in installedPackages.keys():
        recommendedPackageVersion = requiredPackages[requiredPackageName]
        installedPackageVersion = installedPackages[requiredPackageName]
        if recommendedPackageVersion == installedPackageVersion:
            print(
                "{0} is already installed with tested version: {1}".format(
                    requiredPackageName, recommendedPackageVersion
                )
            )
        else:
            print(
                "WARNING: {0} installed version: {1} does not match tested version: {2}".format(
                    requiredPackageName,
                    installedPackageVersion,
                    recommendedPackageVersion,
                )
            )
    else:
        toBeInstalledPackages[requiredPackageName] = requiredPackages[
            requiredPackageName
        ]

if len(toBeInstalledPackages) > 0:
    print(
        "Python Modules missing: {0}".format(
            ", ".join(map(str, toBeInstalledPackages.keys()))
        )
    )

    subprocess.check_call(
        [
            sys.executable,
            '-m',
            "pip",
            "install",
            "--upgrade",
            "pip"
        ],
        stdout=subprocess.DEVNULL,
    )
    print("Installing missing modules using pip3")
    for toBeInstalledPackage in toBeInstalledPackages:
        print(
            "Installing {0} version: {1}".format(
                toBeInstalledPackage, toBeInstalledPackages[toBeInstalledPackage]
            )
        )
        subprocess.check_call(
            [
                sys.executable,
                '-m',
                "pip",
                "install",
                toBeInstalledPackage
                + "=="
                + toBeInstalledPackages[toBeInstalledPackage],
            ],
            stdout=subprocess.DEVNULL,
        )

def normalize_package_name(pkg):
    return re.sub(r"[-_.]+", "-", pkg).lower()

final_state = [x.split() for x in subprocess.check_output([sys.executable, '-m', 'pip', 'list']).decode('utf-8').split('\n')[2:]]
final_state = {normalize_package_name(x[0]): x[1] for x in final_state if x}
package_names = sorted(requiredPackages.keys())
required_versions = [requiredPackages[x] for x in package_names]

installed_versions = []
missing_val = "MISSING"
missing_packages = []
for package in package_names:
    normalized = normalize_package_name(package)
    if normalized in final_state:
        installed_versions.append(final_state[normalized])
    else:
        installed_versions.append(missing_val)
        missing_packages.append(package)

headers = ["Package", "Recommended", "Installed"]
columns = [package_names, required_versions, installed_versions]
widths = [max(len(header), max([len(y) for y in column])) for header, column in zip(headers, columns)]
print("\nSummary:\n")
print('\t'.join([f'{x.ljust(y)}' for x, y in zip(headers, widths)]))
print('\t'.join(['~' * x for x in widths]))
for row in zip(*columns):
    print('\t'.join([f'{x.ljust(y)}' for x, y in zip(row, widths)]))

if missing_packages:
    raise RuntimeError("Missing required package(s): " + ", ".join(missing_packages))
