Skip to content
Snippets Groups Projects
Commit 0646dda2 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Fixed bug for loss in state file.

parent 04805dcc
No related branches found
No related tags found
No related merge requests found
...@@ -124,13 +124,12 @@ class EoDataset(torch.utils.data.Dataset): ...@@ -124,13 +124,12 @@ class EoDataset(torch.utils.data.Dataset):
state_file = '_'.join([state_file, 'doy']) if doy else state_file state_file = '_'.join([state_file, 'doy']) if doy else state_file
# check which loss function is used # check which loss function is used
if predictand == 'pr': if (isinstance(loss, BernoulliGammaLoss) or
if (isinstance(loss, BernoulliGammaLoss) or isinstance(loss, BernoulliWeibullLoss)):
isinstance(loss, BernoulliWeibullLoss)): # adjust state file for precipitation
# adjust state file for precipitation state_file = '_'.join([state_file, '{}mm'.format(
state_file = '_'.join([state_file, '{}mm'.format( str(loss.min_amount).replace('.', '')),
str(loss.min_amount).replace('.', '')), repr(loss).strip('()')])
repr(loss).strip('()')])
else: else:
# add name of loss function to state file # add name of loss function to state file
state_file = '_'.join([state_file, repr(loss).strip('()')]) state_file = '_'.join([state_file, repr(loss).strip('()')])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment