diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index e419a3ca73534075384b8c71ea08e68362463e4e..bd93df9712829f5e6ece29d27f5632b0391a0731 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -718,13 +718,14 @@ def plot_classification_report(report, labels, figsize=(10, 10), # convert to DataFrame df = report2df(report, labels) - # drop overall accuracy/micro avg + # get overall accuracy try: overall_accuracy = df.loc['accuracy'].loc['f1-score'] - metrics = df.drop(index='accuracy') except KeyError: overall_accuracy = df.loc['micro avg'].loc['f1-score'] - metrics = df.drop(index='micro avg') + + # drop micro avg/accuracy from dataframe + metrics = df.loc[labels + ['macro avg', 'weighted avg']] # create a figure fig, ax = plt.subplots(1, 1, figsize=figsize)