Skip to content
Snippets Groups Projects
Commit 8da0f39a authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Fixed a bug: model is now correctly passed to predict function.

parent 463be8e9
No related branches found
No related tags found
No related merge requests found
......@@ -955,9 +955,9 @@ class LogConfig(BaseConfig):
The string to write to the model log file.
"""
LOGGER.info(80 * '-')
LOGGER.info(200 * '-')
LOGGER.info('{}: '.format(LogConfig.now()) + init_str)
LOGGER.info(80 * '-')
LOGGER.info(200 * '-')
@dataclasses.dataclass
......@@ -2459,6 +2459,11 @@ class NetworkInference(BaseConfig):
def predict(self, model):
"""Classify the samples of the target dataset.
Parameters
----------
model : :py:class:`pysegcnn.core.models.Network`
The model to evaluate on the target dataset.
Returns
-------
output : `dict` [`str`, `dict`]
......@@ -2581,7 +2586,7 @@ class NetworkInference(BaseConfig):
# initialize logging
log = LogConfig(state)
dictConfig(log_conf(log.log_file))
log.init_log('Evaluating model: {}.'.format(state.name))
log.init_log('Evaluating model: {}.'.format(state))
# check whether model was already evaluated
if self.eval_file(state).exists():
......@@ -2607,7 +2612,7 @@ class NetworkInference(BaseConfig):
model, _ = Network.load_pretrained_model(state)
# evaluate the model on the target dataset
output = self.predict()
output = self.predict(model)
# check whether to calculate confusion matrix
if self.cm:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment