From cd2b660aef27bd71d86212a9b8020d1442f4be2a Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 7 Jul 2020 15:35:26 +0000
Subject: [PATCH] Renamed file

---
 main/{data.py => init.py} | 21 ++++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)
 rename main/{data.py => init.py} (72%)

diff --git a/main/data.py b/main/init.py
similarity index 72%
rename from main/data.py
rename to main/init.py
index 169407e..503a214 100755
--- a/main/data.py
+++ b/main/init.py
@@ -15,16 +15,23 @@ import torch
 sys.path.append('..')
 
 # local modules
-from pytorch.dataset import SparcsDataset
+from pytorch.dataset import SparcsDataset, Cloud95Dataset
 from pytorch.trainer import NetworkTrainer
 from pytorch.models import SegNet
-from main.config import (dataset_path, bands, tile_size, tvratio, filters,
-                         skip_connection, kwargs, loss_function, optimizer,
-                         lr, ttratio, batch_size, seed)
+from main.config import (dataset_name, dataset_path, bands, tile_size, tvratio,
+                         filters, skip_connection, kwargs, loss_function,
+                         optimizer, lr, ttratio, batch_size, seed, patches)
 
-
-# instanciate the SparcsDataset class
-dataset = SparcsDataset(dataset_path, bands, tile_size)
+# check which dataset the model is trained on
+if dataset_name == 'Sparcs':
+    # instanciate the SparcsDataset
+    dataset = SparcsDataset(dataset_path, use_bands=bands, tile_size=tile_size)
+elif dataset_name == 'Cloud95':
+    dataset = Cloud95Dataset(dataset_path, use_bands=bands,
+                             tile_size=tile_size, exclude=patches)
+else:
+    raise ValueError('{} is not a valid dataset. Available datasets are '
+                     '"Sparcs" and "Cloud95".'.format(dataset_name))
 
 # print the bands used for the segmentation
 print('------------------------ Input bands -----------------------------')
-- 
GitLab