diff --git a/pysegcnn/core/dataset.py b/pysegcnn/core/dataset.py index b4ed855e27773db85038c17e33413c655d1e42ee..b9c21070447f15acd9c2153490a7c7bd38dfb490 100644 --- a/pysegcnn/core/dataset.py +++ b/pysegcnn/core/dataset.py @@ -32,7 +32,6 @@ import pathlib # externals import numpy as np -import pandas as pd import torch from torch.utils.data import Dataset @@ -41,8 +40,7 @@ from pysegcnn.core.constants import (MultiSpectralSensor, Landsat8, Sentinel2, Label, SparcsLabels, Cloud95Labels, AlcdLabels) from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner, - parse_landsat_scene, parse_sentinel2_scene, - array_replace) + parse_landsat_scene, parse_sentinel2_scene) # module level logger LOGGER = logging.getLogger(__name__) @@ -632,10 +630,13 @@ class ImageDataset(Dataset): def class_distribution(self): - # initialize dictionary of class spectral distribution # exclude NoData class + labels = {k: v for k, v in self.labels.items() if + v['label'] != 'No data'} + + # initialize dictionary of class spectral distribution cls_ds = {k: np.empty(shape=(0, len(self.use_bands)), dtype=np.float32) - for k, v in self.labels.items() if v['label'] != 'No data'} + for k, v in labels.items()} # iterate over the samples of the dataset for i in range(len(self)): @@ -644,7 +645,7 @@ class ImageDataset(Dataset): x, y = self[i] # iterate over the different classes - for k, v in self.labels.items(): + for k, v in labels.items(): # get values equal to the current class mask = np.where(y == k)