Why Big Data? Learning Curves
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
by Bob Horton
Microsoft Senior Data Scientist
Learning curves are an elaboration of the idea of validating a model on a test set, and have been widely popularized by Andrew Ng’s Machine Learning course on Coursera. Here I present a simple simulation that illustrates this idea.
Imagine you use a sample of your data to train a model, then use the model to predict the outcomes on data where you know what the real outcome is. Since you know the “real” answer, you can calculate the overall error in your predictions. The error on the same data set used to train the model is called the training error, and the error on an independent sample is called the validation error.
A model will commonly perform better (that is, have lower error) on the data it was trained on than on an independent sample. The difference between the training error and the validation error reflects overfitting of the model. Overfitting is like memorizing the answers for a test instead of learning the principles (to borrow a metaphor from the Wikipedia article). Memorizing works fine if the test is exactly like the study guide, but it doesn’t work very well if the test questions are different; that is, it doesn’t generalize. In fact, the more a model is overfitted, the higher its validation error is likely to be. This is because the spurious correlations the overfitted model memorized from the training set most likely don’t apply in the validation set.
Overfitting is usually more extreme with small training sets. In large training sets the random noise tends to average out, so that the underlying patterns are more clear. But in small training sets, there is less opportunity for averaging out the noise, and accidental correlations consequently have more influence on the model. Learning curves let us visualize this relationship between training set size and the degree of overfitting.
We start with a function to generate simulated data:
sim_data <- function(N, noise_level=1){ X1 <- sample(LETTERS[1:10], N, replace=TRUE) X2 <- sample(LETTERS[1:10], N, replace=TRUE) X3 <- sample(LETTERS[1:10], N, replace=TRUE) y <- 100 + ifelse(X1 == X2, 10, 0) + rnorm(N, sd=noise_level) data.frame(X1, X2, X3, y) }
The input columns X1, X2, and X3 are categorical variables which each have 10 possible values, represented by capital letters A
through J
. The outcome is cleverly named y
; it has a base level of 100, but if the values in the first two X
variables are equal, this is increased by 10. On top of this we add some normally distributed noise. Any other pattern that might appear in the data is accidental.
Now we can use this function to generate a simulated data set for experiments.
set.seed(123) data <- sim_data(25000, noise=10)
There are many possible error functions, but I prefer the root mean squared error:
rmse <- function(actual, predicted) sqrt( mean( (actual - predicted)^2 ))
To generate a learning curve, we fit models at a series of different training set sizes, and calculate the training error and validation error for each model. Then we will plot these errors against the training set size. Here the parameters are a model formula, the data frame of simulated data, the validation set size (vss), the number of different training set sizes we want to plot, and the smallest training set size to start with. The largest training set will be all the rows of the dataset that are not used for validation.
run_learning_curve <- function(model_formula, data, vss=5000, num_tss=30, min_tss=1000){ library(data.table) max_tss <- nrow(data) - vss tss_vector <- seq(min_tss, max_tss, length=num_tss) data.table::rbindlist( lapply (tss_vector, function(tss){ vs_idx <- sample(1:nrow(data), vss) vs <- data[vs_idx,] ts_eligible <- setdiff(1:nrow(data), vs_idx) ts <- data[sample(ts_eligible, tss),] fit <- lm( model_formula, ts) training_error <- rmse(ts$y, predict(fit, ts)) validation_error <- rmse(vs$y, predict(fit, vs)) data.frame(tss=tss, error_type = factor(c("training", "validation"), levels=c("validation", "training")), error=c(training_error, validation_error)) }) ) }
We’ll use a formula that considers all combinations of the input columns. Since these are categorical inputs, they will be represented by dummy variables in the model, with each combination of variable values getting its own coefficient.
learning_curve <- run_learning_curve(y ~ X1*X2*X3, data)
With this example, you get a series of warnings:
## Warning in predict.lm(fit, vs): prediction from a rank-deficient fit may be ## misleading
This is R trying to tell you that you don’t have enough rows to reliably fit all those coefficients. In this simulation, training set sizes above about 7500 don’t trigger the warning, though as we’ll see the curve still shows some evidence of overfitting.
library(ggplot2) ggplot(learning_curve, aes(x=tss, y=error, linetype=error_type)) + geom_line(size=1, col="blue") + xlab("training set size") + geom_hline(y=10, linetype=3)
In this figure, the X-axis represents different training set sizes and the Y-axis represents error. Validation error is shown in the solid blue line on the top part of the figure, and training error is shown by the dashed blue line in the bottom part. As the training set sizes get larger, these curves converge toward a level representing the amount of irreducible error in the data. This plot was generated using a simulated dataset where we know exactly what the irreducible error is; in this case it is the standard deviation of the Gaussian noise we added to the output in the simulation (10; the root mean squared error is essentially the same as standard deviation for reasonably large sample sizes). We don’t expect any model to reliably fit this error since we know it was completely random.
One interesting thing about this simulation is that the underlying system is very simple, yet it can take many thousands of training examples before the validation error of this model gets very close to optimum. In real life, you can easily encounter systems with many more variables, much higher cardinality, far more complex patterns, and of course lots and lots of those unpredictable variations we call “noise”. You can easily encounter situations where truly enormous numbers of samples are needed to train your model without excessive overfitting. On the other hand, if your training and validation error curves have already converged, more data may be superfluous. Learning curves can help you see if you are in a situation where more data is likely to be of benefit for training your model better.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.