Skip to content
Snippets Groups Projects
Commit 3ec4bf24 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Adjusted sparcs evaluation to confusion matrix output of predict function

parent dcfce5a2
No related branches found
No related tags found
No related merge requests found
......@@ -18,13 +18,13 @@ from sparcs.sparcs_02_dataset import (net, valid_ds, valid_dl, optimizer,
if __name__ == '__main__':
# calculate accuracy for each batch in the validation set
accuracies = predict(net, valid_dl, optimizer, accuracy_function,
state_file)
mean_acc = np.mean(accuracies)
print('After training for {:d} epochs, we achieved an overall mean '
'accuracy of {:.2f}% on the validation set!'.format(net.epoch,
mean_acc * 100))
# predict each batch in the validation set
cm = predict(net, valid_dl, optimizer, accuracy_function, state_file)
# calculate overal accuracy
acc = (cm.diag().sum() / cm.sum()).numpy().item()
print('After training for {:d} epochs, we achieved an overall accuracy of '
'{:.2f}% on the validation set!'.format(net.epoch, acc * 100))
# number of samples to plot
n = 5
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment