Skip to content
Snippets Groups Projects
Commit 7c1732db authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added a transfer learning option

parent 83dca2d5
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 15 09:45:49 2020
@author: Daniel
"""
# builtins
import os
import sys
# externals
import numpy as np
import torch
import torch.nn.functional as F
# append path to local files to the python search path
sys.path.append('..')
# local modules
from pytorch.trainer import NetworkTrainer
from pytorch.layers import Conv2dSame
from main.config import config
if __name__ == '__main__':
# instanciate the NetworkTrainer class
trainer = NetworkTrainer(config)
trainer.initialize()
# freeze the model state
trainer.model.freeze()
# get the number of input features to the model classifier
in_features = trainer.model.classifier.in_channels
# replace the classification layer
trainer.model.classifier = Conv2dSame(in_channels=in_features,
out_channels=len(trainer.dataset.labels),
kernel_size=1)
# train the model on the new dataset
trainer.train()
......@@ -20,7 +20,7 @@ sys.path.append('..')
# local modules
from pytorch.dataset import SparcsDataset, Cloud95Dataset
from pytorch.models import SegNet
from pytorch.constants import SparcsLabels, Cloud95Labels
class NetworkTrainer(object):
......@@ -31,6 +31,8 @@ class NetworkTrainer(object):
for k, v in config.items():
setattr(self, k, v)
def initialize(self):
# check which dataset the model is trained on
if self.dataset_name == 'Sparcs':
# instanciate the SparcsDataset
......@@ -62,11 +64,14 @@ class NetworkTrainer(object):
# instanciate the segmentation network
print('------------------- Network architecture ---------------------')
self.model = SegNet(in_channels=len(self.dataset.use_bands),
nclasses=len(self.dataset.labels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
if self.pretrained:
self.model = self.from_pretrained()
else:
self.model = self.net(in_channels=len(self.dataset.use_bands),
nclasses=len(self.dataset.labels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
print(self.model)
print('--------------------------------------------------------------')
......@@ -110,6 +115,44 @@ class NetworkTrainer(object):
self.loss_state = self.state.replace('.pt', '_loss.pt')
def from_pretrained(self):
# name of the dataset the pretrained model was trained on
dataset_name = self.pretrained_model.split('_')[1]
# input bands of the pretrained model
bands = self.pretrained_model.split('_')[-1].split('.')[0]
if dataset_name == SparcsDataset.__name__:
# number of input channels
in_channels = len(bands) if bands != 'all' else 10
# instanciate pretrained model architecture
model = self.net(in_channels=in_channels,
nclasses=len(SparcsLabels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
if dataset_name == Cloud95Dataset.__name__:
# number of input channels
in_channels = len(bands) if bands != 'all' else 4
# instanciate pretrained model architecture
model = self.net(in_channels=in_channels,
nclasses=len(Cloud95Labels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
# load pretrained model weights
model.load(self.pretrained_model, inpath=self.state_path)
return model
def ds_len(self, ds, ratio):
return int(np.round(len(ds) * ratio))
......@@ -164,8 +207,8 @@ class NetworkTrainer(object):
max_accuracy = 0
# whether to resume training from an existing model
if os.path.exists(self.state) and self.resume:
state = self.model.load(self.optimizer, self.state_file,
if os.path.exists(self.state) and self.checkpoint:
state = self.model.load(self.state_file, self.optimizer,
self.state_path)
print('Resuming training from {} ...'.format(state))
print('Model epoch: {:d}'.format(self.model.epoch))
......@@ -291,7 +334,7 @@ class NetworkTrainer(object):
# load the model state if evaluating a pretrained model is required
if pretrained:
state = self.model.load(self.optimizer, self.state_file,
state = self.model.load(self.state_file, self.optimizer,
self.state_path)
# send the model to the gpu if available
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment