Skip to content
Snippets Groups Projects
Commit 970832cd authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

LR-scheduler to statefile.

parent 50232ca4
No related branches found
No related tags found
No related merge requests found
......@@ -113,7 +113,7 @@ class EoDataset(torch.utils.data.Dataset):
def state_file(model, predictand, predictors, plevels, dem=False,
dem_features=False, doy=False, loss=None, cv=None,
season=None, anomalies=False, decay=None, optim=None,
lr=None):
lr=None, lr_scheduler=None):
# naming convention:
# <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt
......@@ -155,6 +155,10 @@ class EoDataset(torch.utils.data.Dataset):
state_file = ('_'.join([state_file, 'lr{:.0e}'.format(lr)]) if lr
is not None else state_file)
# add suffix for learning rate scheduler
state_file = ('_'.join([state_file, lr_scheduler.__name__]) if
lr_scheduler is not None else state_file)
# add suffix for training with anomalies
state_file = ('_'.join([state_file, 'anom']) if anomalies else
state_file)
......
......@@ -24,7 +24,7 @@ from climax.core.utils import split_date_range
from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM,
OPTIM_PARAMS, CHUNKS)
OPTIM_PARAMS, CHUNKS, LR_SCHEDULER)
from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
# module level logger
......@@ -40,7 +40,8 @@ if __name__ == '__main__':
state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'])
decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'],
lr_scheduler=LR_SCHEDULER)
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
......
......@@ -44,7 +44,8 @@ if __name__ == '__main__':
state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'])
decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'],
lr_scheduler=LR_SCHEDULER)
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment