# -*- coding: utf-8 -*-
"""
Created on Tue Jul 14 11:04:27 2020

@author: Daniel
"""
# builtins
import os
import itertools

# externals
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import cm as colormap


# this function applies percentile stretching at the alpha level
# can be used to increase constrast for visualization
def contrast_stretching(image, alpha=2):

    # compute upper and lower percentiles defining the range of the stretch
    inf, sup = np.percentile(image, (alpha, 100 - alpha))

    # normalize image intensity distribution to
    # (alpha, 100 - alpha) percentiles
    norm = ((image - inf) * (image.max() - image.min()) /
            (sup - inf)) + image.min()

    # clip: values < inf = 0, values > sup = max
    norm[norm <= image.min()] = image.min()
    norm[norm >= image.max()] = image.max()

    return norm


def running_mean(x, w):
    cumsum = np.cumsum(np.insert(x, 0, 0))
    return (cumsum[w:] - cumsum[:-w]) / w


# plot_sample() plots a false color composite of the scene/tile together
# with the model prediction and the corresponding ground truth
def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
                bands=['red', 'green', 'blue'], stretch=False, state=None,
                outpath=os.path.join(os.getcwd(), '_samples/'),  **kwargs):

    # check whether to apply constrast stretching
    stretch = True if kwargs else False
    func = contrast_stretching if stretch else lambda x: x

    # create an rgb stack
    rgb = np.dstack([func(x[use_bands.index(band)], **kwargs)
                     for band in bands])

    # get labels and corresponding colors
    ulabels = [label['label'] for label in labels.values()]
    colors = [label['color'] for label in labels.values()]

    # create a ListedColormap
    cmap = ListedColormap(colors)
    boundaries = [*labels.keys(), cmap.N]
    norm = BoundaryNorm(boundaries, cmap.N)

    # create figure: check whether to plot model prediction
    if y_pred is not None:

        # compute accuracy
        acc = (y_pred == y).float().mean()

        # plot model prediction
        fig, ax = plt.subplots(1, 3, figsize=figsize)
        ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm)
        ax[2].set_title('Prediction ({:.2f}%)'.format(acc * 100), pad=15)

    else:
        fig, ax = plt.subplots(1, 2, figsize=figsize)

    # plot false color composite
    ax[0].imshow(rgb)
    ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=15)

    # plot ground thruth mask
    ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm)
    ax[1].set_title('Ground truth', pad=15)

    # create a patch (proxy artist) for every color
    patches = [mpatches.Patch(color=c, label=l) for c, l in
               zip(colors, ulabels)]

    # plot patches as legend
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2,
               frameon=False)

    # save figure
    if state is not None:
        os.makedirs(outpath, exist_ok=True)
        fig.savefig(os.path.join(outpath, state.replace('.pt', '.png')),
                    dpi=300, bbox_inches='tight')

    return fig, ax


# plot_confusion_matrix() plots the confusion matrix of the validation/test
# set returned by the pytorch.predict function
def plot_confusion_matrix(cm, labels, normalize=True,
                          figsize=(10, 10), cmap='Blues', state=None,
                          outpath=os.path.join(os.getcwd(), '_graphics/')):

    # number of classes
    nclasses = len(labels)

    # string format to plot values of confusion matrix
    fmt = '.0f'

    # minimum and maximum values of the colorbar
    vmin, vmax = 0, cm.max()

    # check whether to normalize the confusion matrix
    if normalize:
        # normalize
        cm = cm / cm.sum(axis=1, keepdims=True)

        # change string format to floating point
        fmt = '.2f'
        vmin, vmax = 0, 1

    # create figure
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # get colormap
    cmap = colormap.get_cmap(cmap, 256)

    # plot confusion matrix
    im = ax.imshow(cm, cmap=cmap, vmin=vmin, vmax=vmax)

    # threshold determining the color of the values
    thresh = (cm.max() + cm.min()) / 2

    # brightest/darkest color of current colormap
    cmap_min, cmap_max = cmap(0), cmap(256)

    # plot values of confusion matrix
    for i, j in itertools.product(range(nclasses), range(nclasses)):
        ax.text(j, i, format(cm[i, j], fmt), ha='center', va='center',
                color=cmap_max if cm[i, j] < thresh else cmap_min)

    # axes properties and labels
    ax.set(xticks=np.arange(nclasses),
           yticks=np.arange(nclasses),
           xticklabels=labels,
           yticklabels=labels,
           ylabel='True',
           xlabel='Predicted')

    # add colorbar axes
    cax = fig.add_axes([ax.get_position().x1 + 0.025, ax.get_position().y0,
                        0.05, ax.get_position().y1 - ax.get_position().y0])
    fig.colorbar(im, cax=cax)

    # save figure
    if state is not None:
        os.makedirs(outpath, exist_ok=True)
        fig.savefig(os.path.join(outpath, state.replace('.pt', '_cm.png')),
                    dpi=300, bbox_inches='tight')

    return fig, ax


def plot_loss(loss_file, figsize=(10, 10), step=5,
              colors=['lightgreen', 'green', 'skyblue', 'steelblue'],
              outpath=os.path.join(os.getcwd(), '_graphics/')):

    # load the model loss
    state = torch.load(loss_file)

    # get all non-zero elements, i.e. get number of epochs trained before
    # early stop
    loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in
            state.items() if k != 'epoch'}

    # compute running mean with a window equal to the number of batches in
    # an epoch
    rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()}

    # number of epochs trained
    epochs = np.arange(0, loss['tl'].shape[1])

    # instanciate figure
    fig, ax1 = plt.subplots(1, 1, figsize=figsize)

    # create axes for each parameter to plot
    ax2 = ax1.twinx()
    ax3 = ax1.twiny()
    ax4 = ax2.twiny()

    # list of axes
    axes = [ax1, ax2, ax3, ax4]

    # plot running mean loss and accuracy of the training dataset
    [ax.plot(v, color=c) for (k, v), ax, c in zip(rm.items(), axes, colors)
     if v.any()]

    # axes properties and labels
    nbatches = loss['tl'].shape[0]
    for ax in [ax3, ax4]:
        ax.set(xticks=[], xticklabels=[])
    ax1.set(xticks=np.arange(0, nbatches * epochs[-1] + 1, nbatches * step),
            xticklabels=epochs[::step],
            xlabel='Epoch',
            ylabel='Loss',
            ylim=(0, 1))
    ax2.set(ylabel='Accuracy',
            ylim=(0.5, 1))

    # compute early stopping point
    if loss['va'].any():
        esepoch = np.argmax(loss['va'].mean(axis=0)) * nbatches
        esacc = np.max(loss['va'].mean(axis=0))
        ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1],
                   ls='--', color='grey')
        ax1.text(esepoch - nbatches, ax1.get_ylim()[0] + 0.01,
                 'epoch = {}, accuracy = {:.1f}%'
                 .format(int(esepoch / nbatches), esacc * 100),
                 ha='right', color='grey')

    # create a patch (proxy artist) for every color
    ulabels = ['Training loss', 'Training accuracy',
               'Validation loss', 'Validation accuracy']
    patches = [mlines.Line2D([], [], color=c, label=l) for c, l in
               zip(colors, ulabels)]
    # plot patches as legend
    ax1.legend(handles=patches, loc='lower left', frameon=False)

    # save figure
    os.makedirs(outpath, exist_ok=True)
    fig.savefig(os.path.join(
        outpath, os.path.basename(loss_file).replace('.pt', '.png')),
                dpi=300, bbox_inches='tight')

    return fig, ax