From 8be76bef1f341166f57b97d5c8c30b76260f25b8 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 17 Jul 2020 17:05:43 +0200 Subject: [PATCH] Added smarter import strategy --- main/config.py | 5 +---- main/eval.py | 5 +---- main/train.py | 5 +---- main/transfer.py | 44 -------------------------------------------- pytorch/constants.py | 2 ++ pytorch/models.py | 5 +---- pytorch/trainer.py | 5 +---- 7 files changed, 7 insertions(+), 64 deletions(-) delete mode 100644 main/transfer.py diff --git a/main/config.py b/main/config.py index 277c11f..1aaa85d 100755 --- a/main/config.py +++ b/main/config.py @@ -6,17 +6,14 @@ Created on Tue Jun 30 11:40:35 2020 @author: Daniel """ # builtins +from __future__ import absolute_import import os -import sys import inspect # externals import torch.nn as nn import torch.optim as optim -# append path to local files to the python search path -sys.path.append('..') - from pytorch.models import UNet # ------------------------- Dataset configuration ----------------------------- diff --git a/main/eval.py b/main/eval.py index da3fa14..d5b18d7 100755 --- a/main/eval.py +++ b/main/eval.py @@ -1,6 +1,6 @@ # builtins +from __future__ import absolute_import import os -import sys # externals import numpy as np @@ -8,9 +8,6 @@ import torch import torch.nn.functional as F import matplotlib.pyplot as plt -# append path to local files to the python search path -sys.path.append('..') - # local modules from pytorch.trainer import NetworkTrainer from pytorch.graphics import plot_confusion_matrix, plot_loss, plot_sample diff --git a/main/train.py b/main/train.py index a14e22a..840fe24 100755 --- a/main/train.py +++ b/main/train.py @@ -6,15 +6,12 @@ Created on Tue Jun 30 09:33:38 2020 @author: Daniel """ # builtins +from __future__ import absolute_import import os -import sys # externals import torch -# append path to local files to the python search path -sys.path.append('..') - # local modules from pytorch.trainer import NetworkTrainer from main.config import config diff --git a/main/transfer.py b/main/transfer.py deleted file mode 100644 index fdce169..0000000 --- a/main/transfer.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Wed Jul 15 09:45:49 2020 - -@author: Daniel -""" - -# builtins -import os -import sys - -# externals -import numpy as np -import torch -import torch.nn.functional as F - -# append path to local files to the python search path -sys.path.append('..') - -# local modules -from pytorch.trainer import NetworkTrainer -from pytorch.layers import Conv2dSame -from main.config import config - - -if __name__ == '__main__': - - # instanciate the NetworkTrainer class - trainer = NetworkTrainer(config) - trainer.initialize() - - # freeze the model state - trainer.model.freeze() - - # get the number of input features to the model classifier - in_features = trainer.model.classifier.in_channels - - # replace the classification layer - trainer.model.classifier = Conv2dSame(in_channels=in_features, - out_channels=len(trainer.dataset.labels), - kernel_size=1) - - # train the model on the new dataset - trainer.train() diff --git a/pytorch/constants.py b/pytorch/constants.py index ccb29c1..6f7cd56 100644 --- a/pytorch/constants.py +++ b/pytorch/constants.py @@ -7,6 +7,7 @@ Created on Tue Jul 14 10:58:20 2020 # builtins import enum + # Landsat 8 bands class Landsat8(enum.Enum): violet = 1 @@ -38,6 +39,7 @@ class Sentinel2(enum.Enum): swir1 = 11 swir2 = 12 + # labels of the Sparcs dataset class SparcsLabels(enum.Enum): Shadow = 0, 'black' diff --git a/pytorch/models.py b/pytorch/models.py index e1add80..17f5bac 100644 --- a/pytorch/models.py +++ b/pytorch/models.py @@ -6,17 +6,14 @@ Created on Fri Jun 26 16:31:36 2020 @author: Daniel """ # builtins +from __future__ import absolute_import import os -import sys # externals import numpy as np import torch import torch.nn as nn -# append path to local files to the python search path -sys.path.append('..') - # locals from pytorch.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool, Conv2dUpsample, Conv2dSame) diff --git a/pytorch/trainer.py b/pytorch/trainer.py index dbcfb80..7376ccd 100755 --- a/pytorch/trainer.py +++ b/pytorch/trainer.py @@ -6,8 +6,8 @@ Created on Fri Jun 26 16:31:36 2020 @author: Daniel """ # builtins +from __future__ import absolute_import import os -import sys # externals import numpy as np @@ -15,9 +15,6 @@ import torch import torch.nn.functional as F from torch.utils.data import random_split, DataLoader -# append path to local files to the python search path -sys.path.append('..') - # local modules from pytorch.dataset import SparcsDataset, Cloud95Dataset from pytorch.layers import Conv2dSame -- GitLab