From 8ff83719e921322338916ad768a49a97a987ab8a Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 20 Jul 2020 17:12:40 +0200 Subject: [PATCH] Made transformation classes callable for easier execution --- pytorch/transforms.py | 94 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 82 insertions(+), 12 deletions(-) diff --git a/pytorch/transforms.py b/pytorch/transforms.py index b4b0187..739a3e5 100644 --- a/pytorch/transforms.py +++ b/pytorch/transforms.py @@ -10,16 +10,16 @@ import numpy as np from scipy import ndimage -class Transformation(object): +class Transform(object): def __init__(self): pass - def apply(self, image): + def __call__(self, image): raise NotImplementedError -class VariantTransformation(Transformation): +class VariantTransform(Transform): def __init__(self): @@ -27,7 +27,7 @@ class VariantTransformation(Transformation): self.invariant = False -class InvariantTransformation(Transformation): +class InvariantTransform(Transform): def __init__(self): @@ -35,30 +35,37 @@ class InvariantTransformation(Transformation): self.invariant = True -class FlipLr(VariantTransformation): +class FlipLr(VariantTransform): def __init__(self): super().__init__() - def apply(self, image): + def __call__(self, image): return np.asarray(image)[..., ::-1] + def __repr__(self): + return self.__class__.__name__ -class FlipUd(VariantTransformation): + +class FlipUd(VariantTransform): def __init__(self): super().__init__() - def _transform(self, image): + def __call__(self, image): return np.asarray(image)[..., ::-1, :] + def __repr__(self): + return self.__class__.__name__ -class Rotate(VariantTransformation): - def __init__(self): +class Rotate(VariantTransform): + + def __init__(self, angle): + self.angle = angle super().__init__() - def apply(self, image, angle): + def __call__(self, image): # check dimension of input image ndim = np.asarray(image).ndim @@ -68,4 +75,67 @@ class Rotate(VariantTransformation): if ndim > 2: rot_axes = (ndim - 2, ndim - 1) - return ndimage.rotate(image, angle, axes=rot_axes, reshape=False) + return ndimage.rotate(image, self.angle, axes=rot_axes, reshape=False) + + def __repr__(self): + return self.__class__.__name__ + '(angle = {})'.format(self.angle) + + +class Noise(InvariantTransform): + + def __init__(self, mode, mean=0, var=0.05): + super().__init__() + + # check which kind of noise to apply + modes = ['gaussian', 'speckle'] + if mode not in modes: + raise ValueError('Supported noise types are: {}.'.format(modes)) + self.mode = mode + + # mean and variance of the gaussian distribution the noise signal is + # sampled from + self.mean = mean + self.var = var + + def __call__(self, image): + + # generate gaussian noise + noise = np.random.normal(self.mean, self.var, image.shape) + + if self.mode == 'gaussian': + return (np.asarray(image) + noise).clip(0, 1) + + if self.mode == 'speckle': + return (np.asarray(image) + np.asarray(image) * noise).clip(0, 1) + + def __repr__(self): + return self.__class__.__name__ + ('(mode = {}, mean = {}, var = {})' + .format(self.mode, self.mean, + self.var)) + + +class Augment(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, gt): + + # apply transformations to the input image in specified order + for t in self.transforms: + image = t(image) + + # check whether the transformations are invariant and if not, apply + # the transformation also to the ground truth + if not t.invariant: + gt = t(gt) + + return image, gt + + def __repr__(self): + fstring = self.__class__.__name__ + '(' + for t in self.transforms: + fstring += '\n' + fstring += ' {0}'.format(t) + fstring += '\n)' + return fstring -- GitLab