From 970832cd2221e38eedcb44b4d8231357d230861c Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 21 Oct 2021 14:37:21 +0200
Subject: [PATCH] LR-scheduler to statefile.

---
 climax/core/dataset.py         | 6 +++++-
 climax/main/downscale_infer.py | 5 +++--
 climax/main/downscale_train.py | 3 ++-
 3 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/climax/core/dataset.py b/climax/core/dataset.py
index 4e4ad63..c18df02 100644
--- a/climax/core/dataset.py
+++ b/climax/core/dataset.py
@@ -113,7 +113,7 @@ class EoDataset(torch.utils.data.Dataset):
     def state_file(model, predictand, predictors, plevels, dem=False,
                    dem_features=False, doy=False, loss=None, cv=None,
                    season=None, anomalies=False, decay=None, optim=None,
-                   lr=None):
+                   lr=None, lr_scheduler=None):
 
         # naming convention:
         # <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt
@@ -155,6 +155,10 @@ class EoDataset(torch.utils.data.Dataset):
         state_file = ('_'.join([state_file, 'lr{:.0e}'.format(lr)]) if lr
                       is not None else state_file)
 
+        # add suffix for learning rate scheduler
+        state_file = ('_'.join([state_file, lr_scheduler.__name__]) if
+                      lr_scheduler is not None else state_file)
+
         # add suffix for training with anomalies
         state_file = ('_'.join([state_file, 'anom']) if anomalies else
                       state_file)
diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py
index 47f104d..5d26894 100644
--- a/climax/main/downscale_infer.py
+++ b/climax/main/downscale_infer.py
@@ -24,7 +24,7 @@ from climax.core.utils import split_date_range
 from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
                                 VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
                                 DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM,
-                                OPTIM_PARAMS, CHUNKS)
+                                OPTIM_PARAMS, CHUNKS, LR_SCHEDULER)
 from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
 
 # module level logger
@@ -40,7 +40,8 @@ if __name__ == '__main__':
     state_file = ERA5Dataset.state_file(
         NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
         dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
-        decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'])
+        decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'],
+        lr_scheduler=LR_SCHEDULER)
 
     # path to model state
     state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 770750a..a076496 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -44,7 +44,8 @@ if __name__ == '__main__':
     state_file = ERA5Dataset.state_file(
         NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
         dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
-        decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'])
+        decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'],
+        lr_scheduler=LR_SCHEDULER)
 
     # path to model state
     state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
-- 
GitLab