Skip to content
Snippets Groups Projects
Commit 9c912abe authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Set learning and weight decay rates for supersampling task.

parent 4729b62e
No related branches found
No related tags found
No related merge requests found
......@@ -22,14 +22,14 @@ from climax.core.loss import (BernoulliGammaLoss, MSELoss, L1Loss,
# -----------------------------------------------------------------------------
# ERA5 predictor variables on pressure levels
ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind',
'v_component_of_wind', 'specific_humidity']
# ERA5_P_PREDICTORS = []
# ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind',
# 'v_component_of_wind', 'specific_humidity']
ERA5_P_PREDICTORS = []
assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS])
# ERA5 predictor variables on single levels
ERA5_S_PREDICTORS = ['surface_pressure']
# ERA5_S_PREDICTORS = ['total_precipitation']
# ERA5_S_PREDICTORS = ['surface_pressure']
ERA5_S_PREDICTORS = ['total_precipitation']
assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS])
# ERA5 predictor variables
......@@ -50,10 +50,10 @@ CHUNKS = {'time': 365}
# -----------------------------------------------------------------------------
# include day of year as predictor
DOY = True
DOY = False
# use digital elevation model instead of model orography
DEM = True
DEM = False
if DEM:
# remove model orography when using DEM
if 'orography' in ERA5_S_PREDICTORS:
......@@ -91,14 +91,14 @@ VALID_SIZE = 0.2
CV = 5
# number of bootstrapped model trainings
BOOTSTRAP = 10
BOOTSTRAP=10
# -----------------------------------------------------------------------------
# Observations ----------------------------------------------------------------
# -----------------------------------------------------------------------------
# target variable: check if target variable is valid
PREDICTAND = 'pr'
PREDICTAND='pr'
assert PREDICTAND in PREDICTANDS
# threshold defining the minimum amount of precipitation (mm) for a wet day
......@@ -122,10 +122,10 @@ FILTERS = [32, 64, 128, 256]
# for precipitation:
# BernoulliGammaLoss (NLL of Bernoulli-Gamma distribution)
# BernoulliWeibullLoss (NLL of Bernoulli-Weibull distribution)
# LOSS = L1Loss()
LOSS = MSELoss()
# LOSS = BernoulliGammaLoss(min_amount=1)
# LOSS = BernoulliWeibullLoss(min_amount=1)
# LOSS=MSELoss()
LOSS=MSELoss()
# LOSS=MSELoss()
# LOSS=MSELoss()
# batch size: number of time steps processed by the net in each iteration
BATCH_SIZE = 16
......@@ -164,17 +164,20 @@ if PREDICTAND is 'pr':
WEIGHT_DECAY = 1e-5
if isinstance(LOSS, MSELoss):
MAX_LR = 0.0004
WEIGHT_DECAY = 1e-3
# weight decay rates for supersampling task
if not any(ERA5_P_PREDICTORS) and ERA5_S_PREDICTORS == 'pr':
WEIGHT_DECAY = 1e-5
else:
WEIGHT_DECAY = 1e-3
if isinstance(LOSS, BernoulliGammaLoss):
# learning rates for supersampling task
if not any(ERA5_P_PREDICTORS) and ERA5_S_PREDICTORS == 'pr':
MAX_LR = 0.001
WEIGHT_DECAY = 1e-4
else:
MAX_LR = 0.0005 if OPTIM is torch.optim.Adam else 0.001
# weight decay
WEIGHT_DECAY = 1e-2
WEIGHT_DECAY = 1e-2
# base learning rate: MAX_LR / 4 (Smith L. (2017))
BASE_LR = MAX_LR / 4
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment