Skip to content
Snippets Groups Projects
Commit 59e9d02e authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

More explicit input name

parent c335d997
No related branches found
No related tags found
No related merge requests found
......@@ -446,12 +446,12 @@ class ImageDataset(Dataset):
return fig, ax
def plot_loss(self, state_file, figsize=(10, 10),
def plot_loss(self, loss_file, figsize=(10, 10),
colors=['lightgreen', 'skyblue', 'darkgreen', 'steelblue'],
outpath=os.path.join(os.getcwd(), '_graphics/')):
# load the model loss
state = torch.load(state_file)
state = torch.load(loss_file)
# get all non-zero elements, i.e. get number of epochs trained before
# early stop
......@@ -516,7 +516,7 @@ class ImageDataset(Dataset):
# save figure
os.makedirs(outpath, exist_ok=True)
fig.savefig(os.path.join(
outpath, os.path.basename(state_file).replace('.pt', '.png')),
outpath, os.path.basename(loss_file).replace('.pt', '.png')),
dpi=300, bbox_inches='tight')
return fig, ax
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment