Skip to content
Snippets Groups Projects
Commit 8ff83719 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Made transformation classes callable for easier execution

parent a991fd99
No related branches found
No related tags found
No related merge requests found
......@@ -10,16 +10,16 @@ import numpy as np
from scipy import ndimage
class Transformation(object):
class Transform(object):
def __init__(self):
pass
def apply(self, image):
def __call__(self, image):
raise NotImplementedError
class VariantTransformation(Transformation):
class VariantTransform(Transform):
def __init__(self):
......@@ -27,7 +27,7 @@ class VariantTransformation(Transformation):
self.invariant = False
class InvariantTransformation(Transformation):
class InvariantTransform(Transform):
def __init__(self):
......@@ -35,30 +35,37 @@ class InvariantTransformation(Transformation):
self.invariant = True
class FlipLr(VariantTransformation):
class FlipLr(VariantTransform):
def __init__(self):
super().__init__()
def apply(self, image):
def __call__(self, image):
return np.asarray(image)[..., ::-1]
def __repr__(self):
return self.__class__.__name__
class FlipUd(VariantTransformation):
class FlipUd(VariantTransform):
def __init__(self):
super().__init__()
def _transform(self, image):
def __call__(self, image):
return np.asarray(image)[..., ::-1, :]
def __repr__(self):
return self.__class__.__name__
class Rotate(VariantTransformation):
def __init__(self):
class Rotate(VariantTransform):
def __init__(self, angle):
self.angle = angle
super().__init__()
def apply(self, image, angle):
def __call__(self, image):
# check dimension of input image
ndim = np.asarray(image).ndim
......@@ -68,4 +75,67 @@ class Rotate(VariantTransformation):
if ndim > 2:
rot_axes = (ndim - 2, ndim - 1)
return ndimage.rotate(image, angle, axes=rot_axes, reshape=False)
return ndimage.rotate(image, self.angle, axes=rot_axes, reshape=False)
def __repr__(self):
return self.__class__.__name__ + '(angle = {})'.format(self.angle)
class Noise(InvariantTransform):
def __init__(self, mode, mean=0, var=0.05):
super().__init__()
# check which kind of noise to apply
modes = ['gaussian', 'speckle']
if mode not in modes:
raise ValueError('Supported noise types are: {}.'.format(modes))
self.mode = mode
# mean and variance of the gaussian distribution the noise signal is
# sampled from
self.mean = mean
self.var = var
def __call__(self, image):
# generate gaussian noise
noise = np.random.normal(self.mean, self.var, image.shape)
if self.mode == 'gaussian':
return (np.asarray(image) + noise).clip(0, 1)
if self.mode == 'speckle':
return (np.asarray(image) + np.asarray(image) * noise).clip(0, 1)
def __repr__(self):
return self.__class__.__name__ + ('(mode = {}, mean = {}, var = {})'
.format(self.mode, self.mean,
self.var))
class Augment(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, gt):
# apply transformations to the input image in specified order
for t in self.transforms:
image = t(image)
# check whether the transformations are invariant and if not, apply
# the transformation also to the ground truth
if not t.invariant:
gt = t(gt)
return image, gt
def __repr__(self):
fstring = self.__class__.__name__ + '('
for t in self.transforms:
fstring += '\n'
fstring += ' {0}'.format(t)
fstring += '\n)'
return fstring
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment