diff --git a/climax/core/dataset.py b/climax/core/dataset.py index 7059ba7929c96587fbf0314549f626e01e05d62f..05e6d7535fcfcf100c14b9a5d9c4681ac56e3621 100644 --- a/climax/core/dataset.py +++ b/climax/core/dataset.py @@ -104,7 +104,7 @@ class EoDataset(torch.utils.data.Dataset): @staticmethod def state_file(model, predictand, predictors, plevels, dem=False, - doy=False): + dem_features=False, doy=False): # naming convention: # <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt @@ -116,8 +116,11 @@ class EoDataset(torch.utils.data.Dataset): state_file = '_'.join([model.__name__, str(predictand), Ppredictors, *plevels, Spredictors]) - # check whether digital elevation model and day of year were used + # check whether digital elevation model, slope and aspect, and the day + # of year were used state_file = '_'.join([state_file, 'dem']) if dem else state_file + state_file = ('_'.join([state_file, 'sa']) if dem_features else + state_file) state_file = '_'.join([state_file, 'doy']) if doy else state_file # add file extension: .pt diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py index 3401d893d4aa5fc539bf367f23240efab53199ae..f6df06eea6c650e1b92d8491a20dd8351c4f0718 100644 --- a/climax/main/downscale_infer.py +++ b/climax/main/downscale_infer.py @@ -11,14 +11,13 @@ from datetime import timedelta from logging.config import dictConfig # externals -import numpy as np import xarray as xr # locals from pysegcnn.core.trainer import LogConfig from pysegcnn.core.models import Network from pysegcnn.core.logging import log_conf -from pysegcnn.core.utils import img2np, search_files +from pysegcnn.core.utils import search_files from climax.core.dataset import ERA5Dataset from climax.core.predict import predict_ERA5 from climax.core.utils import split_date_range @@ -38,7 +37,8 @@ if __name__ == '__main__': # filename of pretrained model state_file = ERA5Dataset.state_file( - NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, doy=DOY) + NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, + dem_features=DEM_FEATURES, doy=DOY) state_file = MODEL_PATH.joinpath(state_file) # initialize logging diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 56ee01006a35a51a1548ba796f845c6281f8a368..b8104516daf1ab21f5a2869472a94257c6a8b3ff 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -11,13 +11,12 @@ from logging.config import dictConfig # externals import torch -import numpy as np import xarray as xr from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader # locals -from pysegcnn.core.utils import search_files, img2np +from pysegcnn.core.utils import search_files from pysegcnn.core.trainer import NetworkTrainer, LogConfig from pysegcnn.core.models import Network from pysegcnn.core.logging import log_conf @@ -39,7 +38,8 @@ if __name__ == '__main__': # initialize network filename state_file = ERA5Dataset.state_file( - NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, doy=DOY) + NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, + dem_features=DEM_FEATURES, doy=DOY) state_file = MODEL_PATH.joinpath(state_file) # initialize logging