diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 88e0fc624c0999dd9a9ed2802a9e9f6dafa4b09d..bc20de3f6989fdf843dd3e3a04d2872d9ae339aa 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2643,77 +2643,76 @@ class NetworkInference(BaseConfig): # load existing model evaluation if not self.overwrite: inference[state.stem] = torch.load(self.eval_file(state)) + continue else: # overwrite existing model evaluation LOGGER.info('Overwriting model evaluation: {}.' .format(self.eval_file(state))) self.eval_file(state).unlink() - else: + # plot loss and accuracy + plot_loss(check_filename_length(state), + outpath=self.perfmc_path) - # plot loss and accuracy - plot_loss(check_filename_length(state), - outpath=self.perfmc_path) + # load the target dataset to evaluate the model on + self.trg_ds = self.load_dataset( + state, implicit=self.implicit, test=self.test, + domain=self.domain) - # load the target dataset to evaluate the model on - self.trg_ds = self.load_dataset( - state, implicit=self.implicit, test=self.test, - domain=self.domain) + # load the source dataset the model was trained on + self.src_ds = self.load_dataset(state, test=None) - # load the source dataset the model was trained on - self.src_ds = self.load_dataset(state, test=None) + # load the pretrained model + model, _ = Network.load_pretrained_model(state) - # load the pretrained model - model, _ = Network.load_pretrained_model(state) - - # evaluate the model on the target dataset - output = self.predict(model) - - # merge predictions of the different samples - y_true = np.asarray([v['y_true'].flatten() for _, v - in output.items()]).flatten() - y_pred = np.asarray([v['y_pred'].flatten() for _, v - in output.items()]).flatten() - - # calculate classification report from sklearn - report_name = self.report_path.joinpath( - self.report_name(state)) - LOGGER.info('Calculating classification report: {}' - .format(report_name)) - report = classification_report( - y_true, y_pred, target_names=self.class_names, - output_dict=True) - - # store report in output dictionary - output['report'] = report2df(report, self.class_names) - - # plot classification report - fig = plot_classification_report(report, self.class_names) - fig.savefig(report_name, dpi=300, bbox_inches='tight') - - # check whether to calculate confusion matrix - if self.cm: - - # calculate confusion matrix - LOGGER.info('Computing confusion matrix ...') - conf_mat = confusion_matrix(y_true, y_pred) - - # add confusion matrix to model output - output['cm'] = conf_mat - - # plot confusion matrix - plot_confusion_matrix( - conf_mat, self.class_names, state_file=state, - subset=self.trg_ds.name, - outpath=self.perfmc_path) - - # save model predictions to file - LOGGER.info('Saving model evaluation: {}' - .format(self.eval_file(state))) - torch.save(output, self.eval_file(state)) + # evaluate the model on the target dataset + output = self.predict(model) + + # merge predictions of the different samples + y_true = np.asarray([v['y_true'].flatten() for _, v + in output.items()]).flatten() + y_pred = np.asarray([v['y_pred'].flatten() for _, v + in output.items()]).flatten() + + # calculate classification report from sklearn + report_name = self.report_path.joinpath( + self.report_name(state)) + LOGGER.info('Calculating classification report: {}' + .format(report_name)) + report = classification_report( + y_true, y_pred, target_names=self.class_names, + output_dict=True, zero_division=1) + + # store report in output dictionary + output['report'] = report2df(report, self.class_names) + + # plot classification report + fig = plot_classification_report(report, self.class_names) + fig.savefig(report_name, dpi=300, bbox_inches='tight') + + # check whether to calculate confusion matrix + if self.cm: + + # calculate confusion matrix + LOGGER.info('Computing confusion matrix ...') + conf_mat = confusion_matrix(y_true, y_pred) + + # add confusion matrix to model output + output['cm'] = conf_mat + + # plot confusion matrix + plot_confusion_matrix( + conf_mat, self.class_names, state_file=state, + subset=self.trg_ds.name, + outpath=self.perfmc_path) + + # save model predictions to file + LOGGER.info('Saving model evaluation: {}' + .format(self.eval_file(state))) + torch.save(output, self.eval_file(state)) - # save model predictions to list - inference[state.stem] = output + # save model predictions to list + inference[state.stem] = output # check whether to aggregate the results of the different model runs if self.aggregate: