Skip to content
Snippets Groups Projects
Commit 7e80b2c9 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Fixed a bug: NoData class is now skipped in class distribution.

parent 77a4727c
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,6 @@ import pathlib
# externals
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
......@@ -41,8 +40,7 @@ from pysegcnn.core.constants import (MultiSpectralSensor, Landsat8, Sentinel2,
Label, SparcsLabels, Cloud95Labels,
AlcdLabels)
from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner,
parse_landsat_scene, parse_sentinel2_scene,
array_replace)
parse_landsat_scene, parse_sentinel2_scene)
# module level logger
LOGGER = logging.getLogger(__name__)
......@@ -632,10 +630,13 @@ class ImageDataset(Dataset):
def class_distribution(self):
# initialize dictionary of class spectral distribution
# exclude NoData class
labels = {k: v for k, v in self.labels.items() if
v['label'] != 'No data'}
# initialize dictionary of class spectral distribution
cls_ds = {k: np.empty(shape=(0, len(self.use_bands)), dtype=np.float32)
for k, v in self.labels.items() if v['label'] != 'No data'}
for k, v in labels.items()}
# iterate over the samples of the dataset
for i in range(len(self)):
......@@ -644,7 +645,7 @@ class ImageDataset(Dataset):
x, y = self[i]
# iterate over the different classes
for k, v in self.labels.items():
for k, v in labels.items():
# get values equal to the current class
mask = np.where(y == k)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment