Skip to content
Snippets Groups Projects
Commit 7642ca3a authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Decluttered statefile name.

parent fa4ae83c
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,7 @@ import dask.array as da
from climax.core.constants import (ERA5_VARIABLES, ERA5_PRESSURE_LEVELS,
ERA5_P_VARIABLE_NAME, ERA5_S_VARIABLE_NAME,
PROJECTION)
from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss
from pysegcnn.core.utils import search_files, img2np
from pysegcnn.core.trainer import LogConfig
......@@ -102,7 +103,7 @@ class EoDataset(torch.utils.data.Dataset):
@staticmethod
def state_file(model, predictand, predictors, plevels, dem=False,
dem_features=False, doy=False):
dem_features=False, doy=False, loss=None, cv=False):
# naming convention:
# <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt
......@@ -121,6 +122,21 @@ class EoDataset(torch.utils.data.Dataset):
state_file)
state_file = '_'.join([state_file, 'doy']) if doy else state_file
# check which loss function is used
if predictand == 'pr':
if (isinstance(loss, BernoulliGammaLoss) or
isinstance(loss, BernoulliGenParetoLoss)):
# adjust state file for precipitation
state_file = '_'.join([state_file, '{}mm'.format(
str(loss.min_amount).replace('.', '')),
repr(loss).strip('()')])
else:
# add name of loss function to state file
state_file = '_'.join([state_file, repr(loss).strip('()')])
# add suffix for training with cross-validation
state_file = '_'.join([state_file, 'cv']) if cv else state_file
# add file extension: .pt
return '.'.join([state_file, 'pt'])
......
......@@ -21,7 +21,6 @@ from pysegcnn.core.utils import search_files
from climax.core.dataset import ERA5Dataset
from climax.core.predict import predict_ERA5
from climax.core.utils import split_date_range
from climax.core.loss import BernoulliGammaLoss
from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
DEM, DEM_FEATURES, LOSS, CV)
......@@ -36,23 +35,10 @@ if __name__ == '__main__':
# initialize timing
start_time = time.monotonic()
# filename of pretrained model
# initialize network filename
state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY)
# adjust statefile name for precipitation
if PREDICTAND == 'pr':
if isinstance(LOSS, BernoulliGammaLoss):
state_file = state_file.replace('.pt', '_{}mm_{}.pt'.format(
str(LOSS.min_amount).replace('.', ''),
repr(LOSS).strip('()')))
else:
state_file = state_file.replace('.pt', '_{}.pt'.format(
repr(LOSS).strip('()')))
# add suffix for training with cross-validation
state_file = state_file.replace('.pt', '_cv.pt') if CV else state_file
dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, cv=CV)
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
......
......@@ -41,20 +41,7 @@ if __name__ == '__main__':
# initialize network filename
state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY)
# adjust statefile name for precipitation
if PREDICTAND == 'pr':
if isinstance(LOSS, BernoulliGammaLoss):
state_file = state_file.replace('.pt', '_{}mm_{}.pt'.format(
str(LOSS.min_amount).replace('.', ''),
repr(LOSS).strip('()')))
else:
state_file = state_file.replace('.pt', '_{}.pt'.format(
repr(LOSS).strip('()')))
# add suffix for training with cross-validation
state_file = state_file.replace('.pt', '_cv.pt') if CV else state_file
dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, cv=CV)
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment