diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 09239c92fdd47e23767c02c8e9d0f3955459dd37..fa8715ed8882478792d8095bc07ad55ba0ce1546 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -292,11 +292,18 @@ class NetworkTrainer(object): def _init_state(self): # file to save model state to - # format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt - bformat = ''.join([b[0] for b in self.bands]) if self.bands else 'all' - self.state_file = ('{}_{}_t{}_b{}_{}.pt' + # format: network_dataset_seed_tilesize_batchsize_bands.pt + + # get the band numbers + bformat = ''.join(band[0] + + str(self.dataset.sensor.__members__[band].value) for + band in self.bands) + + # model state filename + self.state_file = ('{}_{}_s{}_t{}_b{}_{}.pt' .format(self.model.__name__, self.dataset.__class__.__name__, + self.seed, self.tile_size, self.batch_size, bformat))