diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 31c03d56b95892e4a370b369430bc29f65e5d09a..e419a3ca73534075384b8c71ea08e68362463e4e 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -718,9 +718,13 @@ def plot_classification_report(report, labels, figsize=(10, 10), # convert to DataFrame df = report2df(report, labels) - # drop overall accuracy - overall_accuracy = df.loc['accuracy'].loc['f1-score'] - metrics = df.drop(index='accuracy') + # drop overall accuracy/micro avg + try: + overall_accuracy = df.loc['accuracy'].loc['f1-score'] + metrics = df.drop(index='accuracy') + except KeyError: + overall_accuracy = df.loc['micro avg'].loc['f1-score'] + metrics = df.drop(index='micro avg') # create a figure fig, ax = plt.subplots(1, 1, figsize=figsize) diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index c8020f7730c87bbece9e11f218f4025b560408dc..d0c2a57182349868a690b755aee1dca94ecdbb3c 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2777,9 +2777,13 @@ class NetworkInference(BaseConfig): LOGGER.info('Calculating k-fold estimate of metrics ...') report = df.groupby(df.index, sort=False).mean() - # labels to predict - labels = list(report.index.drop(['macro avg', 'weighted avg', - 'accuracy'])) + # labels to predict: drop averages + try: + labels = list(report.index.drop(['macro avg', 'weighted avg', + 'accuracy'])) + except KeyError: + labels = list(report.index.drop(['macro avg', 'weighted avg', + 'micro avg'])) # plot classification report fig = plot_classification_report(report, labels)