diff --git a/main/data.py b/main/init.py similarity index 72% rename from main/data.py rename to main/init.py index 169407ef5eafde13d78fe8cc24a4cb71141b387f..503a2142d389d69a28f75dc841cf17b80645bf96 100755 --- a/main/data.py +++ b/main/init.py @@ -15,16 +15,23 @@ import torch sys.path.append('..') # local modules -from pytorch.dataset import SparcsDataset +from pytorch.dataset import SparcsDataset, Cloud95Dataset from pytorch.trainer import NetworkTrainer from pytorch.models import SegNet -from main.config import (dataset_path, bands, tile_size, tvratio, filters, - skip_connection, kwargs, loss_function, optimizer, - lr, ttratio, batch_size, seed) +from main.config import (dataset_name, dataset_path, bands, tile_size, tvratio, + filters, skip_connection, kwargs, loss_function, + optimizer, lr, ttratio, batch_size, seed, patches) - -# instanciate the SparcsDataset class -dataset = SparcsDataset(dataset_path, bands, tile_size) +# check which dataset the model is trained on +if dataset_name == 'Sparcs': + # instanciate the SparcsDataset + dataset = SparcsDataset(dataset_path, use_bands=bands, tile_size=tile_size) +elif dataset_name == 'Cloud95': + dataset = Cloud95Dataset(dataset_path, use_bands=bands, + tile_size=tile_size, exclude=patches) +else: + raise ValueError('{} is not a valid dataset. Available datasets are ' + '"Sparcs" and "Cloud95".'.format(dataset_name)) # print the bands used for the segmentation print('------------------------ Input bands -----------------------------')