Site icon R-bloggers

Cross Validation in R with Example

[This article was first published on finnstats », and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

What Does Cross-Validation Mean?

Cross-validation is a statistical approach for determining how well the results of a statistical investigation generalize to a different data set.

Cross-validation is commonly employed in situations where the goal is prediction and the accuracy of a predictive model’s performance must be estimated.

We explored different stepwise regressions in a previous article and came up with different models, now let’s see how cross-validation can help us choose the best model.

Which model is the most accurate at forecasting?

To begin, we need to load our dataset:

library(purrr)
library(dplyr)
head(mtcars)
                   mpg cyl disp  hp drat    wt  qsec vs am gear carb
Mazda RX4         21.0   6  160 110 3.90 2.620 16.46  0  1    4    4
Mazda RX4 Wag     21.0   6  160 110 3.90 2.875 17.02  0  1    4    4
Datsun 710        22.8   4  108  93 3.85 2.320 18.61  1  1    4    1
Hornet 4 Drive    21.4   6  258 110 3.08 3.215 19.44  1  0    3    1
Hornet Sportabout 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2
Valiant           18.1   6  225 105 2.76 3.460 20.22  1  0    3    1

There are several ways to accomplish this, but we’ll utilize the modelr package to assist us.

To begin, we divided our data into two categories:

KNN Algorithm Machine Learning » Classification & Regression »

K Fold Cross-Validation in R

library(modelr)
cv  <- crossv_kfold(mtcars, k = 5)
cv
train                test                .id  
  <named list>         <named list>        <chr>
1 <resample [25 x 11]> <resample [7 x 11]> 1    
2 <resample [25 x 11]> <resample [7 x 11]> 2    
3 <resample [26 x 11]> <resample [6 x 11]> 3    
4 <resample [26 x 11]> <resample [6 x 11]> 4    
5 <resample [26 x 11]> <resample [6 x 11]> 5    

Our data has been divided into five sets, each with a training set and a test set.

For each training set, we now use map to fit a model. In actuality, our three models will be fitted separately.

Decision Trees in R » Classification & Regression »

Model Fitting

models1  <- map(cv$train, ~lm(mpg ~ wt + cyl + hp, data = .))
models2  <- map(cv$train, ~lm(mpg ~ wt + qsec + am, data = .))
models3  <- map(cv$train, ~lm(mpg ~ wt + qsec + hp, data = .))

Now it’s time to make some predictions. To accomplish this, I created a tiny function that takes the models and test data and returns the predictions. It’s worth noting that I use as.data.frame to get the data ().

get_pred  <- function(model, test_data){
  data  <- as.data.frame(test_data)
  pred  <- add_predictions(data, model)
  return(pred)
}
pred1  <- map2_df(models1, cv$test, get_pred, .id = "Run")
pred2  <- map2_df(models2, cv$test, get_pred, .id = "Run")
pred3  <- map2_df(models3, cv$test, get_pred, .id = "Run")

Now we will calculate the MSE for each group:

datatable editor-DT package in R » Shiny, R Markdown & R »

MSE1  <- pred1 %>% group_by(Run) %>%
  summarise(MSE = mean( (mpg - pred)^2))
MSE1
Run     MSE
  <chr> <dbl>
1 1      7.36
2 2      1.27
3 3      5.31
4 4      8.84
5 5     13.8 
MSE2  <- pred2 %>% group_by(Run) %>%
  summarise(MSE = mean( (mpg - pred)^2))
MSE2
 Run     MSE
  <chr> <dbl>
1 1      6.45
2 2      2.27
3 3      7.71
4 4      9.56
5 5     15.4 
MSE3  <- pred3 %>% group_by(Run) %>%
  summarise(MSE = mean( (mpg - pred)^2))
MSE3
Run     MSE
  <chr> <dbl>
1 1      6.45
2 2      2.27
3 3      7.71
4 4      9.56
5 5     15.4 

Please note your machine uses a different random number than mine to construct the folds, your numbers may differ somewhat from mine.

pipe operator in R-Simplify Your Code with %>% »

Finally, consider the following comparison of the three models:

mean(MSE1$MSE)
[1] 7.31312
mean(MSE2$MSE)
[1] 8.277929
mean(MSE2$MSE)
[1] 9.333679

In this case, values are really close however, it appears that model1 is the best model!

apply family in r apply(), lapply(), sapply(), mapply() and tapply() »

The post Cross Validation in R with Example appeared first on finnstats.

To leave a comment for the author, please follow the link and comment on their blog: finnstats ».

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.