diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py
index d0c39d6f8f895323578804d65e90a4d64f133325..3529dc4dbcac4e90b7bc200387a0c241149ab31b 100644
--- a/pysegcnn/core/utils.py
+++ b/pysegcnn/core/utils.py
@@ -2255,7 +2255,8 @@ def dec2bin(number, nbits=8):
     return binary
 
 
-def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
+def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False,
+                    src_no_data=None, trg_no_data=0):
     """Extract raster values by mask.
 
     Extract the extent of ``mask_ds`` from ``src_ds``. The masked values of
@@ -2275,6 +2276,12 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
         The masked target dataset.
     overwrite : `bool`, optional
         Whether to overwrite ``trg_ds``, if it exists. The default is `False`.
+    src_no_data : `int` or `float`, optional
+        The value of NoData values in ``src_ds``. The default is `None`, which
+        means the value is read from ``src_ds``. If specified, values equal to
+        ``src_no_data`` are masked as ``trg_no_data`` in ``trg_ds``.
+    trg_no_data : `int` or `float`, optional
+        The value to assign to NoData values in ``trg_ds``. The default is `0`.
 
     """
     # convert path to source dataset and mask dataset to pathlib.Path object
@@ -2303,9 +2310,6 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
     # read the source dataset
     src_ds = gdal.Open(str(src_path))
 
-    # the source dataset NoData value
-    no_data = src_ds.GetRasterBand(1).GetNoDataValue()
-
     # source dataset spatial reference
     src_sr = osr.SpatialReference()
     src_sr.ImportFromWkt(src_ds.GetProjection())
@@ -2318,7 +2322,8 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
                   cropToCutline=True,
                   xRes=src_ds.GetGeoTransform()[1],
                   yRes=src_ds.GetGeoTransform()[5],
-                  dstNodata=no_data)
+                  srcNodata=src_no_data,
+                  dstNodata=trg_no_data)
 
     else:
         # mask is a raster dataset
@@ -2341,8 +2346,8 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
         # TransfromPoint expects input:
         #   - gdal >= 3.0: x, y, z = TransformPoint(y, x)
         #   - gdal < 3.0 : x, y, z = TransformPoint(x, y)
-        x_tl, y_tl, _ = crs_tr.TransformPoint(extent[-1], extent[0])
-        x_br, y_br, _ = crs_tr.TransformPoint(extent[2], extent[1])
+        x_tl, y_tl, _ = crs_tr.TransformPoint(extent[0], extent[-1])
+        x_br, y_br, _ = crs_tr.TransformPoint(extent[1], extent[2])
 
         # extent of the mask in the source reference coordinate system:
         # (x_min, y_min, x_max, y_max)
@@ -2352,7 +2357,9 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
         gdal.Warp(str(trg_path), str(src_path),
                   outputBounds=extent,
                   xRes=src_ds.GetGeoTransform()[1],
-                  yRes=src_ds.GetGeoTransform()[5])
+                  yRes=src_ds.GetGeoTransform()[5],
+                  srcNodata=src_no_data,
+                  dstNodata=trg_no_data)
 
     # clear source and mask dataset
     del src_ds, mask_ds