from collections import Counter
from itertools import product
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
import pandas as pd
from hdbscan import HDBSCAN
from joblib import Parallel, delayed
from sklearn import ensemble as ens
from sklearn.metrics import (
from tqdm import tqdm

from justclust.hopkins import hopkins

# %%

def get_algclusters(preprocs=None):
    # define the cluster algorithm to use/test
    models = {
        # "KMEANS": {
        #     "cls": clu.KMeans,
        #     "opts": [
        #         dict(
        #             n_clusters=k,
        #             n_init=20,
        #             max_iter=600,
        #             random_state=23,
        #             ordered_feature_names=selected.columns,
        #             feature_importance_method="wcss_min",
        #         )
        #         for k in range(2, 25)
        #     ],
        #     "pre-processing": [
        #         # PowerTransformer(), # using it_ hopkins is reduced to 0.6884 from 0.9273
        #         pre.RobustScaler(
        #             with_centering=True,
        #             with_scaling=True,
        #             quantile_range=(20, 80),
        #         ),
        #         apply_weight,
        #         dec.DictionaryLearning(fit_algorithm="cd", alpha=0.1, n_jobs=-1),
        #     ],
        #     "label": lambda x: (f"kmeans__k-{x['n_clusters']:02d}"),
        #     "title": lambda x: (f"KMeans: k: {x['n_clusters']:2d} "),
        # },
        # "DBSCAN": {
        #     "cls": clu.DBSCAN,
        #     "opts": [
        #         {"eps": eps, "min_samples": ms}
        #         for eps, ms in product(
        #             (0.1, 0.25, 0.5, 0.75, 0.9), (1, 3, 5, 10, 15, 20)
        #         )
        #     ],
        #     "pre-processing": [
        #         pre.RobustScaler(
        #             with_centering=True,
        #             with_scaling=True,
        #             quantile_range=quantile_range,
        #         ),
        #         apply_weight,
        #         dec.DictionaryLearning(fit_algorithm="cd", alpha=0.1, n_jobs=-1),
        #     ],
        #     "label": lambda x: (
        #         f"dbscan__eps-{x['eps']:3.2f}_"
        #         f"ms-{x['min_samples'] if x['min_samples'] else 0:02d}"
        #     ),
        #     "title": lambda x: (
        #         f"DBSCAN: eps: {x['eps']:3.1f} " f"min_samples: {x['min_samples']}"
        #     ),
        # },
        "HDBSCAN": {
            "cls": HDBSCAN,
            "opts": [
                    "min_cluster_size": mcs,
                    "min_samples": ms,
                    "metric": mtr,
                    "p": 0.01 if mtr == "minkowski" else None,
                    "cluster_selection_method": csm,
                for mcs, ms, mtr, csm in product(
                    [2, 3, 5, 7, 10, 12, 15],  # min_cluster_size
                    [None, 1, 2, 3, 5],  # min_samples
                        # "haversine", # only 2D
                        # "cityblock", # same as manhattan
                        # "l1",   # same as manhattan
                        # "l2",   # same as euclidean
                        # "dice",   # only boolean vectors
                        # "hamming",   # only boolean vectors
                        # "jaccard",   # only boolean vectors
                        # "kulsinski",   # only boolean vectors
                        # "mahalanobis", # Must provide either V or VI for Mahalanobis distance
                        # "rogerstanimoto",  # only boolean vectors
                        # "russellrao",  # only boolean vectors
                        # "seuclidean",
                        # "sokalmichener",  # only boolean vectors
                        # "sokalsneath",  # only boolean vectors
                        # "yule", # only boolean vectors
                    ],  # metric
                    ["eom", "leaf"],  # cluster_selection_method
            "pre-processing": preprocs if preprocs is not None else [],
            "label": lambda x: (
                f"ms-{x['min_samples'] if x['min_samples'] else 0:02d}_"
            "title": lambda x: (
                f"Scaled/Weighted HDBSCAN: min_cluster_size: {x['min_cluster_size']} "
                f"min_samples: {x['min_samples']}"
                f"metric: {x['metric']}"
                f"csm: {x['cluster_selection_method']}"
        # "BIRCH": {
        #     "cls": clu.Birch,
        #     "opts": [
        #         dict(
        #             threshold=thr,
        #             branching_factor=bf,
        #             n_clusters=None,
        #         )
        #         for thr, bf, k in product(
        #             (0.1, 0.25, 0.5, 0.75, 0.9, 1.15, 1.5),
        #             (25, 50, 100),
        #             (None, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 15, 20, 25),
        #         )
        #     ],
        #     "pre-processing": [
        #         pre.RobustScaler(
        #             with_centering=True,
        #             with_scaling=True,
        #             quantile_range=quantile_range,
        #         ),
        #         apply_weight,
        #         dec.DictionaryLearning(fit_algorithm="cd", alpha=0.1, n_jobs=-1),
        #     ],
        #     "label": lambda x: (
        #         f"birch__t-{x['threshold']:4.2f}_"
        #         f"b-{x['branching_factor']}_"
        #         f"k-{x['n_clusters']}"
        #     ),
        #     "title": lambda x: (
        #         f"Birch t: {x['threshold']:4.2f} "
        #         f"b: {x['branching_factor']} "
        #         f"k: {x['n_clusters']}"
        #     ),
        # },
    return models

# %%
# Explore clustering algorithms

def feature_importances(
    data: pd.DataFrame, 
    labels: List[int], 
    label: str, 
    clf: Callable = None,
    max_depth = 10
) -> pd.DataFrame:
    """Compute feature importance.


    data : pd.DataFrame
        DataFrame with the transformed data used for clustering
    labels : List[int]
        Label with the assigned cluster for each element in the data
    label : str
        name of the column that will be added in the returned DataFrame
    clf: Callable
        Instance with the `fit` method that compute a
        `feature_importances_` attribute, if not defined by the user
        the RandomForestClassifier is used.

        DataFrame with the feature importances computed
    if clf is None:
        clf = ens.RandomForestClassifier(max_depth=max_depth, n_estimators=500, random_state=1), labels)
    return pd.DataFrame(
        index=pd.Series(data.columns, name="features"),
    ).sort_values(label, ascending=False)

def compute_metrics(
    model, data: pd.DataFrame, labels: np.array, alpha_k: float = 0.02
) -> Tuple[
    k = len(set(labels)) - 1
    # avoid to compute metrics for outliers, rm: -1
    valid = labels >= 0
    labels = labels[valid]
    data = data.loc[valid, :]
    if hasattr(model, "inertia_"):
        inertia_o = np.square((data.values - data.values.mean(axis=0))).sum()
        inertia = model.inertia_
        scaledi = inertia / inertia_o + alpha_k * k
        inertia_o = None
        inertia = None
        scaledi = None

    if k <= (len(labels) - 1) and k >= 2:
        # Number of labels is 1. Valid values are 2 to n_samples - 1
        sil = silhouette_score(data, labels)
        dvb = davies_bouldin_score(data, labels)
        cal = calinski_harabasz_score(data, labels)
        return k, inertia, scaledi, sil, dvb, cal
    return k, inertia, scaledi, None, None, None

def exec_model(
    model: Dict[str, Any], opts: Dict[str, Any], data: pd.DataFrame
) -> Tuple[Any, List[int]]:
    """Execute the single model to the data

    model : Dict[str, Any]
        Dictionary containing a `cls` key with the class
        of the cluster algorithm that need to be applied;
        Optional key is `post-processing` that contain a
        list of callable instances taking as input
        * the algorithm cluster instance
        * the transformed data as pd DataFrame
        * the cluster labels as List[int]
        that are applied after the execution of the
        clustering process.
    opts : Dict[str, Any]
        Arguments to be used to instantiate the cluster
        algorithm `model["cls"](**opts)`.
    data : pd.DataFrame
        DataFrame with the transformed data to be used
        through the `fit_predict` method of the model

    Tuple[Any, List[int]]
    # apply the model
    clst = model["cls"](**opts)
    labels = clst.fit_predict(data)

    # apply all the post-processing actions
    for post in model.get("post-processing", []):
        if post is not None:
            post(clst, data, labels)
    return clst, labels

def cls_counter(labels: List[int]) -> Tuple[pd.DataFrame, float]:
    """Count the number of clusters found

    labels : List[int]
        List containing the cluster labels assigned to each
        element. Negative value are classified as outliers
        or not belonging to any cluster.

    Tuple[pd.DataFrame, float]
        * DataFrame with the number of statistical units count
        * the float with the percentage of number of units that
        have been assigned to a cluster.
    clcount = pd.DataFrame(Counter(labels).most_common(), columns=["cluster", "count"])
    tot = clcount["count"].sum()
    clcount["%"] = clcount["count"].astype(float) * 100.0 / tot
    perc = clcount.loc[clcount["cluster"] >= 0, "%"].sum()
    clcount.sort_values(by="%", ascending=False, inplace=True)
    if round(perc, 2) > 100.0:
        print("#" * 60)
        print(f"WARNING: percentage > 100%: {perc!r}")
        print("#" * 60)
    return clcount, perc

def model_worker(
    mname: str,
    model: Dict[str, Any],
    opts: Dict[str, Any],
    tdata: pd.DataFrame,
    rdir: Path = None,
) -> Tuple[str, str, List[Any], pd.DataFrame]:
    """Excute the single model and compute the main scores

    mname : str
        Name of the algorithm in use.
    model : Dict[str, Any]
        Dictionary with the main options and parameters
    opts : Dict[str, Any]
        Dictionary with the parameters to be used by the model
    tdata : pd.DataFrame
        Transformed and balanced data, ready to be used as input
        for the clustering task.

    Tuple[str, List[int], List[Any], pd.DataFrame]
        * the first value is the string that is used as column name
        * the result of the cluster with the assigned labels for
        each row of the transformed-data
        * list containing the scores and main features of the cluster
        * DataFrame containing the feature importance of each cluster
    label = model["label"](opts)
    clst, labels = exec_model(model, opts, tdata)
    score = compute_metrics(clst, tdata, labels, alpha_k=0.02)
    clcount, perc = cls_counter(labels)
    if score[0] > 1:
        n_of_cls1 = (clcount["count"] == 1).sum()
        mean_el_per_cl = clcount.loc[0:, "count"].mean()
        std_el_per_cl = clcount.loc[0:, "count"].std()
        n_of_cls1 = None
        mean_el_per_cl = None
        std_el_per_cl = None
    long_label = f"k{len(clcount):03d}_perc{perc:05.1f}_{label}"
    xcores = [
    ] + [s for s in score]
    if rdir is not None:
        clcount.to_excel(rdir / f"{long_label}.xlsx")
    fimp = feature_importances(tdata, labels, label)
    return label, labels, xcores, fimp
def apply_filter(
    sd: Dict[str, int | float], filters: List[Tuple[str, Tuple[float, float]]]
) -> List[bool]:
    """Check conditions to a dictionary. Return a list of booleans for each condition.

    sd : Dict[str, int  |  float]
        Dictionary with the values to be filtered
    filters : List[Tuple[str, Tuple[float, float]]]
        List of rules and checked to be verified

        List of booleans with the condition that are satisfied

    >>> apply_filter(
    ...     dict(a=10, b=50, c=100),
    ...     [
    ...          ("a", (5, None)),
    ...          ("b", (40, 60)),
    ...          ("c", (None, 100)),
    ...     ]
    ... )
    [True, True, True]
    >>> apply_filter(
    ...     dict(a=10, b=50, c=100),
    ...     [
    ...          ("a", (15, None)),
    ...          ("b", (60, 80)),
    ...          ("c", (None, 90)),
    ...     ]
    ... )
    [False, Flase, False]
    >>> apply_filter(
    ...     dict(a=None, b=50, c=100),
    ...     [
    ...          ("a", (15, None)),
    ...          ("b", (60, 80)),
    ...          ("c", (None, 90)),
    ...     ]
    ... )
    [False, Flase, False]
    appends = []
    for fname, (fmin, fmax) in filters:
        val = sd[fname]
        if val is None:
            val = np.nan

        if fmin is not None:
            # fmin is defined check min value
            if val >= fmin:
                if fmax is not None:
                    # fmax is defined check max value
                    if val <= fmax:
                    # ignore max value
                # fmin value condition not correct
            # fmin not defined
            if fmax is not None:
                # fmax is defined
                if val <= fmax:
                    # fmax is valid
                # fmax is not defined
    return appends

def explore_models(
    data: pd.DataFrame,
    cldf: gpd.GeoDataFrame,
    rdir: Path,
    models: Dict[str, Dict[str, Any]] = None,
    filters: List[Tuple[str, Tuple[float, float]]] = None,
    n_jobs: int = -1,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Execute all the models and save the labels' result
    to a dedicated GeoDataFrame.

    data : pd.DataFrame
        Data to be used for the clustering, is a DataFrame with
        the selected features.
    cldf : gpd.GeoDataFrame
        For each model a new column is added to this GeoDataFrame
    rdir : Path
        Directory containing all the outputs: figures, tables and
    models : Dict[str, Dict[str, Any]]
        A dictionary of dicionary with all the cluster algorithms
        and that will be executed
    n_jobs: int, default: -1
        Number of parallel jobs to be used to explore the
        algorithms. A negative value mean that all the cores
        available will be used.

    * the DataFrame with the main score and characteristics
    of the clusters
    * the DataFrame with the feature importance
    if models is None:
        models = MODELS

    num_opts = sum(len(m["opts"]) for m in models.values())
        f"Exploring the cluster space with {len(models)} "
        f"models and {num_opts} tasks"

    # compute Hopkins per pre-processing chain per model
    # A statistical test which allow to guess if the data
    # follow an uniform distribution. If the test is positve
    # (an hopkins score which tends to 0) it means that the
    # data is not uniformly distributed. Hence clustering
    # can be useful to classify the observations. However,
    # if the score is too high (above 0.3 for exemple);
    # the data is uniformly distributed and clustering can’t
    # be really useful for the problem at hand.
    trans = {}
    for mname, model in models.items():
        tdata = data.copy()
        # apply all the transformations
        for pre in model.get("pre-processing", []):
            if pre is not None:
                tdata = pre.fit_transform(tdata)
        # avoid to pre-process the data for every attempt
        trans[mname] = tdata
        # compute Hopkins statistics to test the clusterability of the daset
        hop = hopkins(tdata, sampling_size=150)
        print(f"{mname}, Hopkins: {hop:.5f}")
        # print_hopkins(mname, hopkins)

    pll = Parallel(n_jobs=n_jobs, verbose=0)
    tasks = [
            pd.DataFrame(trans[mname], index=data.index, columns=data.columns),
        for mname, model in models.items()
        for opts in model["opts"]

    res = pll(
        delayed(model_worker)(mname, model, opts, df)
        for mname, model, opts, df in tqdm(tasks)

    cols = [
        "% covered by clusters",
        "davies bouldin",
        "calinski harabasz",
    scores, fimps, sels = [], [], []
    for long_label, labels, xcores, fimp in res:
        if filters is not None:
            sd = {k: v for k, v in zip(cols, xcores)}
            appends = apply_filter(sd, filters)
            # skip filters
            appends = [

        # save results
        # check conditions to save it in the vector layer only
        # when filter conditions are satisfied
        if all(appends):
            cldf.loc[data.index, long_label] = labels
    fi = pd.concat(fimps, axis=1)
    return sc, fi