From 801e7f3e6520dec6c42af8e77a78f878f09d4906 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 18 Oct 2021 16:23:44 +0200
Subject: [PATCH] Learning rate to statefile name.

---
 climax/core/dataset.py         | 7 ++++++-
 climax/main/downscale_infer.py | 2 +-
 climax/main/downscale_train.py | 2 +-
 3 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/climax/core/dataset.py b/climax/core/dataset.py
index d90a72c..7bb306c 100644
--- a/climax/core/dataset.py
+++ b/climax/core/dataset.py
@@ -104,7 +104,8 @@ class EoDataset(torch.utils.data.Dataset):
     @staticmethod
     def state_file(model, predictand, predictors, plevels, dem=False,
                    dem_features=False, doy=False, loss=None, cv=None,
-                   season=None, anomalies=False, decay=None, optim=None):
+                   season=None, anomalies=False, decay=None, optim=None,
+                   lr=None):
 
         # naming convention:
         # <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt
@@ -142,6 +143,10 @@ class EoDataset(torch.utils.data.Dataset):
         state_file = ('_'.join([state_file, 'd{:.0e}'.format(decay)]) if decay
                       is not None else state_file)
 
+        # add suffix for learning rate values
+        state_file = ('_'.join([state_file, 'd{:.0e}'.format(lr)]) if lr
+                      is not None else state_file)
+
         # add suffix for training with anomalies
         state_file = ('_'.join([state_file, 'anom']) if anomalies else
                       state_file)
diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py
index beff9e7..0a83e0f 100644
--- a/climax/main/downscale_infer.py
+++ b/climax/main/downscale_infer.py
@@ -40,7 +40,7 @@ if __name__ == '__main__':
     state_file = ERA5Dataset.state_file(
         NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
         dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
-        decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM)
+        decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'])
 
     # path to model state
     state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 65c1530..7468fa8 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -44,7 +44,7 @@ if __name__ == '__main__':
     state_file = ERA5Dataset.state_file(
         NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
         dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
-        decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM)
+        decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'])
 
     # path to model state
     state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
-- 
GitLab