Skip to content
Snippets Groups Projects
Commit 8be76bef authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added smarter import strategy

parent 8fdc92fd
No related branches found
No related tags found
No related merge requests found
...@@ -6,17 +6,14 @@ Created on Tue Jun 30 11:40:35 2020 ...@@ -6,17 +6,14 @@ Created on Tue Jun 30 11:40:35 2020
@author: Daniel @author: Daniel
""" """
# builtins # builtins
from __future__ import absolute_import
import os import os
import sys
import inspect import inspect
# externals # externals
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
# append path to local files to the python search path
sys.path.append('..')
from pytorch.models import UNet from pytorch.models import UNet
# ------------------------- Dataset configuration ----------------------------- # ------------------------- Dataset configuration -----------------------------
......
# builtins # builtins
from __future__ import absolute_import
import os import os
import sys
# externals # externals
import numpy as np import numpy as np
...@@ -8,9 +8,6 @@ import torch ...@@ -8,9 +8,6 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# append path to local files to the python search path
sys.path.append('..')
# local modules # local modules
from pytorch.trainer import NetworkTrainer from pytorch.trainer import NetworkTrainer
from pytorch.graphics import plot_confusion_matrix, plot_loss, plot_sample from pytorch.graphics import plot_confusion_matrix, plot_loss, plot_sample
......
...@@ -6,15 +6,12 @@ Created on Tue Jun 30 09:33:38 2020 ...@@ -6,15 +6,12 @@ Created on Tue Jun 30 09:33:38 2020
@author: Daniel @author: Daniel
""" """
# builtins # builtins
from __future__ import absolute_import
import os import os
import sys
# externals # externals
import torch import torch
# append path to local files to the python search path
sys.path.append('..')
# local modules # local modules
from pytorch.trainer import NetworkTrainer from pytorch.trainer import NetworkTrainer
from main.config import config from main.config import config
......
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 15 09:45:49 2020
@author: Daniel
"""
# builtins
import os
import sys
# externals
import numpy as np
import torch
import torch.nn.functional as F
# append path to local files to the python search path
sys.path.append('..')
# local modules
from pytorch.trainer import NetworkTrainer
from pytorch.layers import Conv2dSame
from main.config import config
if __name__ == '__main__':
# instanciate the NetworkTrainer class
trainer = NetworkTrainer(config)
trainer.initialize()
# freeze the model state
trainer.model.freeze()
# get the number of input features to the model classifier
in_features = trainer.model.classifier.in_channels
# replace the classification layer
trainer.model.classifier = Conv2dSame(in_channels=in_features,
out_channels=len(trainer.dataset.labels),
kernel_size=1)
# train the model on the new dataset
trainer.train()
...@@ -7,6 +7,7 @@ Created on Tue Jul 14 10:58:20 2020 ...@@ -7,6 +7,7 @@ Created on Tue Jul 14 10:58:20 2020
# builtins # builtins
import enum import enum
# Landsat 8 bands # Landsat 8 bands
class Landsat8(enum.Enum): class Landsat8(enum.Enum):
violet = 1 violet = 1
...@@ -38,6 +39,7 @@ class Sentinel2(enum.Enum): ...@@ -38,6 +39,7 @@ class Sentinel2(enum.Enum):
swir1 = 11 swir1 = 11
swir2 = 12 swir2 = 12
# labels of the Sparcs dataset # labels of the Sparcs dataset
class SparcsLabels(enum.Enum): class SparcsLabels(enum.Enum):
Shadow = 0, 'black' Shadow = 0, 'black'
......
...@@ -6,17 +6,14 @@ Created on Fri Jun 26 16:31:36 2020 ...@@ -6,17 +6,14 @@ Created on Fri Jun 26 16:31:36 2020
@author: Daniel @author: Daniel
""" """
# builtins # builtins
from __future__ import absolute_import
import os import os
import sys
# externals # externals
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
# append path to local files to the python search path
sys.path.append('..')
# locals # locals
from pytorch.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool, from pytorch.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool,
Conv2dUpsample, Conv2dSame) Conv2dUpsample, Conv2dSame)
......
...@@ -6,8 +6,8 @@ Created on Fri Jun 26 16:31:36 2020 ...@@ -6,8 +6,8 @@ Created on Fri Jun 26 16:31:36 2020
@author: Daniel @author: Daniel
""" """
# builtins # builtins
from __future__ import absolute_import
import os import os
import sys
# externals # externals
import numpy as np import numpy as np
...@@ -15,9 +15,6 @@ import torch ...@@ -15,9 +15,6 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader from torch.utils.data import random_split, DataLoader
# append path to local files to the python search path
sys.path.append('..')
# local modules # local modules
from pytorch.dataset import SparcsDataset, Cloud95Dataset from pytorch.dataset import SparcsDataset, Cloud95Dataset
from pytorch.layers import Conv2dSame from pytorch.layers import Conv2dSame
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment