diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py index 61b69e8812efa14fc8d33c6cf5cae6c1589e8a0d..8be2d19dda47be3e14d9cad8d9335ab98f6ff468 100644 --- a/pysegcnn/core/utils.py +++ b/pysegcnn/core/utils.py @@ -11,6 +11,7 @@ import datetime # externals import gdal +import torch import numpy as np # the following functions are utility functions for common image @@ -296,6 +297,14 @@ def reconstruct_scene(tiles, img_size, tile_size=None, nbands=1): return scene.squeeze() +# function calculating prediction accuracy +def accuracy_function(outputs, labels): + if isinstance(outputs, torch.Tensor): + return (outputs == labels).float().mean().item() + else: + return (np.asarray(outputs) == np.asarray(labels)).mean().item() + + def parse_landsat_scene(scene_id): # Landsat Collection 1 naming convention in regular expression