diff --git a/climax/main/config.py b/climax/main/config.py index bba8d4d5831531cf5f86e63bbf2f0d37e25d3483..068f82a4d946ae577594e090c986eff3aa4d9fb7 100644 --- a/climax/main/config.py +++ b/climax/main/config.py @@ -7,6 +7,7 @@ import datetime # externals +import torch import numpy as np # locals @@ -29,6 +30,7 @@ assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS]) # ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature'] # ERA5_S_PREDICTORS = ['mean_sea_level_pressure'] ERA5_S_PREDICTORS = ['surface_pressure'] +# ERA5_S_PREDICTORS = ['total_precipitation'] assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS]) # ERA5 predictor variables @@ -81,8 +83,7 @@ STRATIFY = False VALID_SIZE = 0.1 # whether to train using cross-validation -# TODO: define number of folds, description -CV = False +CV = 5 # ----------------------------------------------------------------------------- # Observations ---------------------------------------------------------------- @@ -115,6 +116,19 @@ LOSS = MSELoss() # LOSS = BernoulliGammaLoss(min_amount=1) # LOSS = BernoulliWeibullLoss(min_amount=1) +# stochastic optimization algorithm +OPTIM = torch.optim.SGD +OPTIM_PARAMS = {'lr': 0.005, # learning rate + 'weight_decay': 1e-6 # regularization rate + } +if OPTIM == torch.optim.SGD: + OPTIM_PARAMS['momentum'] = 0.9 + +# learning rate scheduler +# LR_SCHEDULER = torch.optim.lr_scheduler.ExponentialLR +LR_SCHEDULER = None +LR_SCHEDULER_PARAMS = {'gamma': 0.9} + # whether to randomly shuffle time steps or to conserve time series for model # training SHUFFLE = True @@ -126,23 +140,18 @@ NORM = True # batch size: number of time steps processed by the net in each iteration BATCH_SIZE = 16 -# learning rate -LR = 0.0005 - -# regularization rate -LAMBDA = 0.05 - # network training configuration TRAIN_CONFIG = { 'checkpoint_state': {}, - 'epochs': 250, + 'epochs': 50, 'save': True, 'save_loaders': False, 'early_stop': True, - 'patience': 50, + 'patience': 5, 'multi_gpu': True, 'classification': False, - 'clip_gradients': True + 'clip_gradients': True, + # 'lr_scheduler': torch.optim.lr_scheduler. } # whether to overwrite existing models diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py index 1f07b341e361e1560ad56df9f80fff02f1dcce5e..eb6cf2eaa02481aee6de5f5cb69dd2df972619f5 100644 --- a/climax/main/downscale_infer.py +++ b/climax/main/downscale_infer.py @@ -23,7 +23,8 @@ from climax.core.predict import predict_ERA5 from climax.core.utils import split_date_range from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET, VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS, - DEM, DEM_FEATURES, LOSS, ANOMALIES, LAMBDA) + DEM, DEM_FEATURES, LOSS, ANOMALIES, + OPTIM_PARAMS) from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH # module level logger @@ -39,7 +40,7 @@ if __name__ == '__main__': state_file = ERA5Dataset.state_file( NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=LAMBDA) + decay=OPTIM_PARAMS['weight_decay']) # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 749257c602a450a29051157a889e2fbacbfb43ee..314ae378e8b7367cf1aa4627961e4b68a9e42a4a 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -11,7 +11,6 @@ from datetime import timedelta from logging.config import dictConfig # externals -import torch import xarray as xr from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader @@ -24,10 +23,12 @@ from pysegcnn.core.logging import log_conf from climax.core.dataset import ERA5Dataset, NetCDFDataset from climax.core.loss import MSELoss, L1Loss from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, - CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, - LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, + CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM, + NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, OVERWRITE, DEM, DEM_FEATURES, STRATIFY, - WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES) + WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES, + OPTIM_PARAMS, LR_SCHEDULER, + LR_SCHEDULER_PARAMS) from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH # module level logger @@ -43,7 +44,7 @@ if __name__ == '__main__': state_file = ERA5Dataset.state_file( NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=LAMBDA) + decay=OPTIM_PARAMS['weight_decay']) # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) @@ -119,11 +120,12 @@ if __name__ == '__main__': inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) net = NET(state_file, inputs, outputs, filters=FILTERS) - # initialize optimizer - # optimizer = torch.optim.Adam(net.parameters(), lr=LR, - # weight_decay=LAMBDA) - optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9, - weight_decay=LAMBDA) + # initialize optimizer + optimizer = OPTIM(net.parameters(), **OPTIM_PARAMS) + + # initialize learning rate scheduler + if LR_SCHEDULER is not None: + LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS) # initialize training data LogConfig.init_log('Initializing training data.')