From 463e73d3d32f4aea3de9014d5a7968bdc795cbbb Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 24 Aug 2020 10:49:58 +0200
Subject: [PATCH] Network.load() now instanciates the optimizer saved by
 Network.save().

---
 pysegcnn/core/models.py  | 33 ++++++++++++++++++++-------------
 pysegcnn/core/trainer.py |  2 +-
 2 files changed, 21 insertions(+), 14 deletions(-)

diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py
index c04fa4c..39ba1a1 100644
--- a/pysegcnn/core/models.py
+++ b/pysegcnn/core/models.py
@@ -116,8 +116,9 @@ class Network(nn.Module):
         # store the spectral bands the model is trained with
         model_state['bands'] = bands
 
-        # store model class
+        # store model and optimizer class
         model_state['cls'] = self.__class__
+        model_state['optim_cls'] = optimizer.__class__
 
         # store construction parameters to instanciate the network
         model_state['params'] = {
@@ -127,6 +128,9 @@ class Network(nn.Module):
             'in_channels': self.in_channels
             }
 
+        # store optimizer construction parameters
+        model_state['optim_params'] = optimizer.defaults
+
         # store optional keyword arguments
         model_state['params'] = {**model_state['params'], **self.kwargs}
 
@@ -144,21 +148,17 @@ class Network(nn.Module):
         return model_state
 
     @staticmethod
-    def load(state_file, optimizer=None):
+    def load(state_file):
         """Load a model state.
 
-        Returns the model in ``state_file`` with the pretrained model weights.
-        If ``optimizer`` is specified, the optimizer parameters are also loaded
-        from ``state_file``. This is useful when resuming training an existing
-        model.
+        Returns the model in ``state_file`` with the pretrained model and
+        optimizer weights. Useful when resuming training an existing model.
 
         Parameters
         ----------
         state_file : `str` or `pathlib.Path`
            The model state file. Model state files are stored in
            pysegcnn/main/_models.
-        optimizer : `torch.optim.Optimizer` or `None`, optional
-           The optimizer used to train the model.
 
         Raises
         ------
@@ -169,7 +169,7 @@ class Network(nn.Module):
         -------
         model : `pysegcnn.core.models.Network`
             The pretrained model.
-        optimizer : `torch.optim.Optimizer` or `None`
+        optimizer : `torch.optim.Optimizer`
            The optimizer used to train the model.
         model_state : '`dict`
             A dictionary containing the model and optimizer state, as
@@ -185,8 +185,9 @@ class Network(nn.Module):
         # load the model state
         model_state = torch.load(state_file)
 
-        # the model class
+        # the model and optimizer class
         model_class = model_state['cls']
+        optim_class = model_state['optim_cls']
 
         # instanciate pretrained model architecture
         model = model_class(**model_state['params'])
@@ -200,9 +201,11 @@ class Network(nn.Module):
         model.epoch = model_state['epoch']
 
         # resume optimizer parameters
-        if optimizer is not None:
-            LOGGER.info('Loading optimizer parameters ...')
-            optimizer.load_state_dict(model_state['optim_state_dict'])
+        LOGGER.info('Loading optimizer parameters ...')
+        optimizer = optim_class(model.parameters(),
+                                **model_state['optim_params'])
+        optimizer.load_state_dict(model_state['optim_state_dict'])
+
         LOGGER.info('Model epoch: {:d}'.format(model.epoch))
 
         return model, optimizer, model_state
@@ -223,6 +226,10 @@ class Network(nn.Module):
 class UNet(Network):
     """A PyTorch implementation of `U-Net`_.
 
+    Slightly modified version of U-Net:
+        - each convolution is followed by a batch normalization layer
+        - the upsampling is implemented by a 2x2 max unpooling operation
+
     .. _U-Net:
         https://arxiv.org/abs/1505.04597
 
diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 92014fb..bc9714f 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -680,7 +680,7 @@ class ModelConfig(BaseConfig):
                            .format(state_file.name))
         else:
             # load model checkpoint
-            model, optimizer, model_state = Network.load(state_file, optimizer)
+            model, optimizer, model_state = Network.load(state_file)
 
             # load model loss and accuracy
 
-- 
GitLab