From 7c1732db17e48061c36f74937c451da248b9b116 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 15 Jul 2020 12:25:01 +0200
Subject: [PATCH] Added a transfer learning option

---
 main/transfer.py   | 44 +++++++++++++++++++++++++++++++++
 pytorch/trainer.py | 61 +++++++++++++++++++++++++++++++++++++++-------
 2 files changed, 96 insertions(+), 9 deletions(-)
 create mode 100644 main/transfer.py

diff --git a/main/transfer.py b/main/transfer.py
new file mode 100644
index 0000000..fdce169
--- /dev/null
+++ b/main/transfer.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Jul 15 09:45:49 2020
+
+@author: Daniel
+"""
+
+# builtins
+import os
+import sys
+
+# externals
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+# append path to local files to the python search path
+sys.path.append('..')
+
+# local modules
+from pytorch.trainer import NetworkTrainer
+from pytorch.layers import Conv2dSame
+from main.config import config
+
+
+if __name__ == '__main__':
+
+    # instanciate the NetworkTrainer class
+    trainer = NetworkTrainer(config)
+    trainer.initialize()
+
+    # freeze the model state
+    trainer.model.freeze()
+
+    # get the number of input features to the model classifier
+    in_features = trainer.model.classifier.in_channels
+
+    # replace the classification layer
+    trainer.model.classifier = Conv2dSame(in_channels=in_features,
+                                          out_channels=len(trainer.dataset.labels),
+                                          kernel_size=1)
+
+    # train the model on the new dataset
+    trainer.train()
diff --git a/pytorch/trainer.py b/pytorch/trainer.py
index e098481..e6ee801 100755
--- a/pytorch/trainer.py
+++ b/pytorch/trainer.py
@@ -20,7 +20,7 @@ sys.path.append('..')
 
 # local modules
 from pytorch.dataset import SparcsDataset, Cloud95Dataset
-from pytorch.models import SegNet
+from pytorch.constants import SparcsLabels, Cloud95Labels
 
 
 class NetworkTrainer(object):
@@ -31,6 +31,8 @@ class NetworkTrainer(object):
         for k, v in config.items():
             setattr(self, k, v)
 
+    def initialize(self):
+
         # check which dataset the model is trained on
         if self.dataset_name == 'Sparcs':
             # instanciate the SparcsDataset
@@ -62,11 +64,14 @@ class NetworkTrainer(object):
 
         # instanciate the segmentation network
         print('------------------- Network architecture ---------------------')
-        self.model = SegNet(in_channels=len(self.dataset.use_bands),
-                            nclasses=len(self.dataset.labels),
-                            filters=self.filters,
-                            skip=self.skip_connection,
-                            **self.kwargs)
+        if self.pretrained:
+            self.model = self.from_pretrained()
+        else:
+            self.model = self.net(in_channels=len(self.dataset.use_bands),
+                                  nclasses=len(self.dataset.labels),
+                                  filters=self.filters,
+                                  skip=self.skip_connection,
+                                  **self.kwargs)
         print(self.model)
         print('--------------------------------------------------------------')
 
@@ -110,6 +115,44 @@ class NetworkTrainer(object):
         self.loss_state = self.state.replace('.pt', '_loss.pt')
 
 
+    def from_pretrained(self):
+
+        # name of the dataset the pretrained model was trained on
+        dataset_name = self.pretrained_model.split('_')[1]
+
+        # input bands of the pretrained model
+        bands = self.pretrained_model.split('_')[-1].split('.')[0]
+
+        if dataset_name == SparcsDataset.__name__:
+
+            # number of input channels
+            in_channels = len(bands) if bands != 'all' else 10
+
+            # instanciate pretrained model architecture
+            model = self.net(in_channels=in_channels,
+                             nclasses=len(SparcsLabels),
+                             filters=self.filters,
+                             skip=self.skip_connection,
+                             **self.kwargs)
+
+        if dataset_name == Cloud95Dataset.__name__:
+
+            # number of input channels
+            in_channels = len(bands) if bands != 'all' else 4
+
+            # instanciate pretrained model architecture
+            model = self.net(in_channels=in_channels,
+                             nclasses=len(Cloud95Labels),
+                             filters=self.filters,
+                             skip=self.skip_connection,
+                             **self.kwargs)
+
+        # load pretrained model weights
+        model.load(self.pretrained_model, inpath=self.state_path)
+
+        return model
+
+
     def ds_len(self, ds, ratio):
         return int(np.round(len(ds) * ratio))
 
@@ -164,8 +207,8 @@ class NetworkTrainer(object):
             max_accuracy = 0
 
         # whether to resume training from an existing model
-        if os.path.exists(self.state) and self.resume:
-            state = self.model.load(self.optimizer, self.state_file,
+        if os.path.exists(self.state) and self.checkpoint:
+            state = self.model.load(self.state_file, self.optimizer,
                                     self.state_path)
             print('Resuming training from {} ...'.format(state))
             print('Model epoch: {:d}'.format(self.model.epoch))
@@ -291,7 +334,7 @@ class NetworkTrainer(object):
 
         # load the model state if evaluating a pretrained model is required
         if pretrained:
-            state = self.model.load(self.optimizer, self.state_file,
+            state = self.model.load(self.state_file, self.optimizer,
                                     self.state_path)
 
         # send the model to the gpu if available
-- 
GitLab