From cd2b660aef27bd71d86212a9b8020d1442f4be2a Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Tue, 7 Jul 2020 15:35:26 +0000 Subject: [PATCH] Renamed file --- main/{data.py => init.py} | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) rename main/{data.py => init.py} (72%) diff --git a/main/data.py b/main/init.py similarity index 72% rename from main/data.py rename to main/init.py index 169407e..503a214 100755 --- a/main/data.py +++ b/main/init.py @@ -15,16 +15,23 @@ import torch sys.path.append('..') # local modules -from pytorch.dataset import SparcsDataset +from pytorch.dataset import SparcsDataset, Cloud95Dataset from pytorch.trainer import NetworkTrainer from pytorch.models import SegNet -from main.config import (dataset_path, bands, tile_size, tvratio, filters, - skip_connection, kwargs, loss_function, optimizer, - lr, ttratio, batch_size, seed) +from main.config import (dataset_name, dataset_path, bands, tile_size, tvratio, + filters, skip_connection, kwargs, loss_function, + optimizer, lr, ttratio, batch_size, seed, patches) - -# instanciate the SparcsDataset class -dataset = SparcsDataset(dataset_path, bands, tile_size) +# check which dataset the model is trained on +if dataset_name == 'Sparcs': + # instanciate the SparcsDataset + dataset = SparcsDataset(dataset_path, use_bands=bands, tile_size=tile_size) +elif dataset_name == 'Cloud95': + dataset = Cloud95Dataset(dataset_path, use_bands=bands, + tile_size=tile_size, exclude=patches) +else: + raise ValueError('{} is not a valid dataset. Available datasets are ' + '"Sparcs" and "Cloud95".'.format(dataset_name)) # print the bands used for the segmentation print('------------------------ Input bands -----------------------------') -- GitLab