import os
import time
import xarray as xr
import rasterio
from rasterio.io import MemoryFile
import pyproj
from tqdm import tqdm
from multiprocessing import Pool
from datetime import datetime


def netcdf_to_cog(netcdf_path: str, cog_dir: str) -> None:
    """
    Convert a SWE netCDF file to a series of Cloud Optimized GeoTIFF (COG) files.

    Parameters
    ----------
    netcdf_path : str
        Path to the netCDF file.
    cog_dir : str
        Directory for the output COG files.
    swe_name_year : str
        Name of the SWE dataset and year. This is used in the COG file names.
    """
    time_start: float = time.perf_counter()

    os.makedirs(cog_dir, exist_ok=True)

    dataset = xr.open_dataset(netcdf_path)
    spatial_ref = dataset["spatial_ref"].attrs["crs_wkt"]
    crs = pyproj.CRS.from_string(spatial_ref)
    transform = rasterio.transform.from_bounds(
        dataset.x[0],
        dataset.y[-1],
        dataset.x[-1],
        dataset.y[0],
        dataset.sizes["x"],
        dataset.sizes["y"],
    )

    for i in tqdm(range(dataset.sizes["time"])):
        time_slice = dataset.isel(time=i)
        timestamp = time_slice.time.values.astype("datetime64[D]").astype(str)

        date = datetime.strptime(timestamp, "%Y-%m-%d")

        formatted_timestamp = date.strftime("%Y%m%d")

        generic_var_name = "swe"
        method_standard = "reconstruction"
        var_type = "c"
        spatial_support = "50m"
        depth_reference = "s"
        bounding_box = "it"
        epsg_code = "epsg.3035"
        current_date = datetime.now()
        version_code = "v" + current_date.strftime("%Y%m%d")

        cog_file = f"{generic_var_name}_{method_standard}_{var_type}_{spatial_support}_{depth_reference}_{formatted_timestamp}_{bounding_box}_{epsg_code}_{version_code}.tif"
        cog_path = os.path.join(cog_dir, cog_file)

        with rasterio.open(
            "temp.tif",
            "w",
            driver="GTiff",
            height=time_slice.y.size,
            width=time_slice.x.size,
            count=1,
            dtype=str(time_slice.SWE.dtype),
            crs=crs,
            transform=transform,
        ) as temp:
            temp.write(time_slice.SWE.values, 1)

        with rasterio.open("temp.tif") as src:
            profile = src.profile
            profile.update(
                driver="COG",
                dtype=rasterio.float32,
                nodata=0,
                compress="deflate",
                predictor=2,
                blockxsize=256,
                blockysize=256,
                tiled=True,
            )

            with rasterio.open(cog_path, "w", **profile) as dst:
                dst.write(src.read(1), 1)

    os.remove("temp.tif")

    time_end: float = time.perf_counter()
    print(f"COG creation time: {time_end - time_start:.2f} seconds")


def process_multiple_datasets(netcdf_paths, cog_dirs):
    for netcdf_path, cog_dir in zip(netcdf_paths, cog_dirs):
        netcdf_to_cog(netcdf_path, cog_dir)