diff --git a/climax/core/dataset.py b/climax/core/dataset.py index 4e4ad63211d64414e735613b3a7512ecb2c81823..c18df0268bfde1867d26ee0499e222d80c20cf9f 100644 --- a/climax/core/dataset.py +++ b/climax/core/dataset.py @@ -113,7 +113,7 @@ class EoDataset(torch.utils.data.Dataset): def state_file(model, predictand, predictors, plevels, dem=False, dem_features=False, doy=False, loss=None, cv=None, season=None, anomalies=False, decay=None, optim=None, - lr=None): + lr=None, lr_scheduler=None): # naming convention: # <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt @@ -155,6 +155,10 @@ class EoDataset(torch.utils.data.Dataset): state_file = ('_'.join([state_file, 'lr{:.0e}'.format(lr)]) if lr is not None else state_file) + # add suffix for learning rate scheduler + state_file = ('_'.join([state_file, lr_scheduler.__name__]) if + lr_scheduler is not None else state_file) + # add suffix for training with anomalies state_file = ('_'.join([state_file, 'anom']) if anomalies else state_file) diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py index 47f104d55a76f755496c4d8ff7eb284f009681b7..5d2689478bb40de3e859a4b3ded2eab9f2961cb4 100644 --- a/climax/main/downscale_infer.py +++ b/climax/main/downscale_infer.py @@ -24,7 +24,7 @@ from climax.core.utils import split_date_range from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET, VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS, DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM, - OPTIM_PARAMS, CHUNKS) + OPTIM_PARAMS, CHUNKS, LR_SCHEDULER) from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH # module level logger @@ -40,7 +40,8 @@ if __name__ == '__main__': state_file = ERA5Dataset.state_file( NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr']) + decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'], + lr_scheduler=LR_SCHEDULER) # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 770750a672fc4c5dc23218ebf899cab7abb5fc2e..a07649645fe10e4c746643f900f629dca06d5b43 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -44,7 +44,8 @@ if __name__ == '__main__': state_file = ERA5Dataset.state_file( NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr']) + decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'], + lr_scheduler=LR_SCHEDULER) # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)