Skip to content
Snippets Groups Projects
Commit 22a63f44 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Confusion matrix is now saved to file

parent cafc978b
No related branches found
No related tags found
Loading
......@@ -30,6 +30,7 @@ def predict(model, dataloader, optimizer, accuracy, state_file=None):
nbatches = int(len(dataloader.dataset) / dataloader.batch_size)
# iterate over the validation/test set
accuracies = []
for batch, (inputs, labels) in enumerate(dataloader):
# send the data to the gpu if available
......@@ -45,6 +46,7 @@ def predict(model, dataloader, optimizer, accuracy, state_file=None):
# calculate accuracy
acc = accuracy(pred, labels)
accuracies.append(acc)
# print progress
print('Batch: {:d}/{:d}, Accuracy: {:.2f}'.format(batch,
......@@ -54,6 +56,10 @@ def predict(model, dataloader, optimizer, accuracy, state_file=None):
for ytrue, ypred in zip(labels.view(-1), pred.view(-1)):
cm[ytrue.long(), ypred.long()] += 1
# save confusion matrix and accuracies to file
torch.save({'cm': cm, 'accuracy': accuracies},
state.split('.pt')[0] + '_cm.pt')
return cm
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment