Skip to content
Snippets Groups Projects
Commit aa5e72b8 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Optimizer to statefile name.

parent c62eb4ab
No related branches found
No related tags found
No related merge requests found
...@@ -104,7 +104,7 @@ class EoDataset(torch.utils.data.Dataset): ...@@ -104,7 +104,7 @@ class EoDataset(torch.utils.data.Dataset):
@staticmethod @staticmethod
def state_file(model, predictand, predictors, plevels, dem=False, def state_file(model, predictand, predictors, plevels, dem=False,
dem_features=False, doy=False, loss=None, cv=None, dem_features=False, doy=False, loss=None, cv=None,
season=None, anomalies=False, decay=None): season=None, anomalies=False, decay=None, optim=None):
# naming convention: # naming convention:
# <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt # <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt
...@@ -134,6 +134,10 @@ class EoDataset(torch.utils.data.Dataset): ...@@ -134,6 +134,10 @@ class EoDataset(torch.utils.data.Dataset):
# add name of loss function to state file # add name of loss function to state file
state_file = '_'.join([state_file, repr(loss).strip('()')]) state_file = '_'.join([state_file, repr(loss).strip('()')])
# add suffix for optimizer
state_file = ('_'.join([state_file, optim.__name__]) if optim is not
None else state_file)
# add suffix for weight decay values # add suffix for weight decay values
state_file = ('_'.join([state_file, 'd{:.0e}'.format(decay)]) if decay state_file = ('_'.join([state_file, 'd{:.0e}'.format(decay)]) if decay
is not None else state_file) is not None else state_file)
......
...@@ -30,7 +30,6 @@ assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS]) ...@@ -30,7 +30,6 @@ assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS])
# ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature'] # ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature']
# ERA5_S_PREDICTORS = ['mean_sea_level_pressure'] # ERA5_S_PREDICTORS = ['mean_sea_level_pressure']
ERA5_S_PREDICTORS = ['surface_pressure'] ERA5_S_PREDICTORS = ['surface_pressure']
# ERA5_S_PREDICTORS = ['total_precipitation']
assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS]) assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS])
# ERA5 predictor variables # ERA5 predictor variables
...@@ -43,6 +42,9 @@ ERA5_PLEVELS = [500, 850] ...@@ -43,6 +42,9 @@ ERA5_PLEVELS = [500, 850]
# Anomaly = (time_series - mean(time_series)) / (std(time_series)) # Anomaly = (time_series - mean(time_series)) / (std(time_series))
ANOMALIES = False ANOMALIES = False
# Dask chunk size for loading the training data
CHUNKS = {'time': 365}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Auxiliary predictors -------------------------------------------------------- # Auxiliary predictors --------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -74,6 +76,9 @@ VALID_PERIOD = np.arange( ...@@ -74,6 +76,9 @@ VALID_PERIOD = np.arange(
datetime.datetime.strptime('1991-01-01', '%Y-%m-%d').date(), datetime.datetime.strptime('1991-01-01', '%Y-%m-%d').date(),
datetime.datetime.strptime('2011-01-01', '%Y-%m-%d').date()) datetime.datetime.strptime('2011-01-01', '%Y-%m-%d').date())
# entire reference period
REFERENCE_PERIOD = np.concatenate([CALIB_PERIOD, VALID_PERIOD], axis=0)
# stratify training/validation set for precipitation by number of wet days # stratify training/validation set for precipitation by number of wet days
STRATIFY = False STRATIFY = False
...@@ -82,7 +87,7 @@ STRATIFY = False ...@@ -82,7 +87,7 @@ STRATIFY = False
# 10% of CALIB_PERIOD for validation # 10% of CALIB_PERIOD for validation
VALID_SIZE = 0.1 VALID_SIZE = 0.1
# whether to train using cross-validation # number of folds for training with KFold cross-validation
CV = 5 CV = 5
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -118,23 +123,23 @@ LOSS = MSELoss() ...@@ -118,23 +123,23 @@ LOSS = MSELoss()
# stochastic optimization algorithm # stochastic optimization algorithm
OPTIM = torch.optim.SGD OPTIM = torch.optim.SGD
OPTIM_PARAMS = {'lr': 0.005, # learning rate # OPTIM = torch.optim.Adam
OPTIM_PARAMS = {'lr': 1e-1, # learning rate
'weight_decay': 1e-6 # regularization rate 'weight_decay': 1e-6 # regularization rate
} }
if OPTIM == torch.optim.SGD: if OPTIM == torch.optim.SGD:
OPTIM_PARAMS['momentum'] = 0.9 OPTIM_PARAMS['momentum'] = 0.9
# learning rate scheduler # learning rate scheduler
# LR_SCHEDULER = torch.optim.lr_scheduler.ExponentialLR # LR_SCHEDULER = torch.optim.lr_scheduler.MultiStepLR
LR_SCHEDULER = None LR_SCHEDULER = None
LR_SCHEDULER_PARAMS = {'gamma': 0.9} LR_SCHEDULER_PARAMS = {'gamma': 0.25, 'milestones': [1, 3]}
# whether to randomly shuffle time steps or to conserve time series for model # whether to randomly shuffle time steps or to conserve time series for model
# training # training
SHUFFLE = True SHUFFLE = True
# whether to normalize the training data to [0, 1] (True) or to standardize to # whether to normalize the training data to [0, 1]
# mean=0, std=1 (False)
NORM = True NORM = True
# batch size: number of time steps processed by the net in each iteration # batch size: number of time steps processed by the net in each iteration
...@@ -143,11 +148,11 @@ BATCH_SIZE = 16 ...@@ -143,11 +148,11 @@ BATCH_SIZE = 16
# network training configuration # network training configuration
TRAIN_CONFIG = { TRAIN_CONFIG = {
'checkpoint_state': {}, 'checkpoint_state': {},
'epochs': 50, 'epochs': 250,
'save': True, 'save': True,
'save_loaders': False, 'save_loaders': False,
'early_stop': True, 'early_stop': True,
'patience': 5, 'patience': 25,
'multi_gpu': True, 'multi_gpu': True,
'classification': False, 'classification': False,
'clip_gradients': True 'clip_gradients': True
......
...@@ -23,7 +23,7 @@ from climax.core.predict import predict_ERA5 ...@@ -23,7 +23,7 @@ from climax.core.predict import predict_ERA5
from climax.core.utils import split_date_range from climax.core.utils import split_date_range
from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET, from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS, VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
DEM, DEM_FEATURES, LOSS, ANOMALIES, DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM,
OPTIM_PARAMS) OPTIM_PARAMS)
from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
...@@ -40,7 +40,7 @@ if __name__ == '__main__': ...@@ -40,7 +40,7 @@ if __name__ == '__main__':
state_file = ERA5Dataset.state_file( state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
decay=OPTIM_PARAMS['weight_decay']) decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM)
# path to model state # path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
......
...@@ -44,7 +44,7 @@ if __name__ == '__main__': ...@@ -44,7 +44,7 @@ if __name__ == '__main__':
state_file = ERA5Dataset.state_file( state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
decay=OPTIM_PARAMS['weight_decay']) decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM)
# path to model state # path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment