Skip to content
Snippets Groups Projects
Commit 85e41d45 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented first version of downscaling script.

parent 65165449
No related branches found
No related tags found
No related merge requests found
......@@ -75,3 +75,17 @@ SHUFFLE = False
# batch size: number of time steps processed by the net in each iteration
BATCH_SIZE = 64
# learning rate
LR = 0.001
# network training configuration
TRAIN_CONFIG = {
'checkpoint_state': {},
'epochs': 50,
'save': True,
'early_stop': True,
'patience': 25,
'multi_gpu': True,
'classification': False
}
......@@ -3,7 +3,11 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# builtins
from logging.config import dictConfig
# externals
import torch
import xarray as xr
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
......@@ -11,15 +15,20 @@ from torch.utils.data import DataLoader
# locals
from pysegcnn.core.utils import search_files
from pysegcnn.core.models import SegNet
from pysegcnn.core.trainer import NetworkTrainer
from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.constants import (ERA5_P_VARIABLES, ERA5_S_VARIABLES,
ERA5_VARIABLES)
from climax.core.constants import ERA5_VARIABLES
from climax.main.config import (ERA5_PATH, ERA5_PLEVELS, OBS_PATH, PREDICTAND,
CALIB_PERIOD, MODEL_PATH, SHUFFLE, BATCH_SIZE)
CALIB_PERIOD, MODEL_PATH, SHUFFLE, BATCH_SIZE,
LR, TRAIN_CONFIG)
if __name__ == '__main__':
# initialize logging
dictConfig(log_conf())
# initialize ERA5 predictor dataset
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_VARIABLES,
plevels=ERA5_PLEVELS)
......@@ -45,10 +54,20 @@ if __name__ == '__main__':
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
# initialize network: calculate number of input variables
in_channels = int(len(ERA5_P_VARIABLES) * len(ERA5_PLEVELS) +
len(ERA5_S_VARIABLES))
net = SegNet(MODEL_PATH.joinpath(PREDICTAND + '.pt'), in_channels, 1)
# initialize network
net = SegNet(MODEL_PATH.joinpath(PREDICTAND + '.pt'),
len(Era5_ds.data_vars), 1)
# initialize optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
# initialize loss function
loss_function = torch.nn.MSELoss()
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=loss_function,
**TRAIN_CONFIG)
# initialize network training
# TODO: Extend ClassificationNetworkTrainer -> RegressionNetworkTrainer
# train model
state = trainer.train()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment