Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
fe57ae65
Commit
fe57ae65
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Adjusting main executables to changes in core.trainer.py
parent
f7112679
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
pysegcnn/main/config.py
+14
-10
14 additions, 10 deletions
pysegcnn/main/config.py
pysegcnn/main/eval.py
+20
-30
20 additions, 30 deletions
pysegcnn/main/eval.py
pysegcnn/main/train.py
+14
-20
14 additions, 20 deletions
pysegcnn/main/train.py
with
48 additions
and
60 deletions
pysegcnn/main/config.py
+
14
−
10
View file @
fe57ae65
...
...
@@ -30,6 +30,12 @@ DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME)
# DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME, 'Training')
# DATASET_PATH = os.path.join(DRIVE_PATH, 'ProSnow', DATASET_NAME)
# path to store the model states
MODEL_PATH
=
os
.
path
.
join
(
HERE
,
'
_models/
'
)
# path to store model logs
LOG_PATH
=
os
.
path
.
join
(
HERE
,
'
_logs/
'
)
# the dataset configuration dictionary
dataset_config
=
{
...
...
@@ -189,9 +195,6 @@ model_config = {
'
dilation
'
:
1
# the field of view of the kernel
},
# path to save trained models
'
state_path
'
:
os
.
path
.
join
(
HERE
,
'
_models/
'
),
# Transfer learning -------------------------------------------------------
# Use pretrained=True only if you wish to fine-tune a pre-trained model
...
...
@@ -203,7 +206,7 @@ model_config = {
# was trained on
# whether to use a pretrained model for transfer learning
'
transfer
'
:
Tru
e
,
'
transfer
'
:
Fals
e
,
# name of the pretrained model to apply to a different dataset
'
pretrained_model
'
:
'
UNet_SparcsDataset_t125_b64_rgbn.pt
'
,
...
...
@@ -226,6 +229,8 @@ model_config = {
# -------------------------------------------------------------------------
# whether to save the model state to disk
# model states are saved in: pysegcnn/main/_models
# model log files are saved in: pysegcnn/main/_logs
'
save
'
:
True
,
# whether to early stop training if the accuracy on the validation set
...
...
@@ -260,6 +265,10 @@ eval_config = {
# these options are only used for evaluating a trained model using
# pysegcnn.main.eval.py
# the model to evaluate
'
state_file
'
:
os
.
path
.
join
(
MODEL_PATH
,
'
UNet_SparcsDataset_Adam_SceneSplit_s0_t005v05_t125_b64_r4g3b2n5.pt
'
),
# the dataset to evaluate the model on
# test=False, 0 means evaluating on the validation set
# test=True, 1 means evaluating on the test set
...
...
@@ -267,6 +276,7 @@ eval_config = {
'
test
'
:
False
,
# whether to compute and plot the confusion matrix
# output path is: pysegcnn/main/_graphics/
'
cm
'
:
True
,
# whether to predict each sample or each scene individually
...
...
@@ -299,9 +309,3 @@ eval_config = {
'
alpha
'
:
5
}
# the complete configuration
config
=
{
**
dataset_config
,
**
split_config
,
**
model_config
,
**
eval_config
}
This diff is collapsed.
Click to expand it.
pysegcnn/main/eval.py
+
20
−
30
View file @
fe57ae65
...
...
@@ -5,52 +5,41 @@ Created on Wed Jul 29 15:57:01 2020
@author: Daniel
"""
# builtins
import
os
from
logging.config
import
dictConfig
# locals
from
pysegcnn.core.
trainer
import
(
DatasetConfig
,
SplitConfig
,
ModelConfig
,
State
Config
,
Eval
Config
)
from
pysegcnn.core.
models
import
Network
from
pysegcnn.core.trainer
import
Eval
Config
,
Log
Config
from
pysegcnn.core.predict
import
predict_samples
,
predict_scenes
from
pysegcnn.main.config
import
(
dataset_config
,
split_config
,
model_config
,
train_config
,
eval_config
,
HERE
)
from
pysegcnn.core.logging
import
log_conf
from
pysegcnn.core.graphics
import
plot_confusion_matrix
,
plot_loss
from
pysegcnn.main.config
import
eval_config
if
__name__
==
'
__main__
'
:
# (i) instanciate the configurations
dc
=
DatasetConfig
(
**
dataset_config
)
sc
=
SplitConfig
(
**
split_config
)
mc
=
ModelConfig
(
**
model_config
)
# instanciate the evaluation configuration
ec
=
EvalConfig
(
**
eval_config
)
# (ii) instanciate the dataset
ds
=
dc
.
init_dataset
()
# initialize logging
log
=
LogConfig
(
ec
.
state_file
)
dictConfig
(
log_conf
(
log
.
log_file
))
# (iii) instanciate the training, validation and test datasets
train_ds
,
valid_ds
,
test_ds
=
sc
.
train_val_test_split
(
ds
)
# (iv) instanciate the model state
state
=
StateConfig
(
ds
,
sc
,
mc
)
state_file
,
loss_state
=
state
.
init_state
()
# (vii) load pretrained model weights
model
,
_
=
mc
.
load_pretrained
(
state_file
)
model
.
state_file
=
state_file
# load the model state
model
,
_
,
model_state
=
Network
.
load
(
ec
.
state_file
)
# plot loss and accuracy
plot_loss
(
loss_state
,
outpath
=
os
.
path
.
join
(
HERE
,
'
_graphics/
'
)
)
plot_loss
(
ec
.
state_file
,
outpath
=
ec
.
models_path
)
# check whether to evaluate the model on the training set, validation set
# or the test set
if
ec
.
test
is
None
:
ds
=
train_ds
ds
=
model_state
[
'
train_ds
'
]
else
:
ds
=
test_ds
if
ec
.
test
else
valid_ds
ds
=
model_state
[
'
test_ds
'
]
if
ec
.
test
else
model_state
[
'
valid_ds
'
]
# keyword arguments for plotting
kwargs
=
{
'
bands
'
:
ec
.
plot_bands
,
'
outpath
'
:
os
.
path
.
join
(
HERE
,
'
_scenes/
'
),
'
alpha
'
:
ec
.
alpha
,
'
figsize
'
:
ec
.
figsize
}
...
...
@@ -58,16 +47,17 @@ if __name__ == '__main__':
if
ec
.
predict_scene
:
# reconstruct and predict the scenes in the validation/test set
scenes
,
cm
=
predict_scenes
(
ds
,
model
,
scene_id
=
None
,
cm
=
ec
.
cm
,
plot
=
ec
.
plot_scenes
,
**
kwargs
)
plot
=
ec
.
plot_scenes
,
outpath
=
ec
.
scenes_path
,
**
kwargs
)
else
:
# predict the samples in the validation/test set
samples
,
cm
=
predict_samples
(
ds
,
model
,
cm
=
ec
.
cm
,
plot
=
ec
.
plot_samples
,
**
kwargs
)
plot
=
ec
.
plot_samples
,
outpath
=
ec
.
sample_path
,
**
kwargs
)
# whether to plot the confusion matrix
if
ec
.
cm
:
plot_confusion_matrix
(
cm
,
ds
.
dataset
.
labels
,
state
=
state_file
.
name
.
replace
(
'
.pt
'
,
'
.png
'
),
outpath
=
os
.
path
.
join
(
HERE
,
'
_graphics/
'
)
)
state
=
ec
.
state_file
.
name
.
replace
(
'
.pt
'
,
'
.png
'
),
outpath
=
ec
.
models_path
)
This diff is collapsed.
Click to expand it.
pysegcnn/main/train.py
+
14
−
20
View file @
fe57ae65
...
...
@@ -6,13 +6,14 @@ Created on Tue Jun 30 09:33:38 2020
@author: Daniel
"""
# builtins
import
loggin
g
from
logging.config
import
dictConfi
g
# locals
from
pysegcnn.core.trainer
import
(
DatasetConfig
,
SplitConfig
,
ModelConfig
,
StateConfig
,
NetworkTrainer
)
StateConfig
,
LogConfig
,
NetworkTrainer
)
from
pysegcnn.core.logging
import
log_conf
from
pysegcnn.main.config
import
(
dataset_config
,
split_config
,
model_config
)
from
pysegcnn.main.config
import
(
dataset_config
,
split_config
,
model_config
,
LOG_PATH
)
if
__name__
==
'
__main__
'
:
...
...
@@ -26,49 +27,42 @@ if __name__ == '__main__':
# (ii) instanciate the dataset
ds
=
dc
.
init_dataset
()
ds
# (iii) instanciate the model state
state
=
StateConfig
(
ds
,
sc
,
mc
)
state_file
,
loss_state
=
state
.
init_state
()
state_file
=
state
.
init_state
()
# initialize logging
log
_file
=
str
(
state_file
).
replace
(
'
.pt
'
,
'
_train.log
'
)
logging
.
config
.
dictConfig
(
log_conf
(
log_file
))
#
(iv)
initialize logging
log
=
LogConfig
(
state_file
)
dictConfig
(
log_conf
(
log
.
log
_file
))
# (
i
v) instanciate the training, validation and test datasets and
# (v) instanciate the training, validation and test datasets and
# dataloaders
train_ds
,
valid_ds
,
test_ds
=
sc
.
train_val_test_split
(
ds
)
train_dl
,
valid_dl
,
test_dl
=
sc
.
dataloaders
(
train_ds
,
valid_ds
,
test_ds
,
batch_size
=
mc
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
)
# (
i
v) instanciate the model
model
=
mc
.
init_model
(
ds
)
# (v
i
) instanciate the model
model
,
optimizer
,
checkpoint_state
=
mc
.
init_model
(
ds
,
state_file
)
# (vi) instanciate the optimizer and the loss function
optimizer
=
mc
.
init_optimizer
(
model
)
# (vii) instanciate the loss function
loss_function
=
mc
.
init_loss_function
()
# (vii) resume training from an existing model checkpoint
(
model
,
optimizer
,
checkpoint_state
,
max_accuracy
)
=
mc
.
from_checkpoint
(
model
,
optimizer
,
state_file
,
loss_state
)
# (viii) initialize network trainer class for eays model training
# (viii) initialize network trainer class for easy model training
trainer
=
NetworkTrainer
(
model
=
model
,
optimizer
=
optimizer
,
loss_function
=
loss_function
,
train_dl
=
train_dl
,
valid_dl
=
valid_dl
,
test_dl
=
test_dl
,
state_file
=
state_file
,
loss_state
=
loss_state
,
epochs
=
mc
.
epochs
,
nthreads
=
mc
.
nthreads
,
early_stop
=
mc
.
early_stop
,
mode
=
mc
.
mode
,
delta
=
mc
.
delta
,
patience
=
mc
.
patience
,
max_accuracy
=
max_accuracy
,
checkpoint_state
=
checkpoint_state
,
save
=
mc
.
save
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment