From 7c1732db17e48061c36f74937c451da248b9b116 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Wed, 15 Jul 2020 12:25:01 +0200 Subject: [PATCH] Added a transfer learning option --- main/transfer.py | 44 +++++++++++++++++++++++++++++++++ pytorch/trainer.py | 61 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 96 insertions(+), 9 deletions(-) create mode 100644 main/transfer.py diff --git a/main/transfer.py b/main/transfer.py new file mode 100644 index 0000000..fdce169 --- /dev/null +++ b/main/transfer.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Jul 15 09:45:49 2020 + +@author: Daniel +""" + +# builtins +import os +import sys + +# externals +import numpy as np +import torch +import torch.nn.functional as F + +# append path to local files to the python search path +sys.path.append('..') + +# local modules +from pytorch.trainer import NetworkTrainer +from pytorch.layers import Conv2dSame +from main.config import config + + +if __name__ == '__main__': + + # instanciate the NetworkTrainer class + trainer = NetworkTrainer(config) + trainer.initialize() + + # freeze the model state + trainer.model.freeze() + + # get the number of input features to the model classifier + in_features = trainer.model.classifier.in_channels + + # replace the classification layer + trainer.model.classifier = Conv2dSame(in_channels=in_features, + out_channels=len(trainer.dataset.labels), + kernel_size=1) + + # train the model on the new dataset + trainer.train() diff --git a/pytorch/trainer.py b/pytorch/trainer.py index e098481..e6ee801 100755 --- a/pytorch/trainer.py +++ b/pytorch/trainer.py @@ -20,7 +20,7 @@ sys.path.append('..') # local modules from pytorch.dataset import SparcsDataset, Cloud95Dataset -from pytorch.models import SegNet +from pytorch.constants import SparcsLabels, Cloud95Labels class NetworkTrainer(object): @@ -31,6 +31,8 @@ class NetworkTrainer(object): for k, v in config.items(): setattr(self, k, v) + def initialize(self): + # check which dataset the model is trained on if self.dataset_name == 'Sparcs': # instanciate the SparcsDataset @@ -62,11 +64,14 @@ class NetworkTrainer(object): # instanciate the segmentation network print('------------------- Network architecture ---------------------') - self.model = SegNet(in_channels=len(self.dataset.use_bands), - nclasses=len(self.dataset.labels), - filters=self.filters, - skip=self.skip_connection, - **self.kwargs) + if self.pretrained: + self.model = self.from_pretrained() + else: + self.model = self.net(in_channels=len(self.dataset.use_bands), + nclasses=len(self.dataset.labels), + filters=self.filters, + skip=self.skip_connection, + **self.kwargs) print(self.model) print('--------------------------------------------------------------') @@ -110,6 +115,44 @@ class NetworkTrainer(object): self.loss_state = self.state.replace('.pt', '_loss.pt') + def from_pretrained(self): + + # name of the dataset the pretrained model was trained on + dataset_name = self.pretrained_model.split('_')[1] + + # input bands of the pretrained model + bands = self.pretrained_model.split('_')[-1].split('.')[0] + + if dataset_name == SparcsDataset.__name__: + + # number of input channels + in_channels = len(bands) if bands != 'all' else 10 + + # instanciate pretrained model architecture + model = self.net(in_channels=in_channels, + nclasses=len(SparcsLabels), + filters=self.filters, + skip=self.skip_connection, + **self.kwargs) + + if dataset_name == Cloud95Dataset.__name__: + + # number of input channels + in_channels = len(bands) if bands != 'all' else 4 + + # instanciate pretrained model architecture + model = self.net(in_channels=in_channels, + nclasses=len(Cloud95Labels), + filters=self.filters, + skip=self.skip_connection, + **self.kwargs) + + # load pretrained model weights + model.load(self.pretrained_model, inpath=self.state_path) + + return model + + def ds_len(self, ds, ratio): return int(np.round(len(ds) * ratio)) @@ -164,8 +207,8 @@ class NetworkTrainer(object): max_accuracy = 0 # whether to resume training from an existing model - if os.path.exists(self.state) and self.resume: - state = self.model.load(self.optimizer, self.state_file, + if os.path.exists(self.state) and self.checkpoint: + state = self.model.load(self.state_file, self.optimizer, self.state_path) print('Resuming training from {} ...'.format(state)) print('Model epoch: {:d}'.format(self.model.epoch)) @@ -291,7 +334,7 @@ class NetworkTrainer(object): # load the model state if evaluating a pretrained model is required if pretrained: - state = self.model.load(self.optimizer, self.state_file, + state = self.model.load(self.state_file, self.optimizer, self.state_path) # send the model to the gpu if available -- GitLab