# !/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 26 16:31:36 2020

@author: Daniel
"""
# builtins
import os
import sys

# externals
import numpy as np
import torch
import torch.nn as nn

# append path to local files to the python search path
sys.path.append('..')

# locals
from pytorch.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool,
                            Conv2dUpsample, Conv2dSame)


class Network(nn.Module):

    def __init__(self):
        super().__init__()

    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

    def unfreeze(self):
        for param in self.parameters():
            param.requires_grad = True

    def save(self, state_file, optimizer, bands,
             outpath=os.path.join(os.getcwd(), '_models')):

        # check if the output path exists and if not, create it
        if not os.path.isdir(outpath):
            os.makedirs(outpath, exist_ok=True)

        # initialize dictionary to store network parameters
        model_state = {}

        # store input bands
        model_state['bands'] = bands

        # store construction parameters to instanciate the network
        model_state['params'] = {
            'skip': self.skip,
            'filters': self.nfilters,
            'nclasses': self.nclasses,
            'in_channels': self.in_channels
            }

        # store optional keyword arguments
        model_state['kwargs'] = self.kwargs

        # store model epoch
        model_state['epoch'] = self.epoch

        # store model and optimizer state
        model_state['model_state_dict'] = self.state_dict()
        model_state['optim_state_dict'] = optimizer.state_dict()

        # model state dictionary stores the values of all trainable parameters
        state = os.path.join(outpath, state_file)
        torch.save(model_state, state)
        print('Network parameters saved in {}'.format(state))

        return state

    def load(self, state_file, optimizer=None,
             inpath=os.path.join(os.getcwd(), '_models')):

        # load the model state file
        state = os.path.join(inpath, state_file)
        model_state = torch.load(state)

        # resume network parameters
        print('Loading network parameters from {} ...'.format(state))
        self.load_state_dict(model_state['model_state_dict'])
        self.epoch = model_state['epoch']

        # resume optimizer parameters
        if optimizer is not None:
            print('Loading optimizer parameters from {} ...'.format(state))
            optimizer.load_state_dict(model_state['optim_state_dict'])

        return state


class UNet(Network):

    def __init__(self, in_channels, nclasses, filters, skip, **kwargs):
        super().__init__()

        # number of input channels
        self.in_channels = in_channels

        # number of classes
        self.nclasses = nclasses

        # configuration of the convolutional layers in the network
        self.kwargs = kwargs
        self.nfilters = filters

        # convolutional layers of the encoder
        self.filters = np.hstack([np.array(in_channels), np.array(filters)])

        # whether to apply skip connections
        self.skip = skip

        # number of epochs trained
        self.epoch = 0

        # construct the encoder
        self.encoder = Encoder(filters=self.filters, block=Conv2dPool,
                               **kwargs)

        # construct the decoder
        self.decoder = Decoder(filters=self.filters, block=Conv2dUnpool,
                               skip=skip, **kwargs)

        # construct the classifier
        self.classifier = Conv2dSame(in_channels=filters[0],
                                     out_channels=self.nclasses,
                                     kernel_size=1)

    def forward(self, x):

        # forward pass: encoder
        x = self.encoder(x)

        # forward pass: decoder
        x = self.decoder(x, self.encoder.cache)

        # clear intermediate outputs
        del self.encoder.cache

        # classification
        return self.classifier(x)