diff --git a/main/data.py b/main/init.py
similarity index 72%
rename from main/data.py
rename to main/init.py
index 169407ef5eafde13d78fe8cc24a4cb71141b387f..503a2142d389d69a28f75dc841cf17b80645bf96 100755
--- a/main/data.py
+++ b/main/init.py
@@ -15,16 +15,23 @@ import torch
 sys.path.append('..')
 
 # local modules
-from pytorch.dataset import SparcsDataset
+from pytorch.dataset import SparcsDataset, Cloud95Dataset
 from pytorch.trainer import NetworkTrainer
 from pytorch.models import SegNet
-from main.config import (dataset_path, bands, tile_size, tvratio, filters,
-                         skip_connection, kwargs, loss_function, optimizer,
-                         lr, ttratio, batch_size, seed)
+from main.config import (dataset_name, dataset_path, bands, tile_size, tvratio,
+                         filters, skip_connection, kwargs, loss_function,
+                         optimizer, lr, ttratio, batch_size, seed, patches)
 
-
-# instanciate the SparcsDataset class
-dataset = SparcsDataset(dataset_path, bands, tile_size)
+# check which dataset the model is trained on
+if dataset_name == 'Sparcs':
+    # instanciate the SparcsDataset
+    dataset = SparcsDataset(dataset_path, use_bands=bands, tile_size=tile_size)
+elif dataset_name == 'Cloud95':
+    dataset = Cloud95Dataset(dataset_path, use_bands=bands,
+                             tile_size=tile_size, exclude=patches)
+else:
+    raise ValueError('{} is not a valid dataset. Available datasets are '
+                     '"Sparcs" and "Cloud95".'.format(dataset_name))
 
 # print the bands used for the segmentation
 print('------------------------ Input bands -----------------------------')