diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 855b76b1559f878e0ea935d1451adc72833ad756..31c615f3156d7832917389c53059ce5f8e8bc315 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -253,11 +253,7 @@ class ModelConfig(BaseConfig): def init_model(self, ds, state_file): # write an initialization string to the log file - # now = datetime.datetime.strftime(datetime.datetime.now(), - # '%Y-%m-%dT%H:%M:%S') - # LOGGER.info(80 * '-') - # LOGGER.info('{}: Initializing model run. '.format(now) + 35 * '-') - # LOGGER.info(80 * '-') + LogConfig.init_log('{}: Initializing model run. ') # case (1): build a new model if not self.transfer: @@ -332,10 +328,10 @@ class ModelConfig(BaseConfig): ds.__class__.__name__)) # check whether the current dataset uses the correct spectral bands - if new_ds.use_bands != model_state['bands']: + if ds.use_bands != model_state['bands']: raise ValueError('The pretrained network was trained with ' 'bands {}, not with bands {}.' - .format(model_state['bands'], new_ds.use_bands)) + .format(model_state['bands'], ds.use_bands)) # get the number of convolutional filters filters = model_state['params']['filters'] @@ -379,7 +375,7 @@ class StateConfig(BaseConfig): # get the band numbers bformat = ''.join(band[0] + str(self.ds.sensor.__members__[band].value) for - band in self.ds.use_bands) + band in self.ds.use_bands) # check which split mode was used if self.sc.split_mode == 'date': @@ -449,9 +445,8 @@ class EvalConfig(BaseConfig): self.models_path = self.base_path.joinpath('_graphics') # write initialization string to log file - # LOGGER.info(80 * '-') - # LOGGER.info('{}') - # LOGGER.info(80 * '-') + LogConfig.init_log('{}: ' + 'Evaluating model: {}.'.format( + self.state_file.name)) @dataclasses.dataclass @@ -468,6 +463,17 @@ class LogConfig(BaseConfig): self.log_file = self.log_path.joinpath( self.state_file.name.replace('.pt', '.log')) + @staticmethod + def now(): + return datetime.datetime.strftime(datetime.datetime.now(), + '%Y-%m-%dT%H:%M:%S') + + @staticmethod + def init_log(init_str): + LOGGER.info(80 * '-') + LOGGER.info(init_str.format(LogConfig.now())) + LOGGER.info(80 * '-') + @dataclasses.dataclass class NetworkTrainer(BaseConfig): @@ -506,6 +512,9 @@ class NetworkTrainer(BaseConfig): self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta, self.patience) + # log representation + LOGGER.info(repr(self)) + def train(self): LOGGER.info(35 * '-' + ' Training ' + 35 * '-')