From 801e7f3e6520dec6c42af8e77a78f878f09d4906 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 18 Oct 2021 16:23:44 +0200 Subject: [PATCH] Learning rate to statefile name. --- climax/core/dataset.py | 7 ++++++- climax/main/downscale_infer.py | 2 +- climax/main/downscale_train.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/climax/core/dataset.py b/climax/core/dataset.py index d90a72c..7bb306c 100644 --- a/climax/core/dataset.py +++ b/climax/core/dataset.py @@ -104,7 +104,8 @@ class EoDataset(torch.utils.data.Dataset): @staticmethod def state_file(model, predictand, predictors, plevels, dem=False, dem_features=False, doy=False, loss=None, cv=None, - season=None, anomalies=False, decay=None, optim=None): + season=None, anomalies=False, decay=None, optim=None, + lr=None): # naming convention: # <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt @@ -142,6 +143,10 @@ class EoDataset(torch.utils.data.Dataset): state_file = ('_'.join([state_file, 'd{:.0e}'.format(decay)]) if decay is not None else state_file) + # add suffix for learning rate values + state_file = ('_'.join([state_file, 'd{:.0e}'.format(lr)]) if lr + is not None else state_file) + # add suffix for training with anomalies state_file = ('_'.join([state_file, 'anom']) if anomalies else state_file) diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py index beff9e7..0a83e0f 100644 --- a/climax/main/downscale_infer.py +++ b/climax/main/downscale_infer.py @@ -40,7 +40,7 @@ if __name__ == '__main__': state_file = ERA5Dataset.state_file( NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM) + decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr']) # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 65c1530..7468fa8 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -44,7 +44,7 @@ if __name__ == '__main__': state_file = ERA5Dataset.state_file( NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM) + decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr']) # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) -- GitLab