Skip to content
Snippets Groups Projects
Commit adb6e52b authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added a flag whether to compute the confusion matrix

parent ff6c0869
No related branches found
No related tags found
No related merge requests found
...@@ -21,12 +21,12 @@ import torch.optim as optim ...@@ -21,12 +21,12 @@ import torch.optim as optim
wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/' wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/'
# define which dataset to train on # define which dataset to train on
dataset_name = 'Sparcs' # dataset_name = 'Sparcs'
# dataset_name= 'Cloud95' dataset_name= 'Cloud95'
# path to the dataset # path to the dataset
dataset_path = os.path.join(wd, '_Datasets/Sparcs') # dataset_path = os.path.join(wd, '_Datasets/Sparcs')
# dataset_path = os.path.join(wd, '_Datasets/Cloud95/Training') dataset_path = os.path.join(wd, '_Datasets/Cloud95/Training')
# the csv file containing the names of the informative patches of the # the csv file containing the names of the informative patches of the
# Cloud95 dataset # Cloud95 dataset
...@@ -39,7 +39,7 @@ bands = ['red', 'green', 'blue', 'nir'] ...@@ -39,7 +39,7 @@ bands = ['red', 'green', 'blue', 'nir']
# define the size of the network input # define the size of the network input
# if None, the size will default to the size of a scene # if None, the size will default to the size of a scene
tile_size = 125 tile_size = 192
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -76,7 +76,7 @@ ttratio = 1 ...@@ -76,7 +76,7 @@ ttratio = 1
# (ttratio * tvratio) * 100 % will be used as the training dataset # (ttratio * tvratio) * 100 % will be used as the training dataset
# (1 - ttratio * tvratio) * 100 % will be used as the validation dataset # (1 - ttratio * tvratio) * 100 % will be used as the validation dataset
tvratio = 0.8 tvratio = 0.2
# define the batch size # define the batch size
# determines how many samples of the dataset are processed until the weights # determines how many samples of the dataset are processed until the weights
...@@ -88,14 +88,14 @@ checkpoint = False ...@@ -88,14 +88,14 @@ checkpoint = False
# whether to early stop training if the accuracy (loss) on the validation set # whether to early stop training if the accuracy (loss) on the validation set
# does not increase (decrease) more than delta over patience epochs # does not increase (decrease) more than delta over patience epochs
early_stop = True early_stop = False
mode = 'max' mode = 'max'
delta = 0.005 delta = 0
patience = 10 patience = 10
# define the number of epochs: the number of maximum iterations over the whole # define the number of epochs: the number of maximum iterations over the whole
# training dataset # training dataset
epochs = 200 epochs = 5
# define the number of threads # define the number of threads
nthreads = os.cpu_count() nthreads = os.cpu_count()
...@@ -117,6 +117,9 @@ state_path = os.path.join(os.getcwd(), '_models/') ...@@ -117,6 +117,9 @@ state_path = os.path.join(os.getcwd(), '_models/')
# ------------------------- Plotting configuration ---------------------------- # ------------------------- Plotting configuration ----------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# whether to compute and plot confusion matrix for the entire validation set
plot_cm = False
# whether to save plots of (input, ground truth, prediction) of the validation # whether to save plots of (input, ground truth, prediction) of the validation
# dataset to disk # dataset to disk
# output path is: current_working_directory/_samples/ # output path is: current_working_directory/_samples/
......
...@@ -12,23 +12,25 @@ import matplotlib.pyplot as plt ...@@ -12,23 +12,25 @@ import matplotlib.pyplot as plt
sys.path.append('..') sys.path.append('..')
# local modules # local modules
from main.config import state_path, plot_samples, nsamples, plot_bands, seed from main.config import (state_path, plot_samples, nsamples, plot_bands, seed,
plot_cm)
from main.init import state_file, trainer from main.init import state_file, trainer
if __name__ == '__main__': if __name__ == '__main__':
# predict each batch in the validation set if plot_cm:
cm, accuracy, loss = trainer.predict(state_path, state_file, # predict each batch in the validation set
confusion=True) cm, accuracy, loss = trainer.predict(state_path, state_file,
confusion=True)
# # calculate overal accuracy # calculate overal accuracy
acc = (cm.diag().sum() / cm.sum()).numpy().item() acc = (cm.diag().sum() / cm.sum()).numpy().item()
print('After training for {:d} epochs, we achieved an overall accuracy of ' print('After training for {:d} epochs, we achieved an overall '
'{:.2f}% on the validation set!'.format(trainer.model.epoch, 'accuracy of {:.2f}% on the validation set!'
acc * 100)) .format(trainer.model.epoch, acc * 100))
# # plot confusion matrix # plot confusion matrix
trainer.dataset.plot_confusion_matrix(cm, state=state_file) trainer.dataset.plot_confusion_matrix(cm, state=state_file)
# plot loss and accuracy # plot loss and accuracy
trainer.dataset.plot_loss( trainer.dataset.plot_loss(
...@@ -37,6 +39,10 @@ if __name__ == '__main__': ...@@ -37,6 +39,10 @@ if __name__ == '__main__':
# whether to plot the samples of the validation dataset # whether to plot the samples of the validation dataset
if plot_samples: if plot_samples:
# load pretrained model
state = trainer.model.load(trainer.optimizer, state_file, state_path)
trainer.model.eval()
# base filename for each sample # base filename for each sample
fname = state_file.split('.pt')[0] fname = state_file.split('.pt')[0]
...@@ -68,4 +74,4 @@ if __name__ == '__main__': ...@@ -68,4 +74,4 @@ if __name__ == '__main__':
bands=plot_bands, bands=plot_bands,
state=sname, state=sname,
stretch=True, stretch=True,
alpha=5) alpha=5)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment