Commit 92476ba3 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Generalized network trainer to regression tasks.

parent 307b0962
......@@ -1109,15 +1109,12 @@ class NetworkTrainer(BaseConfig):
LOGGER.info('Loss function: {}.'.format(repr(self.loss_function)))
# instanciate metric tracker
self.tracker = MetricTracker(train_metrics=['train_loss'],
valid_metrics=['valid_loss'])
self.tracker = MetricTracker(
train_metrics=['train_loss', 'train_accu'],
valid_metrics=['valid_loss', 'valid_accu'])
# check if solving a classification task
if self.classification:
# include classification accuracy metrics
self.tracker.train_metrics.append('train_accu')
self.tracker.valid_metrics.append('valid_accu')
# check which metric to use for early stopping
self.best, self.metric, self.mfn = (
(0, 'valid_accu', np.max) if self.mode == 'max' else
......@@ -1207,7 +1204,6 @@ class NetworkTrainer(BaseConfig):
# print progress
LOGGER.info(progress)
def train_epoch(self, epoch):
"""Train a model for a single epoch on the source domain.
......@@ -1255,8 +1251,6 @@ class NetworkTrainer(BaseConfig):
# update validation metrics
self.tracker.batch_update(self.tracker.valid_metrics,
# save model state if the model improved with
[valid_loss, valid_accu])
# metric to assess model performance on the validation set
......@@ -1265,7 +1259,8 @@ class NetworkTrainer(BaseConfig):
# whether the model improved with respect to the previous epoch
if self.es.is_better(epoch_best, self.best, self.delta):
self.best = epoch_best
# respect to the previous epoch
# save model if it improved with respect to the previous
# epoch
if self.save:
self.save_state()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment