diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py
index 0a83e0fc20b1508b1b150daf28ecc26b929a3769..2918217a762e200471737489b773bf45dbe36972 100644
--- a/climax/main/downscale_infer.py
+++ b/climax/main/downscale_infer.py
@@ -75,7 +75,7 @@ if __name__ == '__main__':
             dem = dem.drop_vars(['slope', 'aspect'])
 
         # add dem to set of predictor variables
-        Era5_ds = xr.merge([Era5_ds, dem])
+        Era5_ds = xr.merge([Era5_ds, dem]).chunk(Era5_ds.chunks)
 
     # load pretrained model
     if state_file.exists():
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 7468fa839c5e16f41f94e4e5addb90d552e53457..1535759bd3d24c586ef3228e1eda4d394a7cca81 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -101,7 +101,7 @@ if __name__ == '__main__':
 
         # check whether to use slope and aspect
         if not DEM_FEATURES:
-            dem = dem.drop_vars(['slope', 'aspect'])
+            dem = dem.drop_vars(['slope', 'aspect']).chunk(Era5_ds.chunks)
 
         # add dem to set of predictor variables
         Era5_ds = xr.merge([Era5_ds, dem])