Learning curves are a widely used diagnostic tool in machine learning for algorithms such as deep learning that learn incrementally. During training time, we evaluate model performance on both the training and hold-out validation dataset and we plot this performance for each training step (i.e. each epoch of a deep learning model or tree for an ensembled tree model). Reviewing learning curves of models during training can be used to diagnose problems with learning, such as an underfit or overfit model, as well as whether the training and validation datasets are suitably representative. In this notebook, I will illustrate to how you can use learning curves to:
This notebook will demonstrate these issues with learning curve plots but does not show any code.1
The shape and dynamics of a learning curve can be used to diagnose the behavior of a machine learning model and in turn perhaps suggest at the type of configuration changes that may be made to improve learning and/or performance. There are three common dynamics that you are likely to observe in learning curves:
We will take a closer look at each with examples. The examples will assume that we are looking at a minimizing loss metric, meaning that smaller relative scores on the y-axis indicate better performance.
Underfitting refers to a model that has not adequately learned the training dataset to obtain a sufficiently low training error value. There are two common signals for underfitting. First, our training learning curve may show a flat line or noisy values of relatively high loss, indicating that the model was unable to learn the training dataset at all. An example of this is provided below and is common when the model does not have a suitable capacity for the complexity of the dataset.
Solution:
An underfit model may also be identified by a training and validation loss that are continuing to decrease at the end of the plot. This indicates that the model is capable of further learning and that the training process was halted prematurely.
Solution:
Overfitting refers to a model that has learned the training dataset too well, including the statistical noise or random fluctuations in the training dataset.
“… fitting a more flexible model requires estimating a greater number of parameters. These more complex models can lead to a phenomenon known as overfitting the data, which essentially means they follow the errors, or noise, too closely.”2
The problem with overfitting, is that the more specialized the model becomes to training data, the less well it is able to generalize to new data, resulting in an increase in generalization error. Overfitting is apparent when:
However, a model that overfits is not necessarily a bad thing. In fact, it signals that the model has extracted all the signal that that particular model could learn. The issues to be concerned about with overfitting is the magnitude and the inflection point.
A model that overfits early and has a sharp “U” shape often indicates overcapacity and/or a learning rate that is too high.
Solution:
Often, we can minimize overfitting but rarely can we completely eliminate it and still minimize our loss. The following illustrates an example where we have minimized overfitting, yet some overfitting still exists.
Solution:
restore_best_weights = TRUE
to your callback so that your final model uses the weights from the epoch with the best loss score.An optimal fit is the goal of the learning algorithm. The loss of the model will almost always be lower on the training dataset than the validation dataset. This means that we should expect some gap between the train and validation loss learning curves. This gap is referred to as the generalization gap. An optimal fit is one where:
Continued training of an optimal fit will likely lead to overfitting. The example plot below demonstrates a case of an optimal fit assuming we have found a global minimum of our loss function.
Learning curves can also be used to diagnose properties of a dataset and whether it is relatively representative. An unrepresentative dataset means a dataset that may not capture the statistical characteristics relative to another dataset drawn from the same domain, such as between a train and a validation dataset. This can commonly occur if the number of samples in a dataset is too small or if certain characteristics are not adequately represented, relative to another dataset.
There are two common cases that could be observed; they are:
An unrepresentative training dataset means that the training dataset does not provide sufficient information to learn the problem, relative to the validation dataset used to evaluate it. This situation can be identified by a learning curve for training loss that shows improvement and similarly a learning curve for validation loss that shows improvement, but a large gap remains between both curves. This can occur when
Solution:
An unrepresentative validation dataset means that the validation dataset does not provide sufficient information to evaluate the ability of the model to generalize. This may occur if the validation dataset has too few examples as compared to the training dataset. This case can be identified by a learning curve for training loss that looks like a good fit (or other fits) and a learning curve for validation loss that shows noisy movements and little or no improvement.