Site icon R-bloggers

caret::createFolds() vs. createMultiFolds()

[This article was first published on Data Science Depot, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Summary

Recently a user posted a question on Stackoverflow, asserting that caret::createFolds() behaves differently than createMultiFolds(). The questioner argued that while createFolds() samples without replacement, createMultiFolds() samples with replacement. Our analysis demonstrates that the two functions behave consistently, creating k folds where each observation in the data frame participates in one of the k hold out groups. That said, by default createFolds() returns a list containing the indexes of held out observations for each fold, whereas createMultiFolds() returns a list of observations included in each fold for each repetition.

Behavior of createFolds()

createFolds() splits the data into k folds. Output from the function is a list of observation indices that are held out from each fold, not the rows included in each fold. We can see this by creating a table of all the fold data using the mtcars data frame as follows.

set.seed(123)
folds <- createFolds(mtcars$am, k = 5)
table(unlist(folds))

…and the output:

> table(unlist(folds))

 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
 1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
27 28 29 30 31 32
 1  1  1  1  1  1

If we use the returnTrain = TRUE argument with createFolds(), it returns the index of observations included in each fold, as illustrated in the other answer. For k = 5, we expect each observation to be used in 4 of the folds, and confirm this with the following code.

set.seed(123)
folds <- createFolds(mtcars$am, k = 5, returnTrain = TRUE)
table(unlist(folds))

…and the output:

> table(unlist(folds))

 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
 4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4
27 28 29 30 31 32
 4  4  4  4  4  4

Behavior of createMultiFolds()

The createMultiFolds() function is used to to define resampling plans for studies with repeated k-fold cross validation. By default, createMultiFolds() returns a list containing one element for each level of k= and each repetition of times=

We can illustrate that each observation is used in 4 of the 5 folds as follows.

set.seed(123)
folds1 <- createMultiFolds(y = mtcars$am, k = 5, times = 1)
table(unlist(folds1))

…and the output:

> table(unlist(folds1))

 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
 4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4
27 28 29 30 31 32
 4  4  4  4  4  4

Generating equivalent results with createFolds() and createMultiFolds()

Setting returnTrain = TRUE with createFolds() causes it to return the same output as createMultiFolds() with times = 1.

We can compare the contents of folds and folds with lapply() and all() as follows.

# compare folds to folds1
lapply(1:5,function(x){
     all(folds1[[x]],folds[[x]])
})

[[1]]
[1] TRUE

[[2]]
[1] TRUE

[[3]]
[1] TRUE

[[4]]
[1] TRUE

[[5]]
[1] TRUE

If we set times = 2, we expect each observation to be included in 8 of the 10 folds.

set.seed(123)
folds <- createMultiFolds(y = mtcars$am, k = 5, times = 2)
table(unlist(folds))

…and the output:

> table(unlist(folds))

 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
 8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8
27 28 29 30 31 32
 8  8  8  8  8  8

Conclusions

In both functions caret uses sampling to ensure that each observation is included in the hold out group 1 time across the k folds for each repetition of times =, within the constraint that observations for each value of the dependent variable passed to the function are proportionally distributed in the in sample and out of sample components of each fold.

In the case of a small data set such as mtcars, it’s not easy for the algorithm to split effectively, as we can see when we run tables to compare in sample / holdout vs. mtcars$am.

set.seed(123)
folds <- createFolds(mtcars$am, k = 5)
table(unlist(folds))
lapply(folds,function(x){
     holdout <- rep(FALSE,nrow(mtcars))
     holdout[x] <- TRUE
     table(holdout,mtcars$am)
})

$Fold1

holdout  0  1
  FALSE 16 10
  TRUE   3  3

$Fold2

holdout  0  1
  FALSE 15 10
  TRUE   4  3

$Fold3

holdout  0  1
  FALSE 14 11
  TRUE   5  2

$Fold4

holdout  0  1
  FALSE 15 11
  TRUE   4  2

$Fold5

holdout  0  1
  FALSE 16 10
  TRUE   3  3

Each fold contains 6 or 7 observations in the hold out set, with a minimum of 2 manual transmission cars (am = 1) in each hold out set.

With default arguments, createFolds() returns the indexes of held out observations rather than included observations. createFolds(x,k,returnTrain=TRUE) behaves exactly the same as createMultiFolds(x,k,times=1).

To leave a comment for the author, please follow the link and comment on their blog: Data Science Depot.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.