From 8ff83719e921322338916ad768a49a97a987ab8a Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 20 Jul 2020 17:12:40 +0200
Subject: [PATCH] Made transformation classes callable for easier execution

---
 pytorch/transforms.py | 94 +++++++++++++++++++++++++++++++++++++------
 1 file changed, 82 insertions(+), 12 deletions(-)

diff --git a/pytorch/transforms.py b/pytorch/transforms.py
index b4b0187..739a3e5 100644
--- a/pytorch/transforms.py
+++ b/pytorch/transforms.py
@@ -10,16 +10,16 @@ import numpy as np
 from scipy import ndimage
 
 
-class Transformation(object):
+class Transform(object):
 
     def __init__(self):
         pass
 
-    def apply(self, image):
+    def __call__(self, image):
         raise NotImplementedError
 
 
-class VariantTransformation(Transformation):
+class VariantTransform(Transform):
 
     def __init__(self):
 
@@ -27,7 +27,7 @@ class VariantTransformation(Transformation):
         self.invariant = False
 
 
-class InvariantTransformation(Transformation):
+class InvariantTransform(Transform):
 
     def __init__(self):
 
@@ -35,30 +35,37 @@ class InvariantTransformation(Transformation):
         self.invariant = True
 
 
-class FlipLr(VariantTransformation):
+class FlipLr(VariantTransform):
 
     def __init__(self):
         super().__init__()
 
-    def apply(self, image):
+    def __call__(self, image):
         return np.asarray(image)[..., ::-1]
 
+    def __repr__(self):
+        return self.__class__.__name__
 
-class FlipUd(VariantTransformation):
+
+class FlipUd(VariantTransform):
 
     def __init__(self):
         super().__init__()
 
-    def _transform(self, image):
+    def __call__(self, image):
         return np.asarray(image)[..., ::-1, :]
 
+    def __repr__(self):
+        return self.__class__.__name__
 
-class Rotate(VariantTransformation):
 
-    def __init__(self):
+class Rotate(VariantTransform):
+
+    def __init__(self, angle):
+        self.angle = angle
         super().__init__()
 
-    def apply(self, image, angle):
+    def __call__(self, image):
 
         # check dimension of input image
         ndim = np.asarray(image).ndim
@@ -68,4 +75,67 @@ class Rotate(VariantTransformation):
         if ndim > 2:
             rot_axes = (ndim - 2, ndim - 1)
 
-        return ndimage.rotate(image, angle, axes=rot_axes, reshape=False)
+        return ndimage.rotate(image, self.angle, axes=rot_axes, reshape=False)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(angle = {})'.format(self.angle)
+
+
+class Noise(InvariantTransform):
+
+    def __init__(self, mode, mean=0, var=0.05):
+        super().__init__()
+
+        # check which kind of noise to apply
+        modes = ['gaussian', 'speckle']
+        if mode not in modes:
+            raise ValueError('Supported noise types are: {}.'.format(modes))
+        self.mode = mode
+
+        # mean and variance of the gaussian distribution the noise signal is
+        # sampled from
+        self.mean = mean
+        self.var = var
+
+    def __call__(self, image):
+
+        # generate gaussian noise
+        noise = np.random.normal(self.mean, self.var, image.shape)
+
+        if self.mode == 'gaussian':
+            return (np.asarray(image) + noise).clip(0, 1)
+
+        if self.mode == 'speckle':
+            return (np.asarray(image) + np.asarray(image) * noise).clip(0, 1)
+
+    def __repr__(self):
+        return self.__class__.__name__ + ('(mode = {}, mean = {}, var = {})'
+                                          .format(self.mode, self.mean,
+                                                  self.var))
+
+
+class Augment(object):
+
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, image, gt):
+
+        # apply transformations to the input image in specified order
+        for t in self.transforms:
+            image = t(image)
+
+            # check whether the transformations are invariant and if not, apply
+            # the transformation also to the ground truth
+            if not t.invariant:
+                gt = t(gt)
+
+        return image, gt
+
+    def __repr__(self):
+        fstring = self.__class__.__name__ + '('
+        for t in self.transforms:
+            fstring += '\n'
+            fstring += '    {0}'.format(t)
+        fstring += '\n)'
+        return fstring
-- 
GitLab