Skip to content
Snippets Groups Projects
Commit 70692c84 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Fixed division by while normalizing confusion matrix

parent 95949f3a
No related branches found
No related tags found
No related merge requests found
......@@ -114,6 +114,7 @@ def plot_confusion_matrix(cm, labels, normalize=True,
outpath=os.path.join(os.getcwd(), '_graphics/')):
# number of classes
labels = [label['label'] for label in labels.values()]
nclasses = len(labels)
# string format to plot values of confusion matrix
......@@ -125,7 +126,11 @@ def plot_confusion_matrix(cm, labels, normalize=True,
# check whether to normalize the confusion matrix
if normalize:
# normalize
cm = cm / cm.sum(axis=1, keepdims=True)
norm = cm.sum(axis=1, keepdims=True)
# check for division by zero
norm[norm == 0] = 1
cm = cm / norm
# change string format to floating point
fmt = '.2f'
......@@ -198,10 +203,10 @@ def plot_loss(loss_file, figsize=(10, 10), step=5,
# create axes for each parameter to plot
ax2 = ax1.twinx()
ax3 = ax1.twiny()
ax4 = ax2.twiny()
ax4 = ax3.twinx()
# list of axes
axes = [ax1, ax2, ax3, ax4]
axes = [ax2, ax1, ax4, ax3]
# plot running mean loss and accuracy of the training dataset
[ax.plot(v, color=c) for (k, v), ax, c in zip(rm.items(), axes, colors)
......@@ -209,8 +214,8 @@ def plot_loss(loss_file, figsize=(10, 10), step=5,
# axes properties and labels
nbatches = loss['tl'].shape[0]
for ax in [ax3, ax4]:
ax.set(xticks=[], xticklabels=[])
ax3.set(xticks=[], xticklabels=[])
ax4.set(xticks=[], xticklabels=[], yticks=[], yticklabels=[])
ax1.set(xticks=np.arange(0, nbatches * epochs[-1] + 1, nbatches * step),
xticklabels=epochs[::step],
xlabel='Epoch',
......@@ -231,8 +236,8 @@ def plot_loss(loss_file, figsize=(10, 10), step=5,
ha='right', color='grey')
# create a patch (proxy artist) for every color
ulabels = ['Training loss', 'Training accuracy',
'Validation loss', 'Validation accuracy']
ulabels = ['Training accuracy', 'Training loss',
'Validation accuracy', 'Validation loss']
patches = [mlines.Line2D([], [], color=c, label=l) for c, l in
zip(colors, ulabels)]
# plot patches as legend
......@@ -244,4 +249,4 @@ def plot_loss(loss_file, figsize=(10, 10), step=5,
outpath, os.path.basename(loss_file).replace('.pt', '.png')),
dpi=300, bbox_inches='tight')
return fig, ax
return fig
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment