Skip to content
Snippets Groups Projects
Commit 8be76bef authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added smarter import strategy

parent 8fdc92fd
No related branches found
No related tags found
No related merge requests found
......@@ -6,17 +6,14 @@ Created on Tue Jun 30 11:40:35 2020
@author: Daniel
"""
# builtins
from __future__ import absolute_import
import os
import sys
import inspect
# externals
import torch.nn as nn
import torch.optim as optim
# append path to local files to the python search path
sys.path.append('..')
from pytorch.models import UNet
# ------------------------- Dataset configuration -----------------------------
......
# builtins
from __future__ import absolute_import
import os
import sys
# externals
import numpy as np
......@@ -8,9 +8,6 @@ import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
# append path to local files to the python search path
sys.path.append('..')
# local modules
from pytorch.trainer import NetworkTrainer
from pytorch.graphics import plot_confusion_matrix, plot_loss, plot_sample
......
......@@ -6,15 +6,12 @@ Created on Tue Jun 30 09:33:38 2020
@author: Daniel
"""
# builtins
from __future__ import absolute_import
import os
import sys
# externals
import torch
# append path to local files to the python search path
sys.path.append('..')
# local modules
from pytorch.trainer import NetworkTrainer
from main.config import config
......
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 15 09:45:49 2020
@author: Daniel
"""
# builtins
import os
import sys
# externals
import numpy as np
import torch
import torch.nn.functional as F
# append path to local files to the python search path
sys.path.append('..')
# local modules
from pytorch.trainer import NetworkTrainer
from pytorch.layers import Conv2dSame
from main.config import config
if __name__ == '__main__':
# instanciate the NetworkTrainer class
trainer = NetworkTrainer(config)
trainer.initialize()
# freeze the model state
trainer.model.freeze()
# get the number of input features to the model classifier
in_features = trainer.model.classifier.in_channels
# replace the classification layer
trainer.model.classifier = Conv2dSame(in_channels=in_features,
out_channels=len(trainer.dataset.labels),
kernel_size=1)
# train the model on the new dataset
trainer.train()
......@@ -7,6 +7,7 @@ Created on Tue Jul 14 10:58:20 2020
# builtins
import enum
# Landsat 8 bands
class Landsat8(enum.Enum):
violet = 1
......@@ -38,6 +39,7 @@ class Sentinel2(enum.Enum):
swir1 = 11
swir2 = 12
# labels of the Sparcs dataset
class SparcsLabels(enum.Enum):
Shadow = 0, 'black'
......
......@@ -6,17 +6,14 @@ Created on Fri Jun 26 16:31:36 2020
@author: Daniel
"""
# builtins
from __future__ import absolute_import
import os
import sys
# externals
import numpy as np
import torch
import torch.nn as nn
# append path to local files to the python search path
sys.path.append('..')
# locals
from pytorch.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool,
Conv2dUpsample, Conv2dSame)
......
......@@ -6,8 +6,8 @@ Created on Fri Jun 26 16:31:36 2020
@author: Daniel
"""
# builtins
from __future__ import absolute_import
import os
import sys
# externals
import numpy as np
......@@ -15,9 +15,6 @@ import torch
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
# append path to local files to the python search path
sys.path.append('..')
# local modules
from pytorch.dataset import SparcsDataset, Cloud95Dataset
from pytorch.layers import Conv2dSame
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment