diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index e0499c397d1e57de8c4fab5975d62c36c94465f7..381f954a08a1545b7dc745f86e9699dd5f31132e 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -117,8 +117,10 @@ if __name__ == '__main__': net = NET(state_file, inputs, outputs, filters=FILTERS) # initialize optimizer - optimizer = torch.optim.Adam(net.parameters(), lr=LR, - weight_decay=LAMBDA) + # optimizer = torch.optim.Adam(net.parameters(), lr=LR, + # weight_decay=LAMBDA) + optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9, + weight_decay=LAMBDA) # initialize training data LogConfig.init_log('Initializing training data.')