Cross-Validation for Predictive Analytics Using R
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Introduction
Since ancient times, humankind has always avidly sought a way to predict the future. One of the most widely known examples of this kind of activity in the past is the Oracle of Delphi, who dispensed previews of the future to her petitioners in the form of divine inspired prophecies1. In the modern days, the desire to know the future is still of interest to many of us, even if my feeling is that the increasing rapidity of technology innovations we observe everyday has somewhat lessened this instinct: things that few years ago seemed futuristic are now available to the great mass (e.g. the World Wide Web).
Among the many areas of the human being where predictions are highly needed there is business decision making. The tools for formulating predictions about quantities of interest are commonly known as predictive analytics, which is itself an essential part of data science. At the heart of any prediction there is always a model, which typically depends on some unknown structural parameters (e.g. the coefficients of a regression model) as well as on one or more tuning parameters (e.g. the number of basis functions in a smoothing spline or the degree of a polynomial). The former are commonly estimated using a sample of data, while the latter have to be chosen to guarantee that the model itself provides predictions which are accurate enough. Tuning parameters usually regulate the model complexity and hence are a key ingredient for any predictive task. In this blog entry we focus on the most common strategy for eliciting reasonable values for the tuning parameters, the cross-validation approach.
The Bias-Variance Dilemma
The reason why one should care about the choice of the tuning parameter values is because these are intimately linked with the accuracy of the predictions returned by the model. What an analyst typically wants is a model that is able to predict well samples that have not been used for estimating the structural parameters (the so called training sample). In other words, a predictive model is considered good when it is capable of predicting previously unseen samples with high accuracy. The accuracy of a model’s predictions is usually gauged using a loss function. Popular choices for the loss functions are the mean-squared error for continuous outcomes, or the 0-1 loss for a categorical outcome2.
At this point, it is important to distinguish between different prediction error concepts:
- the training error, which is the average loss over the training sample,
- the test error, the prediction error over an independent test sample.
The training error gets smaller as long as the predicted responses are close to the observed responses, and will get larger if for some of the observations, the predicted and observed responses differ substantially. The training error is calculated using the training sample used to fit the model. Clearly, we shouldn’t care too much about the model’s predictive accuracy on the training data. On the contrary, we would like to assess the model’s ability to predict observations never seen during estimation. The test error provides a measure of this ability. In general, one should select the model corresponding to the lowest test error.
The R
code below implements these idea via simulated data. In particular, I simulate 100 training sets each of size 50 from a polynomial regression model, and for each I fit a sequence of cubic spline models with degrees of freedom from 1 to 30.
# Generate the training and test samples seed <- 1809 set.seed(seed) gen_data <- function(n, beta, sigma_eps) { eps <- rnorm(n, 0, sigma_eps) x <- sort(runif(n, 0, 100)) X <- cbind(1, poly(x, degree = (length(beta) - 1), raw = TRUE)) y <- as.numeric(X %*% beta + eps) return(data.frame(x = x, y = y)) } # Fit the models require(splines) n_rep <- 100 n_df <- 30 df <- 1:n_df beta <- c(5, -0.1, 0.004, -3e-05) n_train <- 50 n_test <- 10000 sigma_eps <- 0.5 xy <- res <- list() xy_test <- gen_data(n_test, beta, sigma_eps) for (i in 1:n_rep) { xy[[i]] <- gen_data(n_train, beta, sigma_eps) x <- xy[[i]][, "x"] y <- xy[[i]][, "y"] res[[i]] <- apply(t(df), 2, function(degf) lm(y ~ ns(x, df = degf))) }
The next plot shows the first simulated training sample together with three fitted models corresponding to cubic splines with 1 (green line), 4 (orange line) and 25 (blue line) degrees of freedom respectively. These numbers have been chosen to show the full set of possibilities one may encounter in practice, i.e., either a model with low variability but high bias (degrees of freedom = 1), or a model with high variability but low bias (degrees of freedom = 25), or a model which tries to find a compromise between bias and variance (degrees of freedom = 4).
# Plot the data x <- xy[[1]]$x X <- cbind(1, poly(x, degree = (length(beta) - 1), raw = TRUE)) y <- xy[[1]]$y plot(y ~ x, col = "gray", lwd = 2) lines(x, X %*% beta, lwd = 3, col = "black") lines(x, fitted(res[[1]][[1]]), lwd = 3, col = "palegreen3") lines(x, fitted(res[[1]][[4]]), lwd = 3, col = "darkorange") lines(x, fitted(res[[1]][[25]]), lwd = 3, col = "steelblue") legend(x = "topleft", legend = c("True function", "Linear fit (df = 1)", "Best model (df = 4)", "Overfitted model (df = 25)"), lwd = rep(3, 4), col = c("black", "palegreen3", "darkorange", "steelblue"), text.width = 32, cex = 0.85)
Then, for each training sample and fitted model, I compute the corresponding test error using a large test sample generated from the same (known!) population. These are represented in the following plot together with their averages, which are shown using thicker lines3. The solid points represent the three models illustrated in the previous diagram.
# Compute the training and test errors for each model pred <- list() mse <- te <- matrix(NA, nrow = n_df, ncol = n_rep) for (i in 1:n_rep) { mse[, i] <- sapply(res[[i]], function(obj) deviance(obj)/nobs(obj)) pred[[i]] <- mapply(function(obj, degf) predict(obj, data.frame(x = xy_test$x)), res[[i]], df) te[, i] <- sapply(as.list(data.frame(pred[[i]])), function(y_hat) mean((xy_test$y - y_hat)^2)) } # Compute the average training and test errors av_mse <- rowMeans(mse) av_te <- rowMeans(te) # Plot the errors plot(df, av_mse, type = "l", lwd = 2, col = gray(0.4), ylab = "Prediction error", xlab = "Flexibilty (spline's degrees of freedom [log scaled])", ylim = c(0, 1), log = "x") abline(h = sigma_eps, lty = 2, lwd = 0.5) for (i in 1:n_rep) { lines(df, te[, i], col = "lightpink") } for (i in 1:n_rep) { lines(df, mse[, i], col = gray(0.8)) } lines(df, av_mse, lwd = 2, col = gray(0.4)) lines(df, av_te, lwd = 2, col = "darkred") points(df[1], av_mse[1], col = "palegreen3", pch = 17, cex = 1.5) points(df[1], av_te[1], col = "palegreen3", pch = 17, cex = 1.5) points(df[which.min(av_te)], av_mse[which.min(av_te)], col = "darkorange", pch = 16, cex = 1.5) points(df[which.min(av_te)], av_te[which.min(av_te)], col = "darkorange", pch = 16, cex = 1.5) points(df[25], av_mse[25], col = "steelblue", pch = 15, cex = 1.5) points(df[25], av_te[25], col = "steelblue", pch = 15, cex = 1.5) legend(x = "top", legend = c("Training error", "Test error"), lwd = rep(2, 2), col = c(gray(0.4), "darkred"), text.width = 0.3, cex = 0.85)
One can see that the training errors decrease monotonically as the model gets more complicated (and less smooth). On the other side, even if the test error initially decreases, from a certain flexibility level on it starts increasing again. The change point occurs in correspondence of the orange model, that is, the model that provides a good compromise between bias and variance. The reason why the test error starts increasing for degrees of freedom larger than 3 or 4 is the so called overfitting problem. Overfitting is the tendency of a model to adapt too well to the training data, at the expense of generalization to previously unseen data points. In other words, an overfitted model fits the noise in the data rather than the actual underlying relationships among the variables. Overfitting usually occurs when a model is unnecessarily complex.
It is possible to show that the (expected) test error for a given observation in the test set can be decomposed into the sum of three components, namely
Clearly, the situation illustrated above is only ideal, because in practice:
- We do not know the true model that generates the data. Indeed, our models are typically more or less mis-specified.
- We do only have a limited amount of data.
One way to overcome these hurdles and approximate the search for the optimal model is to use the cross-validation approach.
A Solution: Cross-Validation
In essence, all these ideas bring us to the conclusion that it is not advisable to compare the predictive accuracy of a set of models using the same observations used for estimating the models. Therefore, for assessing the models’ predictive performance we should use an independent set of data (the test sample). Then, the model showing the lowest error on the test sample (i.e., the lowest test error) is identified as the best.
Unfortunately, in many cases it is not possible to draw a (possibly large) independent set of observations for testing the models’ performance, because collecting data is typically an expensive activity. The immediate reaction to this statement is that we can solve this issue by splitting the available data in two sets, one of which will be used for training while the other is used for testing. The split is usually performed randomly to guarantee that the two parts have the same distribution4.
Even if data splitting provides an unbiased estimate of the test error, it is often quite noisy. A possible solution5 is to use cross-validation (CV). In its basic version, the so called
The code below illustrates
set.seed(seed) n_train <- 100 xy <- gen_data(n_train, beta, sigma_eps) x <- xy$x y <- xy$y fitted_models <- apply(t(df), 2, function(degf) lm(y ~ ns(x, df = degf))) mse <- sapply(fitted_models, function(obj) deviance(obj)/nobs(obj)) n_test <- 10000 xy_test <- gen_data(n_test, beta, sigma_eps) pred <- mapply(function(obj, degf) predict(obj, data.frame(x = xy_test$x)), fitted_models, df) te <- sapply(as.list(data.frame(pred)), function(y_hat) mean((xy_test$y - y_hat)^2)) n_folds <- 10 folds_i <- sample(rep(1:n_folds, length.out = n_train)) cv_tmp <- matrix(NA, nrow = n_folds, ncol = length(df)) for (k in 1:n_folds) { test_i <- which(folds_i == k) train_xy <- xy[-test_i, ] test_xy <- xy[test_i, ] x <- train_xy$x y <- train_xy$y fitted_models <- apply(t(df), 2, function(degf) lm(y ~ ns(x, df = degf))) x <- test_xy$x y <- test_xy$y pred <- mapply(function(obj, degf) predict(obj, data.frame(ns(x, df = degf))), fitted_models, df) cv_tmp[k, ] <- sapply(as.list(data.frame(pred)), function(y_hat) mean((y - y_hat)^2)) } cv <- colMeans(cv_tmp) require(Hmisc) plot(df, mse, type = "l", lwd = 2, col = gray(0.4), ylab = "Prediction error", xlab = "Flexibilty (spline's degrees of freedom [log scaled])", main = paste0(n_folds, "-fold Cross-Validation"), ylim = c(0.1, 0.8), log = "x") lines(df, te, lwd = 2, col = "darkred", lty = 2) cv_sd <- apply(cv_tmp, 2, sd)/sqrt(n_folds) errbar(df, cv, cv + cv_sd, cv - cv_sd, add = TRUE, col = "steelblue2", pch = 19, lwd = 0.5) lines(df, cv, lwd = 2, col = "steelblue2") points(df, cv, col = "steelblue2", pch = 19) legend(x = "topright", legend = c("Training error", "Test error", "Cross-validation error"), lty = c(1, 2, 1), lwd = rep(2, 3), col = c(gray(0.4), "darkred", "steelblue2"), text.width = 0.4, cex = 0.85)
Often a “one-standard error” rule is used with cross-validation, according to which one should choose the most parsimonious model whose error is no more than one standard error above the error of the best model. In the example above, the best model (that for which the CV error is minimized) uses 3 degrees of freedom, which also satisfies the requirement of the one-standard error rule.
The case where
The code below implements LOOCV using the same example I discussed so far. The next plot shows that most of the times LOOCV does not provide dramatically different results with respect to CV.
require(splines) loocv_tmp <- matrix(NA, nrow = n_train, ncol = length(df)) for (k in 1:n_train) { train_xy <- xy[-k, ] test_xy <- xy[k, ] x <- train_xy$x y <- train_xy$y fitted_models <- apply(t(df), 2, function(degf) lm(y ~ ns(x, df = degf))) pred <- mapply(function(obj, degf) predict(obj, data.frame(x = test_xy$x)), fitted_models, df) loocv_tmp[k, ] <- (test_xy$y - pred)^2 } loocv <- colMeans(loocv_tmp) plot(df, mse, type = "l", lwd = 2, col = gray(.4), ylab = "Prediction error", xlab = "Flexibilty (spline's degrees of freedom [log scaled])", main = "Leave-One-Out Cross-Validation", ylim = c(.1, .8), log = "x") lines(df, cv, lwd = 2, col = "steelblue2", lty = 2) lines(df, loocv, lwd = 2, col = "darkorange") legend(x = "topright", legend = c("Training error", "10-fold CV error", "LOOCV error"), lty = c(1, 2, 1), lwd = rep(2, 3), col = c(gray(.4), "steelblue2", "darkorange"), text.width = .3, cex = .85)
Doing Cross-Validation With R
: the caret
Package
There are many R
packages that provide functions for performing different flavors of CV. In my opinion, one of the best implementation of these ideas is available in the caret
package by Max Kuhn (see Kuhn and Johnson 2013)7. The aim of the caret
package (acronym of classification and regression training) is to provide a very general and efficient suite of commands for building and assessing predictive models. It allows to compare the predictive accuracy of a multitude of models (currently more than 200), including the most recent ones from machine learning. The comparison of different models can be done using cross-validation as well as with other approaches. The package also provides many options for data pre-processing. It is not my aim to provide here a thorough presentation of all the package features. Rather, I will focus only on a handful of its functions, those that allow to perform CV. For more details on the other package functions, you can inspect the package documentation and its website. To illustrate these feature I will use some data for a credit scoring application whose data can be found here.
Since credit scoring is a classification problem, I will use the number of misclassified observations as the loss measure. The data set contains information about 4,455 individuals for the following variables:
Variable | Description |
---|---|
Status | credit status |
Seniority | job seniority (years) |
Home | type of home ownership |
Time | time of requested loan |
Age | client’s age |
Marital | marital status |
Records | existence of records |
Job | type of job |
Expenses | amount of expenses |
Income | amount of income |
Assets | amount of assets |
Debt | amount of debt |
Amount | amount requested of loan |
Price | price of good |
Here I use the “cleaned” version of the data set, where some pre-processing has already been performed (i.e., removal of few observations, imputation of missing values and categorization of continuous predictors). The tidy data are contained in the file CleanCreditScoring.csv
.
require(RCurl) require(prettyR) url <- "https://raw.githubusercontent.com/gastonstat/CreditScoring/master/CleanCreditScoring.csv" cs_data <- getURL(url) cs_data <- read.csv(textConnection(cs_data)) describe(cs_data)
## Description of cs_data
## ## Numeric ## mean median var sd valid.n ## Seniority 7.99 5.00 66.85 8.18 4446 ## Time 46.45 48.00 214.56 14.65 4446 ## Age 37.08 36.00 120.70 10.99 4446 ## Expenses 55.60 51.00 381.06 19.52 4446 ## Income 140.63 124.00 6428.50 80.18 4446 ## Assets 5354.95 3000.00 133040726.62 11534.33 4446 ## Debt 342.26 0.00 1549264.52 1244.69 4446 ## Amount 1038.76 1000.00 225385.62 474.75 4446 ## Price 1462.48 1400.00 395081.60 628.56 4446 ## Finrat 72.62 77.10 415.78 20.39 4446 ## Savings 3.86 3.12 13.89 3.73 4446 ## ## Factor ## ## Status good bad ## Count 3197.00 1249.00 ## Percent 71.91 28.09 ## Mode good ## ## Home owner rent parents other priv ignore ## Count 2106.00 973.00 782.00 319.00 246.00 20.00 ## Percent 47.37 21.88 17.59 7.17 5.53 0.45 ## Mode owner ## ## Marital married single separated widow divorced ## Count 3238.00 973.00 130.00 67.00 38.00 ## Percent 72.83 21.88 2.92 1.51 0.85 ## Mode married ## ## Records no_rec yes_rec ## Count 3677.0 769.0 ## Percent 82.7 17.3 ## Mode no_rec ## ## Job fixed freelance partime others ## Count 2803.00 1021.00 451.00 171.00 ## Percent 63.05 22.96 10.14 3.85 ## Mode fixed ## ## seniorityR sen (-1,1] sen (3,8] sen (14,99] sen (1,3] sen (8,14] ## Count 1042.00 978 880.00 789.00 757.00 ## Percent 23.44 22 19.79 17.75 17.03 ## Mode sen (-1,1] ## ## timeR time (48,99] time (24,36] time (36,48] time (12,24] time (0,12] ## Count 1949.00 991.00 885.00 441.00 180.00 ## Percent 43.84 22.29 19.91 9.92 4.05 ## Mode time (48,99] ## ## ageR age (30,40] age (40,50] age (25,30] age (0,25] age (50,99] ## Count 1415.00 900.00 781.00 699.00 651.00 ## Percent 31.83 20.24 17.57 15.72 14.64 ## Mode age (30,40] ## ## expensesR exp (0,40] exp (40,50] exp (50,60] exp (60,80] exp (80,1e+04] ## Count 1219.00 999.00 979.00 798.00 451.00 ## Percent 27.42 22.47 22.02 17.95 10.14 ## Mode exp (0,40] ## ## incomeR inc (80,110] inc (140,190] inc (0,80] inc (110,140] ## Count 954.00 915.00 886.00 866.00 ## Percent 21.46 20.58 19.93 19.48 ## ## incomeR inc (190,1e+04] ## Count 825.00 ## Percent 18.56 ## Mode inc (80,110] ## ## assetsR asset (-1,0] asset (3e+03,5e+03] asset (8e+03,1e+06] ## Count 1626.00 937.00 719.00 ## Percent 36.57 21.08 16.17 ## ## assetsR asset (0,3e+03] asset (5e+03,8e+03] ## Count 626.00 538.0 ## Percent 14.08 12.1 ## Mode asset (-1,0] ## ## debtR debt (-1,0] debt (500,1.5e+03] debt (2.5e+03,1e+06] debt (0,500] ## Count 3667.00 230.00 197.00 193.00 ## Percent 82.48 5.17 4.43 4.34 ## ## debtR debt (1.5e+03,2.5e+03] ## Count 159.00 ## Percent 3.58 ## Mode debt (-1,0] ## ## amountR am (900,1.1e+03] am (1.1e+03,1.4e+03] am (600,900] am (0,600] ## Count 945.00 925.00 911.00 895.00 ## Percent 21.26 20.81 20.49 20.13 ## ## amountR am (1.4e+03,1e+05] ## Count 770.00 ## Percent 17.32 ## Mode am (900,1.1e+03] ## ## priceR priz (1.5e+03,1.8e+03] priz (1e+03,1.3e+03] priz (0,1e+03] ## Count 1028.00 985.00 821.00 ## Percent 23.12 22.15 18.47 ## ## priceR priz (1.8e+03,1e+05] priz (1.3e+03,1.5e+03] ## Count 811.00 801.00 ## Percent 18.24 18.02 ## Mode priz (1.5e+03,1.8e+03] ## ## finratR finr (80,90] finr (90,100] finr (50,70] finr (70,80] finr (0,50] ## Count 995.00 960.00 954.00 821.00 716.0 ## Percent 22.38 21.59 21.46 18.47 16.1 ## Mode finr (80,90] ## ## savingsR sav (2,4] sav (0,2] sav (6,99] sav (4,6] sav (-99,0] ## Count 1396.0 1111.00 827.0 814.00 298.0 ## Percent 31.4 24.99 18.6 18.31 6.7 ## Mode sav (2,4]
The caret
package provides functions for splitting the data as well as functions that automatically do all the job for us, namely functions that create the resampled data sets, fit the models, and evaluate performance.
Among the functions for data splitting I just mention createDataPartition()
and createFolds()
. The former allows to create one or more test/training random partitions of the data, while the latter randomly splits the data into
require(caret) classes <- cs_data[, "Status"] predictors <- cs_data[, -match(c("Status", "Seniority", "Time", "Age", "Expenses", "Income", "Assets", "Debt", "Amount", "Price", "Finrat", "Savings"), colnames(cs_data))] train_set <- createDataPartition(classes, p = 0.8, list = FALSE) str(train_set)
## int [1:3558, 1] 1 2 3 4 5 6 7 8 9 11 ... ## - attr(*, "dimnames")=List of 2 ## ..$ : NULL ## ..$ : chr "Resample1"
train_predictors <- predictors[train_set, ] train_classes <- classes[train_set] test_predictors <- predictors[-train_set, ] test_classes <- classes[-train_set] set.seed(seed) cv_splits <- createFolds(classes, k = 10, returnTrain = TRUE) str(cv_splits)
## List of 10 ## $ Fold01: int [1:4002] 1 2 3 4 5 6 7 8 9 10 ... ## $ Fold02: int [1:4002] 2 3 4 5 6 7 8 9 10 11 ... ## $ Fold03: int [1:4001] 1 2 3 4 5 6 7 8 9 10 ... ## $ Fold04: int [1:4002] 1 2 3 4 5 6 7 8 9 10 ... ## $ Fold05: int [1:4001] 1 2 3 4 7 8 10 11 12 13 ... ## $ Fold06: int [1:4001] 1 2 3 4 5 6 7 8 9 10 ... ## $ Fold07: int [1:4001] 1 2 4 5 6 7 9 10 11 12 ... ## $ Fold08: int [1:4002] 1 2 3 4 5 6 7 8 9 10 ... ## $ Fold09: int [1:4001] 1 2 3 5 6 8 9 11 12 13 ... ## $ Fold10: int [1:4001] 1 3 4 5 6 7 8 9 10 11 ...
To automatically split the data, fit the models and assess the performance, one can use the train()
function in the caret
package. The code below shows an example of the train()
function on the credit scoring data by modeling the outcome using all the predictors available with a penalized logistic regression. More specifically, I use the glmnet
package (Friedman, Hastie, and Tibshirani 2008), that fits a generalized linear model via penalized maximum likelihood. The algorithm implemented in the package computes the regularization path for the elastic-net penalty over a grid of values for the regularization parameter
The train()
function requires the model formula together with the indication of the model to fit and the grid of tuning parameter values to use. In the code below this grid is specified through the tuneGrid
argument, while trControl
provides the method to use for choosing the optimal values of the tuning parameters (in our case, 10-fold cross-validation). Finally, the preProcess
argument allows to apply a series of pre-processing operations on the predictors (in our case, centering and scaling the predictor values).
require(glmnet)
## Warning: package 'Matrix' was built under R version 3.2.5
set.seed(seed) cs_data_train <- cs_data[train_set, ] cs_data_test <- cs_data[-train_set, ] glmnet_grid <- expand.grid(alpha = c(0, .1, .2, .4, .6, .8, 1), lambda = seq(.01, .2, length = 20)) glmnet_ctrl <- trainControl(method = "cv", number = 10) glmnet_fit <- train(Status ~ ., data = cs_data_train, method = "glmnet", preProcess = c("center", "scale"), tuneGrid = glmnet_grid, trControl = glmnet_ctrl) glmnet_fit
## glmnet ## ## 3558 samples ## 26 predictor ## 2 classes: 'bad', 'good' ## ## Pre-processing: centered (68), scaled (68) ## Resampling: Cross-Validated (10 fold) ## Summary of sample sizes: 3202, 3203, 3202, 3203, 3202, 3202, ... ## Resampling results across tuning parameters: ## ## alpha lambda Accuracy Kappa ## 0.0 0.01 0.8021427 0.4613907413 ## 0.0 0.02 0.7998916 0.4520486081 ## 0.0 0.03 0.7976412 0.4402614685 ## 0.0 0.04 0.7987633 0.4407093800 ## 0.0 0.05 0.7982015 0.4355350784 ## 0.0 0.06 0.7979182 0.4313111542 ## 0.0 0.07 0.7953893 0.4205306747 ## 0.0 0.08 0.7931413 0.4105376360 ## 0.0 0.09 0.7922978 0.4050557210 ## 0.0 0.10 0.7892072 0.3920192662 ## 0.0 0.11 0.7841454 0.3722554927 ## 0.0 0.12 0.7824600 0.3640031420 ## 0.0 0.13 0.7807739 0.3557226473 ## 0.0 0.14 0.7793694 0.3482341774 ## 0.0 0.15 0.7807746 0.3491274474 ## 0.0 0.16 0.7810571 0.3472621824 ## 0.0 0.17 0.7796511 0.3411817028 ## 0.0 0.18 0.7796526 0.3373484610 ## 0.0 0.19 0.7807746 0.3374411775 ## 0.0 0.20 0.7785267 0.3288713594 ## 0.1 0.01 0.8015794 0.4596697325 ## 0.1 0.02 0.8010160 0.4506840219 ## 0.1 0.03 0.7987672 0.4392696923 ## 0.1 0.04 0.7962367 0.4270137579 ## 0.1 0.05 0.7953909 0.4179208158 ## 0.1 0.06 0.7922986 0.4039787586 ## 0.1 0.07 0.7892056 0.3879423318 ## 0.1 0.08 0.7880804 0.3782808312 ## 0.1 0.09 0.7841454 0.3596294655 ## 0.1 0.10 0.7807723 0.3432057135 ## 0.1 0.11 0.7748718 0.3176779656 ## 0.1 0.12 0.7726223 0.3058039618 ## 0.1 0.13 0.7706544 0.2940236563 ## 0.1 0.14 0.7672804 0.2743480275 ## 0.1 0.15 0.7641905 0.2598076390 ## 0.1 0.16 0.7613815 0.2443603691 ## 0.1 0.17 0.7585718 0.2288614463 ## 0.1 0.18 0.7549153 0.2107358484 ## 0.1 0.19 0.7518231 0.1940038500 ## 0.1 0.20 0.7495751 0.1809109004 ## 0.2 0.01 0.8021396 0.4602812440 ## 0.2 0.02 0.7982070 0.4415693991 ## 0.2 0.03 0.7951147 0.4239150748 ## 0.2 0.04 0.7917424 0.4066236599 ## 0.2 0.05 0.7911742 0.3954568138 ## 0.2 0.06 0.7875218 0.3766126319 ## 0.2 0.07 0.7813388 0.3485492065 ## 0.2 0.08 0.7765588 0.3236774842 ## 0.2 0.09 0.7723398 0.3024505090 ## 0.2 0.10 0.7706583 0.2883765872 ## 0.2 0.11 0.7664425 0.2668721852 ## 0.2 0.12 0.7627868 0.2474032680 ## 0.2 0.13 0.7554779 0.2144108884 ## 0.2 0.14 0.7509796 0.1882593847 ## 0.2 0.15 0.7445165 0.1570413702 ## 0.2 0.16 0.7419869 0.1374198064 ## 0.2 0.17 0.7372108 0.1088243186 ## 0.2 0.18 0.7346811 0.0884942046 ## 0.2 0.19 0.7304645 0.0666304068 ## 0.2 0.20 0.7293401 0.0587304848 ## 0.4 0.01 0.7998916 0.4500521540 ## 0.4 0.02 0.7951147 0.4262079087 ## 0.4 0.03 0.7894920 0.3986179726 ## 0.4 0.04 0.7833091 0.3669902064 ## 0.4 0.05 0.7793694 0.3406658720 ## 0.4 0.06 0.7765596 0.3202875649 ## 0.4 0.07 0.7726246 0.2966774741 ## 0.4 0.08 0.7639128 0.2564191601 ## 0.4 0.09 0.7546352 0.2079469868 ## 0.4 0.10 0.7456401 0.1554338724 ## 0.4 0.11 0.7369299 0.1037171598 ## 0.4 0.12 0.7324316 0.0745712422 ## 0.4 0.13 0.7296218 0.0585154515 ## 0.4 0.14 0.7282173 0.0515741506 ## 0.4 0.15 0.7270937 0.0460444276 ## 0.4 0.16 0.7237213 0.0275193624 ## 0.4 0.17 0.7203474 0.0079998432 ## 0.4 0.18 0.7189429 0.0000000000 ## 0.4 0.19 0.7189429 0.0000000000 ## 0.4 0.20 0.7189429 0.0000000000 ## 0.6 0.01 0.7998940 0.4479646834 ## 0.6 0.02 0.7900538 0.4056859000 ## 0.6 0.03 0.7830274 0.3693670539 ## 0.6 0.04 0.7765572 0.3347009487 ## 0.6 0.05 0.7757161 0.3177186986 ## 0.6 0.06 0.7655990 0.2676239096 ## 0.6 0.07 0.7549153 0.2082278135 ## 0.6 0.08 0.7422693 0.1330653578 ## 0.6 0.09 0.7332758 0.0786721840 ## 0.6 0.10 0.7296218 0.0585273347 ## 0.6 0.11 0.7279364 0.0501825990 ## 0.6 0.12 0.7223176 0.0195861846 ## 0.6 0.13 0.7189429 0.0000000000 ## 0.6 0.14 0.7189429 0.0000000000 ## 0.6 0.15 0.7189429 0.0000000000 ## 0.6 0.16 0.7189429 0.0000000000 ## 0.6 0.17 0.7189429 0.0000000000 ## 0.6 0.18 0.7189429 0.0000000000 ## 0.6 0.19 0.7189429 0.0000000000 ## 0.6 0.20 0.7189429 0.0000000000 ## 0.8 0.01 0.7959582 0.4342802453 ## 0.8 0.02 0.7869623 0.3901784671 ## 0.8 0.03 0.7776832 0.3436570802 ## 0.8 0.04 0.7745925 0.3177549651 ## 0.8 0.05 0.7650348 0.2651758625 ## 0.8 0.06 0.7518270 0.1887158977 ## 0.8 0.07 0.7363681 0.0981410803 ## 0.8 0.08 0.7299027 0.0607058454 ## 0.8 0.09 0.7270937 0.0452185157 ## 0.8 0.10 0.7189429 0.0008725808 ## 0.8 0.11 0.7189429 0.0000000000 ## 0.8 0.12 0.7189429 0.0000000000 ## 0.8 0.13 0.7189429 0.0000000000 ## 0.8 0.14 0.7189429 0.0000000000 ## 0.8 0.15 0.7189429 0.0000000000 ## 0.8 0.16 0.7189429 0.0000000000 ## 0.8 0.17 0.7189429 0.0000000000 ## 0.8 0.18 0.7189429 0.0000000000 ## 0.8 0.19 0.7189429 0.0000000000 ## 0.8 0.20 0.7189429 0.0000000000 ## 1.0 0.01 0.7920241 0.4209892004 ## 1.0 0.02 0.7807794 0.3655088353 ## 1.0 0.03 0.7754360 0.3289251638 ## 1.0 0.04 0.7686897 0.2864096482 ## 1.0 0.05 0.7526713 0.1943914570 ## 1.0 0.06 0.7349620 0.0898775655 ## 1.0 0.07 0.7290600 0.0557404582 ## 1.0 0.08 0.7209100 0.0108497047 ## 1.0 0.09 0.7189429 0.0000000000 ## 1.0 0.10 0.7189429 0.0000000000 ## 1.0 0.11 0.7189429 0.0000000000 ## 1.0 0.12 0.7189429 0.0000000000 ## 1.0 0.13 0.7189429 0.0000000000 ## 1.0 0.14 0.7189429 0.0000000000 ## 1.0 0.15 0.7189429 0.0000000000 ## 1.0 0.16 0.7189429 0.0000000000 ## 1.0 0.17 0.7189429 0.0000000000 ## 1.0 0.18 0.7189429 0.0000000000 ## 1.0 0.19 0.7189429 0.0000000000 ## 1.0 0.20 0.7189429 0.0000000000 ## ## Accuracy was used to select the optimal model using the largest value. ## The final values used for the model were alpha = 0 and lambda = 0.01.
trellis.par.set(caretTheme()) plot(glmnet_fit, scales = list(x = list(log = 2)))
The previous plot shows the “accuracy”, that is the percentage of correctly classified observations, for the penalized logistic regression model with each combination of the two tuning parameters
Then, it is possible to predict new samples with the identified optimal model using the predict
method:
pred_classes <- predict(glmnet_fit, newdata = cs_data_test) table(pred_classes)
## pred_classes ## bad good ## 172 716
pred_probs <- predict(glmnet_fit, newdata = cs_data_test, type = "prob") head(pred_probs)
## bad good ## 1 0.07142151 0.9285785 ## 2 0.04231067 0.9576893 ## 3 0.03736701 0.9626330 ## 4 0.14796622 0.8520338 ## 5 0.12416939 0.8758306 ## 6 0.42359516 0.5764048
If you need to deepen your knowledge of predictive analytics, you may find something interesting in the R Course Data Mining with R.
Stay tuned for the next article on the MilanoR blog!
References
Efron, B., and R. Tibshirani. 1993. An Introduction to the Bootstrap. CRC Press.
Friedman, J., T. Hastie, and R. Tibshirani. 2008. “Regularization Paths for Generalized Linear Models via Coordinate Descent.” Journal of Statistical Software 33 (1): 1–22.
Hastie, T., R. Tibshirani, and J. Friedman. 2009. The Elements of Statistical Learning. 2nd ed. Springer.
James, G., D. Witten, T. Hastie, and R. Tibshirani. 2013. An Introduction to Statistical Learning. Springer.
Kuhn, M., and K. Johnson. 2013. Applied Predictive Modeling. Springer.
Zou, H., and T. Hastie. 2005. “Regularization and Variable Selection via the Elastic Net.” Journal of the Royal Statistical Association B 67 (2): 301–20.
- By the way, it seems that the oracular powers appeared to be associated with hallucinogenic gases that puffed out from the temple floor.↩
- You can find a thorough formal illustration of all these concepts in Hastie, Tibshirani, and Friedman (2009), Chapter 7. A somewhat simpler presentation can be found in James et al. (2013).↩
- More precisely, the light red curves correspond to what is called conditional test error, which means that each curve is conditional on the corresponding training sample. The heavier red curve correspond to the expected test error. In general, we would like to focus on the conditional test error for the particular training sample we have. However, this curve is very difficult to be estimated and in practice the expected test error is typically targeted. As we will see, cross-validation is a method for estimating the expected test error. For more details see Hastie, Tibshirani, and Friedman (2009).↩
- A variant of the purely random split is to use stratified random sampling in order to create subsets that are balanced with respect to the outcome. This is useful in particular in classification problems when one class has a disproportionately small frequency compared to the others.↩
- An alternative approach for the same objective is the bootstrap, that won’t be illustrated here (see Efron and Tibshirani (1993)).↩
- More precisely, cross-validation provides an estimate of the expected test error.↩
- The
boot
package contains also a nice function calledcv.glm
, which implementsk ">kk-fold cross-validation for generalized linear models.↩
The post Cross-Validation for Predictive Analytics Using R appeared first on MilanoR.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.