Commit c3ac8c4b authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Generalized network trainer class to arbitrary tasks.

parent 2b7cd8fb
......@@ -965,7 +965,7 @@ class LogConfig(BaseConfig):
@dataclasses.dataclass
class ClassificationNetworkTrainer(BaseConfig):
class NetworkTrainer(BaseConfig):
"""Base model training class for classification problems.
Train an instance of :py:class:`pysegcnn.core.models.Network` on a
......@@ -999,6 +999,10 @@ class ClassificationNetworkTrainer(BaseConfig):
The source domain test :py:class:`torch.utils.data.DataLoader`
instance build from an instance of
:py:class:`torch.utils.data.Subset`.
loss_function : :py:class:`torch.nn.modules.loss._Loss`
The loss function to minimize. A subclass of
:py:class:`torch.nn.modules.loss._Loss`. The default is
:py:class:`torch.nn.CrossEntropyLoss`.
epochs : `int`
The maximum number of epochs to train. The default is `1`.
nthreads : `int`
......@@ -1072,6 +1076,7 @@ class ClassificationNetworkTrainer(BaseConfig):
src_train_dl: DataLoader
src_valid_dl: DataLoader
src_test_dl: DataLoader = DataLoader(None)
loss_function: nn.modules.loss._Loss = nn.CrossEntropyLoss()
epochs: int = 1
nthreads: int = torch.get_num_threads()
early_stop: bool = False
......@@ -1113,9 +1118,8 @@ class ClassificationNetworkTrainer(BaseConfig):
# instanciate multiclass classification loss function: categorical
# cross-entropy loss function
self.cla_loss_function = nn.CrossEntropyLoss()
LOGGER.info('Classification loss function: {}.'
.format(repr(nn.CrossEntropyLoss)))
.format(repr(self.loss_function)))
# instanciate metric tracker
self.tracker = MetricTracker(
......@@ -1178,7 +1182,7 @@ class ClassificationNetworkTrainer(BaseConfig):
outputs = self.model(inputs)
# compute loss
loss = self.cla_loss_function(outputs, labels.long())
loss = self.loss_function(outputs, labels.long())
# compute the gradients of the loss function w.r.t.
# the network weights
......@@ -1322,7 +1326,7 @@ class ClassificationNetworkTrainer(BaseConfig):
outputs = self.model(inputs)
# compute loss
cla_loss = self.cla_loss_function(outputs, labels.long())
cla_loss = self.loss_function(outputs, labels.long())
loss.append(cla_loss.item())
# calculate predicted class labels
......@@ -1416,7 +1420,7 @@ class ClassificationNetworkTrainer(BaseConfig):
# loss function
fs += '\n (loss function):' + '\n' + 8 * ' '
fs += ''.join(repr(self.cla_loss_function)).replace('\n',
fs += ''.join(repr(self.loss_function)).replace('\n',
'\n' + 8 * ' ')
# early stopping
......@@ -1445,7 +1449,7 @@ class ClassificationNetworkTrainer(BaseConfig):
@dataclasses.dataclass
class DomainAdaptationTrainer(ClassificationNetworkTrainer):
class DomainAdaptationTrainer(NetworkTrainer):
"""Model training class for domain adaptation.
Train an instance of :py:class:`pysegcnn.core.models.EncoderDecoderNetwork`
......@@ -1649,7 +1653,7 @@ class DomainAdaptationTrainer(ClassificationNetworkTrainer):
src_input, trg_input)
# compute classification loss
cla_loss = self.cla_loss_function(src_prdctn, src_label.long())
cla_loss = self.loss_function(src_prdctn, src_label.long())
# compute domain adaptation loss:
# the difference between source and target domain is computed
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment