Skip to content
Snippets Groups Projects
Commit 38e69b2e authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added a high-level training class for easier model training

parent 472a11d4
No related branches found
No related tags found
No related merge requests found
# !/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 26 16:33:46 2020
@author: Daniel
"""
# externals
import torch
import torch.nn.functional as F
def predict(model, dataloader, optimizer, accuracy, state_file=None):
# check whether a gpu is available for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load the model state if provided
if state_file is not None:
model.load(optimizer, state_file)
# set the model to evaluation mode
model.eval()
# initialize confusion matrix
cm = torch.zeros(model.nclasses, model.nclasses)
# number of batches in the validation set
nbatches = int(len(dataloader.dataset) / dataloader.batch_size)
# iterate over the validation/test set
accuracies = []
for batch, (inputs, labels) in enumerate(dataloader):
# send the data to the gpu if available
inputs = inputs.to(device)
labels = labels.to(device)
# calculate network outputs
with torch.no_grad():
outputs = model(inputs)
# calculate predicted class labels
pred = F.softmax(outputs, dim=1).argmax(dim=1)
# calculate accuracy
acc = accuracy(pred, labels)
accuracies.append(acc)
# print progress
print('Batch: {:d}/{:d}, Accuracy: {:.2f}'.format(batch,
nbatches, acc))
# update confusion matrix
for ytrue, ypred in zip(labels.view(-1), pred.view(-1)):
cm[ytrue.long(), ypred.long()] += 1
# save confusion matrix and accuracies to file
torch.save({'cm': cm, 'accuracy': accuracies},
state.split('.pt')[0] + '_cm.pt')
return cm
def accuracy_function(outputs, labels):
return (outputs == labels).float().mean()
......@@ -5,105 +5,303 @@ Created on Fri Jun 26 16:31:36 2020
@author: Daniel
"""
# builtins
import os
# externals
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import random_split
from torch.utils.data import random_split, DataLoader
def train(model, dataloader, loss_function, optimizer, accuracy, state_file,
epochs=1, nthreads=1):
class NetworkTrainer(object):
# set the number of threads
torch.set_num_threads(nthreads)
def __init__(self, model, dataset, loss_function, optimizer, batch_size=32,
tvratio=0.8, ttratio=1, seed=0):
# check whether a gpu is available for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# the model to train
self.model = model
# send the model to the gpu if available
model = model.to(device)
# the dataset to train the model on
self.dataset = dataset
# set the model to training mode
model.train()
# the training and validation dataset
self.train_ds, self.valid_ds, self.test_ds = self.train_val_test_split(
self.dataset, tvratio, ttratio, seed)
# number of batches per epoch
nbatches = int(len(dataloader.dataset) / dataloader.batch_size)
# the batch size
self.batch_size = batch_size
# initialize the training: iterate over the entire training data set
for epoch in range(epochs):
# the training and validation dataloaders
self.train_dl = DataLoader(self.train_ds, batch_size, shuffle=True)
self.valid_dl = DataLoader(self.valid_ds, batch_size, shuffle=True)
# create a list of the observed losses and accuracies
losses = []
accuracies = []
# the loss function to compute the model error
self.loss_function = loss_function
# iterate over the dataloader object
for batch, (inputs, labels) in enumerate(dataloader):
# the optimizer used to update the model weights
self.optimizer = optimizer
# send the data to the gpu if available
inputs = inputs.to(device)
labels = labels.to(device)
# whether to use the gpu
self.device = torch.device("cuda:0" if torch.cuda.is_available() else
"cpu")
def ds_len(self, ds, ratio):
return int(np.round(len(ds) * ratio))
def train_val_test_split(self, ds, tvratio, ttratio=1, seed=0):
# set the random seed for reproducibility
torch.manual_seed(seed)
# length of the training and validation dataset
trav_len = self.ds_len(ds, ttratio)
# length of the test dataset
test_len = self.ds_len(ds, 1 - ttratio)
# split dataset into training and test set
# (ttratio * 100) % will be used for training and validation
train_val_ds, test_ds = random_split(ds, (trav_len, test_len))
# length of the training set
train_len = self.ds_len(train_val_ds, tvratio)
# length of the validation dataset
valid_len = self.ds_len(train_val_ds, 1 - tvratio)
# split the training set into training and validation set
train_ds, valid_ds = random_split(train_val_ds, (train_len, valid_len))
# print the dataset ratios
print(*['{} set: {:.2f}%'.format(k, v * 100) for k, v in
{'Training': ttratio * tvratio,
'Validation': ttratio * (1 - tvratio),
'Test': 1 - ttratio}.items()], sep='\n')
return train_ds, valid_ds, test_ds
def accuracy_function(self, outputs, labels):
return (outputs == labels).float().mean()
def train(self, state_path, state_file, epochs=1, resume=True,
early_stop=True, nthreads=os.cpu_count(), **kwargs):
# set the number of threads
torch.set_num_threads(nthreads)
# instanciate early stopping class
if early_stop:
es = EarlyStopping(**kwargs)
# initial accuracy on the validation set
max_accuracy = 0
# whether to resume training from an existing model
if os.path.exists(os.path.join(state_path, state_file)) and resume:
state = self.model.load(self.optimizer, state_file, state_path)
print('Resuming training from {} ...'.format(state))
print('Model epoch: {:d}'.format(self.model.epoch))
# send the model to the gpu if available
self.model = self.model.to(self.device)
# number of batches per epoch
nbatches = int(len(self.train_ds) / self.batch_size)
# initialize the training: iterate over the entire training data set
for epoch in range(epochs):
# set the model to training mode
print('Setting model to training mode ...')
self.model.train()
# create arrays of the observed losses and accuracies
losses = np.zeros(shape=(nbatches, epochs))
accuracies = np.zeros(shape=(nbatches, epochs))
# iterate over the dataloader object
for batch, (inputs, labels) in enumerate(self.train_dl):
# send the data to the gpu if available
inputs = inputs.to(self.device)
labels = labels.to(self.device)
# reset the gradients
self.optimizer.zero_grad()
# perform forward pass
outputs = self.model(inputs)
# compute loss
loss = self.loss_function(outputs, labels.long())
losses[batch, epoch] = loss.detach().numpy().item()
# reset the gradients
optimizer.zero_grad()
# compute the gradients of the loss function w.r.t.
# the network weights
loss.backward()
# perform forward pass
outputs = model(inputs)
# update the weights
self.optimizer.step()
# compute loss
loss = loss_function(outputs, labels.long())
losses.append(loss.detach().numpy().item())
# calculate predicted class labels
ypred = F.softmax(outputs, dim=1).argmax(dim=1)
# compute the gradients of the loss function w.r.t.
# the network weights
loss.backward()
# calculate accuracy on current batch
acc = self.accuracy_function(ypred, labels)
accuracies[batch, epoch] = acc
# update the weights
optimizer.step()
# print progress
print('Epoch: {:d}/{:d}, Batch: {:d}/{:d}, Loss: {:.2f}, '
'Accuracy: {:.2f}'.format(epoch, epochs, batch, nbatches,
losses[batch, epoch], acc))
# update the number of epochs trained
self.model.epoch += 1
# whether to evaluate model performance on the validation set and
# early stop the training process
if early_stop:
# model predictions on the validation set
_, validation_accuracies = self.predict()
# metric to assess model performance on the validation set
epoch_acc = validation_accuracies.mean()
# whether the model improved with respect to the previous epoch
if epoch_acc > max_accuracy:
max_accuracy = epoch_acc
# save model state if the model improved with
# respect to the previous epoch
state = self.model.save(self.optimizer, state_file,
state_path)
# whether the early stopping criterion is met
if es.stop(epoch_acc):
break
else:
# if no early stopping is required, the model state is saved
# after each epoch
state = self.model.save(self.optimizer, state_file, state_path)
# save losses and accuracy after each epoch to file
torch.save({'loss': losses, 'accuracy': accuracies},
state.split('.pt')[0] + '_loss.pt')
return losses, accuracies
def predict(self, state_path=None, state_file=None, confusion=False):
# load the model state if provided
if state_file is not None:
state = self.model.load(self.optimizer, state_file, state_path)
# send the model to the gpu if available
self.model = self.model.to(self.device)
# set the model to evaluation mode
print('Setting model to evaluation mode ...')
self.model.eval()
# initialize confusion matrix
cm = torch.zeros(self.model.nclasses, self.model.nclasses)
# number of batches in the validation set
nbatches = int(len(self.valid_ds) / self.batch_size)
# iterate over the validation/test set
print('Calculating accuracy on validation set ...')
accuracies = np.zeros(shape=(nbatches, 1))
for batch, (inputs, labels) in enumerate(self.valid_dl):
# send the data to the gpu if available
inputs = inputs.to(self.device)
labels = labels.to(self.device)
# calculate network outputs
with torch.no_grad():
outputs = self.model(inputs)
# calculate predicted class labels
ypred = F.softmax(outputs, dim=1).argmax(dim=1)
pred = F.softmax(outputs, dim=1).argmax(dim=1)
# calculate accuracy
acc = accuracy(ypred, labels)
accuracies.append(acc)
# calculate accuracy on current batch
acc = self.accuracy_function(pred, labels)
accuracies[batch, 0] = acc
# print progress
print('Epoch: {:d}/{:d}, Batch: {:d}/{:d}, Loss: {:.2f}, '
'Accuracy: {:.2f}'.format(epoch, epochs, batch, nbatches,
loss.detach().numpy().item(), acc))
print('Batch: {:d}/{:d}, Accuracy: {:.2f}'.format(batch,
nbatches, acc))
# update confusion matrix
if confusion:
for ytrue, ypred in zip(labels.view(-1), pred.view(-1)):
cm[ytrue.long(), ypred.long()] += 1
# calculate overall accuracy on the validation set
print('Current mean accuracy on the validation set: {:.2f}%'
.format(accuracies.mean() * 100))
# save confusion matrix and accuracies to file
if state_file is not None and confusion:
torch.save({'cm': cm}, state.split('.pt')[0] + '_cm.pt')
return cm, accuracies
class EarlyStopping(object):
def __init__(self, mode='min', min_delta=0, patience=5):
# check if mode is correctly specified
if mode not in ['min', 'max']:
raise ValueError('Mode "{}" not supported. '
'Mode is either "min" (check whether the metric '
'decreased, e.g. loss) or "max" (check whether '
'the metric increased, e.g. accuracy).'
.format(mode))
# update the number of epochs trained
model.epoch += 1
# whether to check for an increase or a decrease in a given metric
self.is_better = self.decreased if mode == 'min' else self.increased
# save model state to file
state = model.save(optimizer, state_file)
# minimum change in metric to be classified as an improvement
self.min_delta = min_delta
# save losses and accuracy to file
torch.save({'loss': losses, 'accuracy': accuracies},
state.split('.pt')[0] + '_loss.pt')
# number of epochs to wait for improvement
self.patience = patience
return losses, accuracies
# initialize best metric
self.best = None
# initialize early stopping flag
self.early_stop = False
def ds_len(ds, ratio):
return int(np.round(len(ds) * ratio))
def stop(self, metric):
if self.best is not None:
def train_test_split(ds, ratio, seed=0):
# if the metric improved, reset the epochs counter, else, advance
if self.is_better(metric, self.best, self.min_delta):
self.counter = 0
self.best = metric
else:
self.counter += 1
# set the random seed for reproducibility
torch.manual_seed(seed)
# if the metric did not improve over the last patience epochs,
# the early stopping criterion is met
if self.counter >= self.patience:
self.early_stop = True
# length of the training and validation data set
train_len = ds_len(ds, ratio)
else:
self.best = metric
# length of the test data set
test_len = ds_len(ds, 1 - ratio)
return self.early_stop
# split dataset into training and test set
# (ratio * 100) % will be used for training and validation
train_ds, test_ds = random_split(ds, (train_len, test_len))
def decreased(self, metric, best, min_delta):
return metric < best - min_delta
return train_ds, test_ds
def increased(self, metric, best, min_delta):
return metric > best + min_delta
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment