Skip to content
Snippets Groups Projects
Commit 125090f0 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented learning rate decay.

parent 4a081755
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@
import datetime
# externals
import torch
import numpy as np
# locals
......@@ -29,6 +30,7 @@ assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS])
# ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature']
# ERA5_S_PREDICTORS = ['mean_sea_level_pressure']
ERA5_S_PREDICTORS = ['surface_pressure']
# ERA5_S_PREDICTORS = ['total_precipitation']
assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS])
# ERA5 predictor variables
......@@ -81,8 +83,7 @@ STRATIFY = False
VALID_SIZE = 0.1
# whether to train using cross-validation
# TODO: define number of folds, description
CV = False
CV = 5
# -----------------------------------------------------------------------------
# Observations ----------------------------------------------------------------
......@@ -115,6 +116,19 @@ LOSS = MSELoss()
# LOSS = BernoulliGammaLoss(min_amount=1)
# LOSS = BernoulliWeibullLoss(min_amount=1)
# stochastic optimization algorithm
OPTIM = torch.optim.SGD
OPTIM_PARAMS = {'lr': 0.005, # learning rate
'weight_decay': 1e-6 # regularization rate
}
if OPTIM == torch.optim.SGD:
OPTIM_PARAMS['momentum'] = 0.9
# learning rate scheduler
# LR_SCHEDULER = torch.optim.lr_scheduler.ExponentialLR
LR_SCHEDULER = None
LR_SCHEDULER_PARAMS = {'gamma': 0.9}
# whether to randomly shuffle time steps or to conserve time series for model
# training
SHUFFLE = True
......@@ -126,23 +140,18 @@ NORM = True
# batch size: number of time steps processed by the net in each iteration
BATCH_SIZE = 16
# learning rate
LR = 0.0005
# regularization rate
LAMBDA = 0.05
# network training configuration
TRAIN_CONFIG = {
'checkpoint_state': {},
'epochs': 250,
'epochs': 50,
'save': True,
'save_loaders': False,
'early_stop': True,
'patience': 50,
'patience': 5,
'multi_gpu': True,
'classification': False,
'clip_gradients': True
'clip_gradients': True,
# 'lr_scheduler': torch.optim.lr_scheduler.
}
# whether to overwrite existing models
......
......@@ -23,7 +23,8 @@ from climax.core.predict import predict_ERA5
from climax.core.utils import split_date_range
from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
DEM, DEM_FEATURES, LOSS, ANOMALIES, LAMBDA)
DEM, DEM_FEATURES, LOSS, ANOMALIES,
OPTIM_PARAMS)
from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
# module level logger
......@@ -39,7 +40,7 @@ if __name__ == '__main__':
state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
decay=LAMBDA)
decay=OPTIM_PARAMS['weight_decay'])
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
......
......@@ -11,7 +11,6 @@ from datetime import timedelta
from logging.config import dictConfig
# externals
import torch
import xarray as xr
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
......@@ -24,10 +23,12 @@ from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.loss import MSELoss, L1Loss
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM,
NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
OVERWRITE, DEM, DEM_FEATURES, STRATIFY,
WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES)
WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES,
OPTIM_PARAMS, LR_SCHEDULER,
LR_SCHEDULER_PARAMS)
from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
# module level logger
......@@ -43,7 +44,7 @@ if __name__ == '__main__':
state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
decay=LAMBDA)
decay=OPTIM_PARAMS['weight_decay'])
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
......@@ -119,11 +120,12 @@ if __name__ == '__main__':
inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
net = NET(state_file, inputs, outputs, filters=FILTERS)
# initialize optimizer
# optimizer = torch.optim.Adam(net.parameters(), lr=LR,
# weight_decay=LAMBDA)
optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9,
weight_decay=LAMBDA)
# initialize optimizer
optimizer = OPTIM(net.parameters(), **OPTIM_PARAMS)
# initialize learning rate scheduler
if LR_SCHEDULER is not None:
LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS)
# initialize training data
LogConfig.init_log('Initializing training data.')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment