diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py
index 2d5cc84097f4b0bea00d565c1460581823a67092..1f9fa2b645f057be27f7e86b4d21609695f038aa 100644
--- a/pysegcnn/core/models.py
+++ b/pysegcnn/core/models.py
@@ -9,6 +9,7 @@ Created on Fri Jun 26 16:31:36 2020
 import os
 import enum
 import logging
+import pathlib
 
 # externals
 import numpy as np
@@ -29,6 +30,9 @@ class Network(nn.Module):
     def __init__(self):
         super().__init__()
 
+        # initialize state file
+        self.state_file = None
+
     def freeze(self):
         for param in self.parameters():
             param.requires_grad = False
@@ -37,22 +41,22 @@ class Network(nn.Module):
         for param in self.parameters():
             param.requires_grad = True
 
-    def save(self, state_file, optimizer, bands=None,
-             outpath=os.path.join(os.getcwd(), '_models/')):
+    def save(self, state_file, optimizer, bands=None, **kwargs):
 
         # check if the output path exists and if not, create it
-        if not os.path.isdir(outpath):
-            os.makedirs(outpath, exist_ok=True)
+        state_file = pathlib.Path(state_file)
+        if not state_file.parent.is_dir():
+           state_file.parent.mkdir(parents=True, exist_ok=True)
 
         # initialize dictionary to store network parameters
-        model_state = {}
-
-        # store model name
-        model_state['cls'] = self.__class__
+        model_state = {**kwargs}
 
-        # store the bands the model was trained with
+        # store the spectral bands the model is trained with
         model_state['bands'] = bands
 
+        # store model class
+        model_state['cls'] = self.__class__
+
         # store construction parameters to instanciate the network
         model_state['params'] = {
             'skip': self.skip,
@@ -62,7 +66,7 @@ class Network(nn.Module):
             }
 
         # store optional keyword arguments
-        model_state['kwargs'] = self.kwargs
+        model_state['params'] = {**model_state['params'], **self.kwargs}
 
         # store model epoch
         model_state['epoch'] = self.epoch
@@ -72,30 +76,48 @@ class Network(nn.Module):
         model_state['optim_state_dict'] = optimizer.state_dict()
 
         # model state dictionary stores the values of all trainable parameters
-        state = os.path.join(outpath, state_file)
-        torch.save(model_state, state)
-        LOGGER.info('Network parameters saved in {}'.format(state))
+        torch.save(model_state, state_file)
+        LOGGER.info('Network parameters saved in {}'.format(state_file))
 
-        return state
+        return state_file
 
-    def load(self, state_file, optimizer=None,
-             inpath=os.path.join(os.getcwd(), '_models/')):
+    @staticmethod
+    def load(state_file, optimizer=None):
 
-        # load the model state file
-        state = os.path.join(inpath, state_file)
-        model_state = torch.load(state)
+        # load the pretrained model
+        state_file = pathlib.Path(state_file)
+        if not state_file.exists():
+            raise FileNotFoundError('{} does not exist.'.format(state_file))
+        LOGGER.info('Loading pretrained weights from: {}'.format(state_file))
 
-        # resume network parameters
-        LOGGER.info('Loading model parameters ...'.format(state))
-        self.load_state_dict(model_state['model_state_dict'])
-        self.epoch = model_state['epoch']
+        # load the model state
+        model_state = torch.load(state_file)
+
+        # the model class
+        model_class = model_state['cls']
+
+        # instanciate pretrained model architecture
+        model = model_class(**model_state['params'])
+
+        # store state file as instance attribute
+        model.state_file = state_file
+
+        # load pretrained model weights
+        LOGGER.info('Loading model parameters ...')
+        model.load_state_dict(model_state['model_state_dict'])
+        model.epoch = model_state['epoch']
 
         # resume optimizer parameters
         if optimizer is not None:
-            LOGGER.info('Loading optimizer parameters ...'.format(state))
+            LOGGER.info('Loading optimizer parameters ...')
             optimizer.load_state_dict(model_state['optim_state_dict'])
+        LOGGER.info('Model epoch: {:d}'.format(model.epoch))
+
+        return model, optimizer, model_state
 
-        return state
+    @property
+    def state(self):
+        return self.state_file
 
 
 class UNet(Network):