diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index f23a72aac4e41cea97e95e730b13b536a7ad66ae..50b91d62dfbb2d5135566a4bb56b4d00f7bd9d01 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2583,7 +2583,7 @@ class NetworkInference(BaseConfig): '_eval.pt')) def report_name(self, state_file): - return str(state_file).replace(state_file.suffix, '_cr.tex') + return str(state_file.name).replace(state_file.suffix, '_cr.tex') def evaluate(self): """Evaluate the models on a defined dataset. @@ -2663,8 +2663,7 @@ class NetworkInference(BaseConfig): cr_labels = [v['label'] for _, v in self.source_labels.items()] # calculate classification report from sklearn - report_name = self.report_path.joinpath( - self.report_name(state.name)) + report_name = self.report_path.joinpath(self.report_name(state)) LOGGER.info('Calculating classification report: {}' .format(report_name)) @@ -2701,8 +2700,10 @@ class NetworkInference(BaseConfig): if self.aggregate: # base name for all models - base_name = str(self.state_files[0].name) - fold_number = re.search('f[0-9]', base_name)[0] + base_name = self.state_files[0] + fold_number = re.search('f[0-9]', base_name.name)[0] + kfold = base_name.parent.joinpath( + str(base_name.name).replace(fold_number, 'kfold')) # predictions of the different models y_true = [output['y_true'] for output in inference.values()] @@ -2714,8 +2715,7 @@ class NetworkInference(BaseConfig): ['{}'.format(mstate.name) for mstate in self.state_files])) # calculate classification report from sklearn - report_name = self.report_path.joinpath( - self.report_name(base_name.replace(fold_number, 'kfold'))) + report_name = self.report_path.joinpath(self.report_name(kfold)) LOGGER.info('Calculating classification report: {}' .format(report_name)) @@ -2737,12 +2737,9 @@ class NetworkInference(BaseConfig): # save aggregated confusion matrix to dictionary inference['cm_agg'] = cm_agg - # create file name for aggregated confusion matrix - cm_name = base_name.replace(fold_number, 'kfold') - # plot aggregated confusion matrix and save to file plot_confusion_matrix( - cm_agg, self.source_labels, state_file=cm_name, + cm_agg, self.source_labels, state_file=kfold, outpath=self.perfmc_path) return inference