Commit 755d39c0 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented a flag to use multiple GPUs, if available.

parent cc30d808
......@@ -1027,6 +1027,8 @@ class ClassificationNetworkTrainer(BaseConfig):
save : `bool`
Whether to save the model state to ``state_file``. The default is
`True`.
multi_gpu : `bool`
Whether to use multiple GPUs, if available. The default is `False`.
device : `str`
The device to train the model on, i.e. `cpu` or `cuda`.
cla_loss_function : :py:class:`torch.nn.Module`
......@@ -1078,6 +1080,7 @@ class ClassificationNetworkTrainer(BaseConfig):
patience: int = 10
checkpoint_state: dict = dataclasses.field(default_factory=dict)
save: bool = True
multi_gpu: bool = False
def __post_init__(self):
"""Check the type of each argument.
......@@ -1100,7 +1103,7 @@ class ClassificationNetworkTrainer(BaseConfig):
torch.set_num_threads(self.nthreads)
# check if multiple gpus are available
if torch.cuda.device_count() > 1:
if torch.cuda.device_count() > 1 and self.multi_gpu:
LOGGER.info('Using {} available GPUs.'.format(
torch.cuda.device_count()))
self.model = nn.DataParallel(self.model, dim=0)
......@@ -1338,7 +1341,7 @@ class ClassificationNetworkTrainer(BaseConfig):
# calculate overall accuracy on the validation/test set
epoch = (self.model.module.epoch if
isinstance(self.model,nn.DataParallel) else self.model.epoch)
isinstance(self.model, nn.DataParallel) else self.model.epoch)
LOGGER.info('Epoch: {:d}, Mean accuracy: {:.2f}%.'
.format(epoch, np.mean(accuracy) * 100))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment