Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
I have spent the last couple of days adding functionality for performing repeated cross-validation to cvms and groupdata2. In this quick post I will show an example.
In cross-validation, we split our training set into a number (often denoted “k”) of groups called folds. We repeatedly train our machine learning model on k-1 folds and test it on the last fold, such that each fold becomes test set once. Then we average the results and celebrate with food and music.
The benefits of using groupdata2 to create the folds are 1) that it allows us to balance the ratios of our output classes (or simply a categorical column, if we are working with linear regression instead of classification), and 2) that it allows us to keep all observations with a specific ID (e.g. participant/user ID) in the same fold to avoid leakage between the folds.
The benefit of cvms is that it trains all the models and outputs a tibble (data frame) with results, predictions, model coefficients, and other sweet stuff, which is easy to add to a report or do further analyses on. It even allows us to cross-validate multiple model formulas at once to quickly compare them and select the best model.
Repeated Cross-validation
In repeated cross-validation we simply repeat this process a couple of times, training the model on more combinations of our training set observations. The more combinations, the less one bad split of the data would impact our evaluation of the model.
For each repetition, we evaluate our model as we would have in regular cross-validation. Then we average the results from the repetitions and go back to food and music.
groupdata2
As stated, the role of groupdata2 is to create the folds. Normally it creates one column in the dataset called “.folds”, which contains a fold identifier for each observation (e.g. 1,1,2,2,3,3,1,1,3,3,2,2). In repeated cross-validation it simply creates multiple of such fold columns (“.folds_1”, “.folds_2”, etc.). It also makes sure they are unique, so we actually train on different subsets.
# Install groupdata2 and cvms from github
devtools::install_github("ludvigolsen/groupdata2")
devtools::install_github("ludvigolsen/cvms")
# Attach packages
library(cvms) # cross_validate()
library(groupdata2) # fold()
library(knitr) # kable()
library(dplyr) # %>%
# Set seed for reproducibility
set.seed(7)
# Fold data
# Create 3 fold columns
# cat_col is the categorical column to balance between folds
# id_col is the column with IDs. Observations with the same ID will be put in the same fold.
# num_fold_cols determines the number of fold columns, and thereby the number of repetitions.
data <- fold(data, k = 4, cat_col = 'diagnosis', id_col = 'participant', num_fold_cols = 3)
# Show first 15 rows of data
data %>% head(10) %>% kable()
cvms
In the cross_validate function, we specify our model formula for a logistic regression that classifies diagnosis. cvms currently supports linear regression and logistic regression, including mixed effects modelling. In the fold_cols (previously called folds_col), we specify the fold column names.
CV <- cross_validate(data, "diagnosis~score",
fold_cols = c('.folds_1','.folds_2','.folds_3'),
family='binomial',
REML = FALSE)
# Show results
CV
Due to the number of metrics and useful information, it helps to break up the output into parts:
CV %>% select(1:7) %>% kable()
CV %>% select(8:14) %>% kable()
CV$Predictions[[1]] %>% head() %>% kable()
CV$`Confusion Matrix`[[1]] %>% head() %>% kable()
CV$Coefficients[[1]] %>% head() %>% kable()
CV$Results[[1]] %>% select(1:8) %>% kable()
We could have trained multiple models at once by simply adding more model formulas. That would add rows to the output, making it easy to compare the models.
The linear regression version has different evaluation metrics. These are listed in the help page at ?cross_validate
.
Conclusion
cvms and groupdata2 now have the functionality for performing repeated cross-validation. We have briefly talked about this technique and gone through a short example. Check out cvms for more
Indlægget Repeated cross-validation in cvms and groupdata2 blev først udgivet på .
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.