Skip to content
Snippets Groups Projects
Commit 3f14aa4c authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Loss file is deprecated; loss is now also stored in state file

parent 42efced5
No related branches found
No related tags found
No related merge requests found
......@@ -174,17 +174,17 @@ def plot_confusion_matrix(cm, labels, normalize=True,
return fig, ax
def plot_loss(loss_file, figsize=(10, 10), step=5,
def plot_loss(state_file, figsize=(10, 10), step=5,
colors=['lightgreen', 'green', 'skyblue', 'steelblue'],
outpath=os.path.join(HERE, '_graphics/')):
# load the model loss
state = torch.load(loss_file)
# load the model state
model_state = torch.load(state_file)
# get all non-zero elements, i.e. get number of epochs trained before
# early stop
loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in
state.items() if k != 'epoch'}
model_state['state'].items()}
# compute running mean with a window equal to the number of batches in
# an epoch
......@@ -245,7 +245,7 @@ def plot_loss(loss_file, figsize=(10, 10), step=5,
# save figure
os.makedirs(outpath, exist_ok=True)
fig.savefig(os.path.join(
outpath, os.path.basename(loss_file).replace('.pt', '.png')),
outpath, os.path.basename(state_file).replace('.pt', '.png')),
dpi=300, bbox_inches='tight')
return fig
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment