Commit 307b0962 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Generalized network trainer to regression tasks.

parent af796f29
......@@ -3,9 +3,6 @@
This module provides an end-to-end framework of dataclasses designed to train
segmentation models on image datasets.
See :py:meth:`pysegcnn.core.trainer.NetworkTrainer.init_network_trainer` for a
complete walkthrough.
License
-------
......@@ -966,16 +963,7 @@ class LogConfig(BaseConfig):
@dataclasses.dataclass
class NetworkTrainer(BaseConfig):
"""Base model training class for classification problems.
Train an instance of :py:class:`pysegcnn.core.models.Network` on a
classification problem. The `categorical cross-entropy loss`_
is used as the loss function in combination with the `softmax`_ output
layer activation function.
In case of a binary classification problem, the categorical cross-entropy
loss reduces to the binary cross-entropy loss and the softmax function to
the standard `logistic function`_.
"""Train an instance of :py:class:`pysegcnn.core.models.Network`.
Attributes
----------
......@@ -1033,11 +1021,11 @@ class NetworkTrainer(BaseConfig):
`True`.
multi_gpu : `bool`
Whether to use multiple GPUs, if available. The default is `False`.
classification : `bool`
Whether the task to solve is a classification task. The default is
`True`.
device : `str`
The device to train the model on, i.e. `cpu` or `cuda`.
cla_loss_function : :py:class:`torch.nn.Module`
The classification loss function to compute the model error. An
instance of :py:class:`torch.nn.CrossEntropyLoss`.
tracker : :py:class:`pysegcnn.core.trainer.MetricTracker`
A :py:class:`pysegcnn.core.trainer.MetricTracker` instance tracking
training metrics, i.e. loss and accuracy.
......@@ -1086,6 +1074,7 @@ class NetworkTrainer(BaseConfig):
checkpoint_state: dict = dataclasses.field(default_factory=dict)
save: bool = True
multi_gpu: bool = False
classification: bool = True
def __post_init__(self):
"""Check the type of each argument.
......@@ -1116,24 +1105,29 @@ class NetworkTrainer(BaseConfig):
# send the model to the gpu(s)
self.model = self.model.to(self.device)
# instanciate multiclass classification loss function: categorical
# cross-entropy loss function
LOGGER.info('Classification loss function: {}.'
.format(repr(self.loss_function)))
# instanciate loss function
LOGGER.info('Loss function: {}.'.format(repr(self.loss_function)))
# instanciate metric tracker
self.tracker = MetricTracker(
train_metrics=['train_loss', 'train_accu'],
valid_metrics=['valid_loss', 'valid_accu'])
self.tracker = MetricTracker(train_metrics=['train_loss'],
valid_metrics=['valid_loss'])
# check if solving a classification task
if self.classification:
# include classification accuracy metrics
self.tracker.train_metrics.append('train_accu')
self.tracker.valid_metrics.append('valid_accu')
# check which metric to use for early stopping
self.best, self.metric, self.mfn = (
(0, 'valid_accu', np.max) if self.mode == 'max' else
(np.inf, 'valid_loss', np.min))
else:
self.best, self.metric, self.mfn = (np.inf, 'valid_loss', np.min)
# initialize metric tracker
self.tracker.initialize()
# check which metric to use for early stopping
self.best, self.metric, self.mfn = (
(0, 'valid_accu', np.max) if self.mode == 'max' else
(np.inf, 'valid_loss', np.min))
# best metric score on the validation set
if self.checkpoint_state:
self.best = self.mfn(
......@@ -1169,6 +1163,7 @@ class NetworkTrainer(BaseConfig):
"""
# iterate over the dataloader object
log = 'Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f}'
for batch, (inputs, labels) in enumerate(self.src_train_dl):
# send the data to the gpu if available
......@@ -1191,21 +1186,27 @@ class NetworkTrainer(BaseConfig):
# update the weights
self.optimizer.step()
# calculate predicted class labels
ypred = F.softmax(outputs, dim=1).argmax(dim=1)
# update training loss
self.tracker.update('train_loss', loss.item())
progress = log.format(epoch + 1, self.epochs, batch + 1,
self.tmbatch, loss.item())
# calculate accuracy on current batch
acc = accuracy_function(ypred, labels)
# calculate model predictions for classification tasks
if self.classification:
# calculate predicted class labels
ypred = F.softmax(outputs, dim=1).argmax(dim=1)
# calculate accuracy on current batch
acc = accuracy_function(ypred, labels)
# update training accuracy
self.tracker.update('train_accu', acc)
progress = ', '.join([progress,
'Accuracy: {:.2f}'.format(acc)])
# print progress
LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, '
'Loss: {:.2f}, Accuracy: {:.2f}'
.format(epoch + 1, self.epochs, batch + 1,
self.tmbatch, loss.item(), acc))
LOGGER.info(progress)
# update training metrics
self.tracker.batch_update(self.tracker.train_metrics,
[loss.item(), acc])
def train_epoch(self, epoch):
"""Train a model for a single epoch on the source domain.
......@@ -1254,6 +1255,8 @@ class NetworkTrainer(BaseConfig):
# update validation metrics
self.tracker.batch_update(self.tracker.valid_metrics,
# save model state if the model improved with
[valid_loss, valid_accu])
# metric to assess model performance on the validation set
......@@ -1262,8 +1265,6 @@ class NetworkTrainer(BaseConfig):
# whether the model improved with respect to the previous epoch
if self.es.is_better(epoch_best, self.best, self.delta):
self.best = epoch_best
# save model state if the model improved with
# respect to the previous epoch
if self.save:
self.save_state()
......@@ -1315,6 +1316,7 @@ class NetworkTrainer(BaseConfig):
# iterate over the validation/test set
LOGGER.info('Calculating accuracy on the validation set ...')
log = 'Mini-batch: {:d}/{:d}, Loss: {:.2f}'
for batch, (inputs, labels) in enumerate(dataloader):
# send the data to the gpu if available
......@@ -1326,30 +1328,35 @@ class NetworkTrainer(BaseConfig):
outputs = self.model(inputs)
# compute loss
cla_loss = self.loss_function(outputs, labels.long())
loss.append(cla_loss.item())
val_loss = self.loss_function(outputs, labels.long())
loss.append(val_loss.item())
# calculate predicted class labels
pred = F.softmax(outputs, dim=1).max(dim=1)
if return_pred:
predictions['y_pred'].append(pred.indices)
predictions['y_prob'].append(pred.values)
# calculate accuracy on current batch
acc = accuracy_function(pred.indices, labels)
accuracy.append(acc)
progress = log.format(batch + 1, len(dataloader), val_loss)
if self.classification:
pred = F.softmax(outputs, dim=1).max(dim=1)
if return_pred:
predictions['y_pred'].append(pred.indices)
predictions['y_prob'].append(pred.values)
# calculate accuracy on current batch
acc = accuracy_function(pred.indices, labels)
accuracy.append(acc)
progress = ', '.join([progress,
'Accuracy: {:.2f}'.format(acc)])
# print progress
LOGGER.info('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}'
.format(batch + 1, len(dataloader), acc))
LOGGER.info(progress)
# calculate overall accuracy on the validation/test set
# calculate overall accuracy/loss on the validation/test set
epoch = (self.model.module.epoch if
isinstance(self.model, nn.DataParallel) else self.model.epoch)
LOGGER.info('Epoch: {:d}, Mean accuracy: {:.2f}%.'
.format(epoch, np.mean(accuracy) * 100))
overall = ('Mean accuracy: {:.2f}%'.format(np.mean(accuracy) * 100) if
self.classification else
'Mean loss: {:.2f}%'.format(np.mean(loss)))
LOGGER.info(', '.join(['Epoch: {:d}'.format(epoch), overall]))
if return_pred:
if self.classification and return_pred:
# return only predictions, if specified
return predictions
else:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment