From 105a90b47147b7cec283ee4ade03f1df90563b9a Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 19 Aug 2021 11:28:42 +0200
Subject: [PATCH] Added slope and aspect parameters to file naming convention.

---
 climax/core/dataset.py         | 7 +++++--
 climax/main/downscale_infer.py | 6 +++---
 climax/main/downscale_train.py | 6 +++---
 3 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/climax/core/dataset.py b/climax/core/dataset.py
index 7059ba7..05e6d75 100644
--- a/climax/core/dataset.py
+++ b/climax/core/dataset.py
@@ -104,7 +104,7 @@ class EoDataset(torch.utils.data.Dataset):
 
     @staticmethod
     def state_file(model, predictand, predictors, plevels, dem=False,
-                   doy=False):
+                   dem_features=False, doy=False):
 
         # naming convention:
         # <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt
@@ -116,8 +116,11 @@ class EoDataset(torch.utils.data.Dataset):
         state_file = '_'.join([model.__name__, str(predictand), Ppredictors,
                                *plevels, Spredictors])
 
-        # check whether digital elevation model and day of year were used
+        # check whether digital elevation model, slope and aspect, and the day
+        # of year were used
         state_file = '_'.join([state_file, 'dem']) if dem else state_file
+        state_file = ('_'.join([state_file, 'sa']) if dem_features else
+                      state_file)
         state_file = '_'.join([state_file, 'doy']) if doy else state_file
 
         # add file extension: .pt
diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py
index 3401d89..f6df06e 100644
--- a/climax/main/downscale_infer.py
+++ b/climax/main/downscale_infer.py
@@ -11,14 +11,13 @@ from datetime import timedelta
 from logging.config import dictConfig
 
 # externals
-import numpy as np
 import xarray as xr
 
 # locals
 from pysegcnn.core.trainer import LogConfig
 from pysegcnn.core.models import Network
 from pysegcnn.core.logging import log_conf
-from pysegcnn.core.utils import img2np, search_files
+from pysegcnn.core.utils import search_files
 from climax.core.dataset import ERA5Dataset
 from climax.core.predict import predict_ERA5
 from climax.core.utils import split_date_range
@@ -38,7 +37,8 @@ if __name__ == '__main__':
 
     # filename of pretrained model
     state_file = ERA5Dataset.state_file(
-        NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, doy=DOY)
+        NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
+        dem_features=DEM_FEATURES, doy=DOY)
     state_file = MODEL_PATH.joinpath(state_file)
 
     # initialize logging
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 56ee010..b810451 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -11,13 +11,12 @@ from logging.config import dictConfig
 
 # externals
 import torch
-import numpy as np
 import xarray as xr
 from sklearn.model_selection import train_test_split
 from torch.utils.data import DataLoader
 
 # locals
-from pysegcnn.core.utils import search_files, img2np
+from pysegcnn.core.utils import search_files
 from pysegcnn.core.trainer import NetworkTrainer, LogConfig
 from pysegcnn.core.models import Network
 from pysegcnn.core.logging import log_conf
@@ -39,7 +38,8 @@ if __name__ == '__main__':
 
     # initialize network filename
     state_file = ERA5Dataset.state_file(
-        NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, doy=DOY)
+        NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
+        dem_features=DEM_FEATURES, doy=DOY)
     state_file = MODEL_PATH.joinpath(state_file)
 
     # initialize logging
-- 
GitLab