diff --git a/climax/core/dataset.py b/climax/core/dataset.py index d90a72c929173f84cae69875de2c47776ddfcdfd..7bb306c43eade0bdac0c500f0162ab30c7b46518 100644 --- a/climax/core/dataset.py +++ b/climax/core/dataset.py @@ -104,7 +104,8 @@ class EoDataset(torch.utils.data.Dataset): @staticmethod def state_file(model, predictand, predictors, plevels, dem=False, dem_features=False, doy=False, loss=None, cv=None, - season=None, anomalies=False, decay=None, optim=None): + season=None, anomalies=False, decay=None, optim=None, + lr=None): # naming convention: # <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt @@ -142,6 +143,10 @@ class EoDataset(torch.utils.data.Dataset): state_file = ('_'.join([state_file, 'd{:.0e}'.format(decay)]) if decay is not None else state_file) + # add suffix for learning rate values + state_file = ('_'.join([state_file, 'd{:.0e}'.format(lr)]) if lr + is not None else state_file) + # add suffix for training with anomalies state_file = ('_'.join([state_file, 'anom']) if anomalies else state_file) diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py index beff9e71ec02d17880b435c90e13cb6c3265963b..0a83e0fc20b1508b1b150daf28ecc26b929a3769 100644 --- a/climax/main/downscale_infer.py +++ b/climax/main/downscale_infer.py @@ -40,7 +40,7 @@ if __name__ == '__main__': state_file = ERA5Dataset.state_file( NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM) + decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr']) # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 65c15300f7fb67c756a6357cf74bba17c25f21ae..7468fa839c5e16f41f94e4e5addb90d552e53457 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -44,7 +44,7 @@ if __name__ == '__main__': state_file = ERA5Dataset.state_file( NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM) + decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr']) # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)