From bac6bc17571fae0b2cfd1b3a9b41219f687f8e7b Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 15 Jul 2020 17:14:01 +0200
Subject: [PATCH] Improved saving method to rebuild model in case of checkpoint
 or transfer learning

---
 pytorch/models.py | 38 +++++++++++++++++++++++++++++++-------
 1 file changed, 31 insertions(+), 7 deletions(-)

diff --git a/pytorch/models.py b/pytorch/models.py
index a0eb58a..e1add80 100644
--- a/pytorch/models.py
+++ b/pytorch/models.py
@@ -35,20 +35,37 @@ class Network(nn.Module):
         for param in self.parameters():
             param.requires_grad = True
 
-    def save(self, optimizer, state_file,
+    def save(self, state_file, optimizer, bands,
              outpath=os.path.join(os.getcwd(), '_models')):
 
         # check if the output path exists and if not, create it
         if not os.path.isdir(outpath):
             os.makedirs(outpath, exist_ok=True)
 
-        # create a dictionary that stores the model state
-        model_state = {
-            'epoch': self.epoch,
-            'model_state_dict': self.state_dict(),
-            'optim_state_dict': optimizer.state_dict()
+        # initialize dictionary to store network parameters
+        model_state = {}
+
+        # store input bands
+        model_state['bands'] = bands
+
+        # store construction parameters to instanciate the network
+        model_state['params'] = {
+            'skip': self.skip,
+            'filters': self.nfilters,
+            'nclasses': self.nclasses,
+            'in_channels': self.in_channels
             }
 
+        # store optional keyword arguments
+        model_state['kwargs'] = self.kwargs
+
+        # store model epoch
+        model_state['epoch'] = self.epoch
+
+        # store model and optimizer state
+        model_state['model_state_dict'] = self.state_dict()
+        model_state['optim_state_dict'] = optimizer.state_dict()
+
         # model state dictionary stores the values of all trainable parameters
         state = os.path.join(outpath, state_file)
         torch.save(model_state, state)
@@ -87,9 +104,16 @@ class UNet(Network):
         # number of classes
         self.nclasses = nclasses
 
-        # get the configuration for the convolutional layers of the encoder
+        # configuration of the convolutional layers in the network
+        self.kwargs = kwargs
+        self.nfilters = filters
+
+        # convolutional layers of the encoder
         self.filters = np.hstack([np.array(in_channels), np.array(filters)])
 
+        # whether to apply skip connections
+        self.skip = skip
+
         # number of epochs trained
         self.epoch = 0
 
-- 
GitLab