Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
405feffb
Commit
405feffb
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Changed default path to save model output
parent
a12d474a
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
pysegcnn/core/graphics.py
+4
-3
4 additions, 3 deletions
pysegcnn/core/graphics.py
pysegcnn/core/models.py
+4
-3
4 additions, 3 deletions
pysegcnn/core/models.py
pysegcnn/main/eval.py
+1
-1
1 addition, 1 deletion
pysegcnn/main/eval.py
with
9 additions
and
7 deletions
pysegcnn/core/graphics.py
+
4
−
3
View file @
405feffb
...
@@ -19,6 +19,7 @@ from matplotlib import cm as colormap
...
@@ -19,6 +19,7 @@ from matplotlib import cm as colormap
# locals
# locals
from
pysegcnn.core.trainer
import
accuracy_function
from
pysegcnn.core.trainer
import
accuracy_function
from
pysegcnn.core.config
import
HERE
# this function applies percentile stretching at the alpha level
# this function applies percentile stretching at the alpha level
...
@@ -49,7 +50,7 @@ def running_mean(x, w):
...
@@ -49,7 +50,7 @@ def running_mean(x, w):
# with the model prediction and the corresponding ground truth
# with the model prediction and the corresponding ground truth
def
plot_sample
(
x
,
y
,
use_bands
,
labels
,
y_pred
=
None
,
figsize
=
(
10
,
10
),
def
plot_sample
(
x
,
y
,
use_bands
,
labels
,
y_pred
=
None
,
figsize
=
(
10
,
10
),
bands
=
[
'
nir
'
,
'
red
'
,
'
green
'
],
stretch
=
False
,
state
=
None
,
bands
=
[
'
nir
'
,
'
red
'
,
'
green
'
],
stretch
=
False
,
state
=
None
,
outpath
=
os
.
path
.
join
(
os
.
getcwd
()
,
'
_samples/
'
),
**
kwargs
):
outpath
=
os
.
path
.
join
(
HERE
,
'
_samples/
'
),
**
kwargs
):
# check whether to apply constrast stretching
# check whether to apply constrast stretching
stretch
=
True
if
kwargs
else
stretch
stretch
=
True
if
kwargs
else
stretch
...
@@ -111,7 +112,7 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
...
@@ -111,7 +112,7 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
# set returned by the pytorch.predict function
# set returned by the pytorch.predict function
def
plot_confusion_matrix
(
cm
,
labels
,
normalize
=
True
,
def
plot_confusion_matrix
(
cm
,
labels
,
normalize
=
True
,
figsize
=
(
10
,
10
),
cmap
=
'
Blues
'
,
state
=
None
,
figsize
=
(
10
,
10
),
cmap
=
'
Blues
'
,
state
=
None
,
outpath
=
os
.
path
.
join
(
os
.
getcwd
()
,
'
_graphics/
'
)):
outpath
=
os
.
path
.
join
(
HERE
,
'
_graphics/
'
)):
# number of classes
# number of classes
labels
=
[
label
[
'
label
'
]
for
label
in
labels
.
values
()]
labels
=
[
label
[
'
label
'
]
for
label
in
labels
.
values
()]
...
@@ -180,7 +181,7 @@ def plot_confusion_matrix(cm, labels, normalize=True,
...
@@ -180,7 +181,7 @@ def plot_confusion_matrix(cm, labels, normalize=True,
def
plot_loss
(
loss_file
,
figsize
=
(
10
,
10
),
step
=
5
,
def
plot_loss
(
loss_file
,
figsize
=
(
10
,
10
),
step
=
5
,
colors
=
[
'
lightgreen
'
,
'
green
'
,
'
skyblue
'
,
'
steelblue
'
],
colors
=
[
'
lightgreen
'
,
'
green
'
,
'
skyblue
'
,
'
steelblue
'
],
outpath
=
os
.
path
.
join
(
os
.
getcwd
()
,
'
_graphics/
'
)):
outpath
=
os
.
path
.
join
(
HERE
,
'
_graphics/
'
)):
# load the model loss
# load the model loss
state
=
torch
.
load
(
loss_file
)
state
=
torch
.
load
(
loss_file
)
...
...
This diff is collapsed.
Click to expand it.
pysegcnn/core/models.py
+
4
−
3
View file @
405feffb
...
@@ -16,6 +16,7 @@ import torch.nn as nn
...
@@ -16,6 +16,7 @@ import torch.nn as nn
# locals
# locals
from
pysegcnn.core.layers
import
(
Encoder
,
Decoder
,
Conv2dPool
,
Conv2dUnpool
,
from
pysegcnn.core.layers
import
(
Encoder
,
Decoder
,
Conv2dPool
,
Conv2dUnpool
,
Conv2dUpsample
,
Conv2dSame
)
Conv2dUpsample
,
Conv2dSame
)
from
pysegcnn.main.config
import
HERE
class
Network
(
nn
.
Module
):
class
Network
(
nn
.
Module
):
...
@@ -31,8 +32,8 @@ class Network(nn.Module):
...
@@ -31,8 +32,8 @@ class Network(nn.Module):
for
param
in
self
.
parameters
():
for
param
in
self
.
parameters
():
param
.
requires_grad
=
True
param
.
requires_grad
=
True
def
save
(
self
,
state_file
,
optimizer
,
bands
,
def
save
(
self
,
state_file
,
optimizer
,
bands
=
None
,
outpath
=
os
.
path
.
join
(
os
.
getcwd
()
,
'
_models
'
)):
outpath
=
os
.
path
.
join
(
HERE
,
'
_models
/
'
)):
# check if the output path exists and if not, create it
# check if the output path exists and if not, create it
if
not
os
.
path
.
isdir
(
outpath
):
if
not
os
.
path
.
isdir
(
outpath
):
...
@@ -70,7 +71,7 @@ class Network(nn.Module):
...
@@ -70,7 +71,7 @@ class Network(nn.Module):
return
state
return
state
def
load
(
self
,
state_file
,
optimizer
=
None
,
def
load
(
self
,
state_file
,
optimizer
=
None
,
inpath
=
os
.
path
.
join
(
os
.
getcwd
()
,
'
_models
'
)):
inpath
=
os
.
path
.
join
(
HERE
,
'
_models
/
'
)):
# load the model state file
# load the model state file
state
=
os
.
path
.
join
(
inpath
,
state_file
)
state
=
os
.
path
.
join
(
inpath
,
state_file
)
...
...
This diff is collapsed.
Click to expand it.
pysegcnn/main/eval.py
+
1
−
1
View file @
405feffb
...
@@ -38,7 +38,7 @@ if __name__ == '__main__':
...
@@ -38,7 +38,7 @@ if __name__ == '__main__':
trainer
.
cm
,
trainer
.
cm
,
trainer
.
plot_scenes
,
trainer
.
plot_scenes
,
bands
=
trainer
.
plot_bands
,
bands
=
trainer
.
plot_bands
,
outpath
=
os
.
path
.
join
(
HERE
,
'
_s
ampl
es/
'
),
outpath
=
os
.
path
.
join
(
HERE
,
'
_s
cen
es/
'
),
stretch
=
True
,
stretch
=
True
,
alpha
=
5
)
alpha
=
5
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment