Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
4338ddfb
Commit
4338ddfb
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Adapted main modules to changes in pysegcnn.core.trainer.py
parent
5128033c
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
pysegcnn/main/config.py
+29
-15
29 additions, 15 deletions
pysegcnn/main/config.py
pysegcnn/main/eval.py
+53
-33
53 additions, 33 deletions
pysegcnn/main/eval.py
pysegcnn/main/train.py
+57
-8
57 additions, 8 deletions
pysegcnn/main/train.py
with
139 additions
and
56 deletions
pysegcnn/main/config.py
+
29
−
15
View file @
4338ddfb
...
...
@@ -16,14 +16,19 @@ import os
# path to this file
HERE
=
os
.
path
.
abspath
(
os
.
path
.
dirname
(
__file__
))
# path to the datasets
D
ATASET
_PATH
=
'
C:/Eurac/2020/_Datasets/
'
# D
ATASET
_PATH = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/_Datasets/'
# path to the datasets
on the current machine
D
RIVE
_PATH
=
'
C:/Eurac/2020/_Datasets/
'
# D
RIVE
_PATH = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/_Datasets/'
# name of the datasets
DATASET_NAME
=
'
Sparcs
'
# DATASET_NAME = 'Cloud95/Training/'
# DATASET_NAME =' ProSnow/Garmisch/
# DATASET_NAME = 'Cloud95'
# DATASET_NAME = 'Garmisch'
# path to the dataset
DATASET_PATH
=
os
.
path
.
join
(
DRIVE_PATH
,
DATASET_NAME
)
# DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME, 'Training')
# DATASET_PATH = os.path.join(DRIVE_PATH, 'ProSnow', DATASET_NAME)
# the dataset configuration dictionary
dataset_config
=
{
...
...
@@ -32,8 +37,11 @@ dataset_config = {
# -------------------------------------------------------------------------
# name of the dataset
'
dataset_name
'
:
DATASET_NAME
,
# path to the dataset
'
root_dir
'
:
os
.
path
.
join
(
DATASET_PATH
,
DATASET_NAME
)
,
'
root_dir
'
:
DATASET_PATH
,
# a pattern to match the ground truth file naming convention
'
gt_pattern
'
:
'
*mask.png
'
,
...
...
@@ -53,8 +61,8 @@ dataset_config = {
# tiles of size (tile_size, tile_size)
'
pad
'
:
True
,
#
set
random seed for
reproducibility of the training, validation
# and test data split
#
the
random seed for
the numpy random number generator
#
ensures reproducibility of the training, validation
and test data split
# used if split_mode='random' and split_mode='scene'
'
seed
'
:
0
,
...
...
@@ -134,12 +142,12 @@ split_config = {
# (ttratio * 100) % of the dataset will be used for training and
# validation
# used if split_mode='random' and split_mode='scene'
'
ttratio
'
:
1
,
'
ttratio
'
:
0.05
,
# (ttratio * tvratio) * 100 % will be used as for training
# (1 - ttratio * tvratio) * 100 % will be used for validation
# used if split_mode='random' and split_mode='scene'
'
tvratio
'
:
0.
8
,
'
tvratio
'
:
0.
5
,
# the date to split the scenes
# format: 'yyyymmdd'
...
...
@@ -211,7 +219,10 @@ model_config = {
# define the batch size
# determines how many samples of the dataset are processed until the
# weights of the network are updated (via mini-batch gradient descent)
'
batch_size
'
:
64
'
batch_size
'
:
64
,
# the seed for the random number generator intializing the network weights
'
torch_seed
'
:
0
}
...
...
@@ -223,6 +234,9 @@ train_config = {
# -------------------------------------------------------------------------
# whether to save the model state to disk
'
save
'
:
True
,
# whether to early stop training if the accuracy on the validation set
# does not increase more than delta over patience epochs
'
early_stop
'
:
True
,
...
...
@@ -246,7 +260,7 @@ train_config = {
}
# the evaluation configuration file
eval
uation
_config
=
{
eval_config
=
{
# ----------------------------- Evaluation --------------------------------
...
...
@@ -256,8 +270,8 @@ evaluation_config = {
# pysegcnn.main.eval.py
# the dataset to evaluate the model on
# test=False means evaluating on the validation set
# test=True means evaluating on the test set
# test=False
, 0
means evaluating on the validation set
# test=True
, 1
means evaluating on the test set
# test=None means evaluating on the training set
'
test
'
:
False
,
...
...
@@ -294,4 +308,4 @@ config = {**dataset_config,
**
split_config
,
**
model_config
,
**
train_config
,
**
eval
uation
_config
}
**
eval_config
}
This diff is collapsed.
Click to expand it.
pysegcnn/main/eval.py
+
53
−
33
View file @
4338ddfb
...
...
@@ -8,63 +8,83 @@ Created on Wed Jul 29 15:57:01 2020
import
os
# locals
from
pysegcnn.core.trainer
import
NetworkTrainer
from
pysegcnn.core.trainer
import
(
DatasetConfig
,
SplitConfig
,
ModelConfig
,
TrainConfig
,
EvalConfig
)
from
pysegcnn.core.predict
import
predict_samples
,
predict_scenes
from
pysegcnn.main.config
import
config
,
HERE
from
pysegcnn.main.config
import
(
dataset_config
,
split_config
,
model_config
,
train_config
,
eval_config
,
HERE
)
from
pysegcnn.core.graphics
import
plot_confusion_matrix
,
plot_loss
if
__name__
==
'
__main__
'
:
# instanciate the NetworkTrainer class
trainer
=
NetworkTrainer
(
config
)
trainer
# (i) instanciate the configurations
dc
=
DatasetConfig
(
**
dataset_config
)
sc
=
SplitConfig
(
**
split_config
)
mc
=
ModelConfig
(
**
model_config
)
tc
=
TrainConfig
(
**
train_config
)
ec
=
EvalConfig
(
**
eval_config
)
# (ii) instanciate the dataset
ds
=
dc
.
init_dataset
()
ds
# (iii) instanciate the training, validation and test datasets
train_ds
,
valid_ds
,
test_ds
=
sc
.
train_val_test_split
(
ds
)
# (iv) instanciate the model state files
state_file
,
loss_state
=
mc
.
init_state
(
ds
,
sc
,
tc
)
# (v) instanciate the model
model
=
mc
.
init_model
(
ds
)
# (vi) instanciate the optimizer
optimizer
=
tc
.
init_optimizer
(
model
)
# plot loss and accuracy
plot_loss
(
trainer
.
loss_state
,
outpath
=
os
.
path
.
join
(
HERE
,
'
_graphics/
'
))
plot_loss
(
loss_state
,
outpath
=
os
.
path
.
join
(
HERE
,
'
_graphics/
'
))
# check whether to evaluate the model on the training set, validation set
# or the test set
if
trainer
.
test
is
None
:
ds
=
trainer
.
train_ds
if
ec
.
test
is
None
:
ds
=
train_ds
else
:
ds
=
trainer
.
test_ds
if
trainer
.
test
else
trainer
.
valid_ds
ds
=
test_ds
if
ec
.
test
else
valid_ds
# keyword arguments for plotting
kwargs
=
{
'
bands
'
:
ec
.
plot_bands
,
'
outpath
'
:
os
.
path
.
join
(
HERE
,
'
_scenes/
'
),
'
stretch
'
:
True
,
'
alpha
'
:
5
}
# whether to predict each sample or each scene individually
if
trainer
.
predict_scene
:
if
ec
.
predict_scene
:
# reconstruct and predict the scenes in the validation/test set
scenes
,
cm
=
predict_scenes
(
ds
,
trainer
.
model
,
trainer
.
optimizer
,
trainer
.
state_path
,
trainer
.
state_file
,
None
,
trainer
.
cm
,
trainer
.
plot_scenes
,
bands
=
trainer
.
plot_bands
,
outpath
=
os
.
path
.
join
(
HERE
,
'
_scenes/
'
),
stretch
=
True
,
alpha
=
5
)
model
,
optimizer
,
state_file
,
scene_id
=
None
,
cm
=
ec
.
cm
,
plot_scenes
=
ec
.
plot_scenes
,
**
kwargs
)
else
:
# predict the samples in the validation/test set
samples
,
cm
=
predict_samples
(
ds
,
trainer
.
model
,
trainer
.
optimizer
,
trainer
.
state_path
,
trainer
.
state_file
,
trainer
.
cm
,
trainer
.
plot_samples
,
bands
=
trainer
.
plot_bands
,
outpath
=
os
.
path
.
join
(
HERE
,
'
_samples/
'
),
stretch
=
True
,
alpha
=
5
)
model
,
optimizer
,
state_file
,
cm
=
ec
.
cm
,
plot_scenes
=
ec
.
plot_scenes
,
**
kwargs
)
# whether to plot the confusion matrix
if
trainer
.
cm
:
if
ec
.
cm
:
plot_confusion_matrix
(
cm
,
ds
.
dataset
.
labels
,
normalize
=
True
,
state
=
trainer
.
state_file
,
state
=
state_file
.
name
.
replace
(
'
.pt
'
,
'
.png
'
)
,
outpath
=
os
.
path
.
join
(
HERE
,
'
_graphics/
'
)
)
This diff is collapsed.
Click to expand it.
pysegcnn/main/train.py
+
57
−
8
View file @
4338ddfb
...
...
@@ -6,19 +6,68 @@ Created on Tue Jun 30 09:33:38 2020
@author: Daniel
"""
# locals
from
pysegcnn.core.initconf
import
NetworkTrainer
from
pysegcnn.core.trainer
import
(
DatasetConfig
,
SplitConfig
,
ModelConfig
,
TrainConfig
,
NetworkTrainer
)
from
pysegcnn.main.config
import
(
dataset_config
,
split_config
,
model_config
,
train_config
)
if
__name__
==
'
__main__
'
:
# instanciate the NetworkTrainer class
trainer
=
NetworkTrainer
(
dconfig
=
dataset_config
,
sconfig
=
split_config
,
mconfig
=
model_config
,
tconfig
=
train_config
)
trainer
# write code that checks for list of seeds, band combinations etc. here.
# train the network
# (i) instanciate the configurations
dc
=
DatasetConfig
(
**
dataset_config
)
sc
=
SplitConfig
(
**
split_config
)
mc
=
ModelConfig
(
**
model_config
)
tc
=
TrainConfig
(
**
train_config
)
# (ii) instanciate the dataset
ds
=
dc
.
init_dataset
()
ds
# (iii) instanciate the training, validation and test datasets and
# dataloaders
train_ds
,
valid_ds
,
test_ds
=
sc
.
train_val_test_split
(
ds
)
train_dl
,
valid_dl
,
test_dl
=
sc
.
dataloaders
(
train_ds
,
valid_ds
,
test_ds
,
batch_size
=
mc
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
)
# (iv) instanciate the model state files
state_file
,
loss_state
=
mc
.
init_state
(
ds
,
sc
,
tc
)
# (v) instanciate the model
model
=
mc
.
init_model
(
ds
)
# (vi) instanciate the optimizer and the loss function
optimizer
=
tc
.
init_optimizer
(
model
)
loss_function
=
tc
.
init_loss_function
()
# (vii) resume training from an existing model checkpoint
checkpoint_state
,
max_accuracy
=
mc
.
load_checkpoint
(
state_file
,
loss_state
,
model
,
optimizer
)
# (viii) initialize network trainer class for eays model training
trainer
=
NetworkTrainer
(
model
=
model
,
optimizer
=
optimizer
,
loss_function
=
loss_function
,
train_dl
=
train_dl
,
valid_dl
=
valid_dl
,
state_file
=
state_file
,
loss_state
=
loss_state
,
epochs
=
tc
.
epochs
,
nthreads
=
tc
.
nthreads
,
early_stop
=
tc
.
early_stop
,
mode
=
tc
.
mode
,
delta
=
tc
.
delta
,
patience
=
tc
.
patience
,
max_accuracy
=
max_accuracy
,
checkpoint_state
=
checkpoint_state
,
save
=
tc
.
save
)
# (ix) train model
training_state
=
trainer
.
train
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment