Comparing machine learning models in R
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
by Joseph Rickert
While preparing for the DataWeek R Bootcamp that I conducted this week I came across the following gem. This code, based directly on a Max Kuhn presentation of a couple years back, compares the efficacy of two machine learning models on a training data set.
#----------------------------------------- # SET UP THE PARAMETER SPACE SEARCH GRID ctrl <- trainControl(method="repeatedcv", # use repeated 10fold cross validation repeats=5, # do 5 repititions of 10-fold cv summaryFunction=twoClassSummary, # Use AUC to pick the best model classProbs=TRUE) # Note that the default search grid selects 3 values of each tuning parameter # grid <- expand.grid(.interaction.depth = seq(1,7,by=2), # look at tree depths from 1 to 7 .n.trees=seq(10,100,by=5), # let iterations go from 10 to 100 .shrinkage=c(0.01,0.1)) # Try 2 values of the learning rate parameter # BOOSTED TREE MODEL set.seed(1) names(trainData) trainX <-trainData[,4:61] registerDoParallel(4) # Registrer a parallel backend for train getDoParWorkers() system.time(gbm.tune <- train(x=trainX,y=trainData$Class, method = "gbm", metric = "ROC", trControl = ctrl, tuneGrid=grid, verbose=FALSE)) #--------------------------------- # SUPPORT VECTOR MACHINE MODEL # set.seed(1) registerDoParallel(4,cores=4) getDoParWorkers() system.time( svm.tune <- train(x=trainX, y= trainData$Class, method = "svmRadial", tuneLength = 9, # 9 values of the cost function preProc = c("center","scale"), metric="ROC", trControl=ctrl) # same as for gbm above ) #----------------------------------- # COMPARE MODELS USING RESAPMLING # Having set the seed to 1 before running gbm.tune and svm.tune we have generated paired samplesfor comparing models using resampling. # # The resamples function in caret collates the resampling results from the two models rValues <- resamples(list(svm=svm.tune,gbm=gbm.tune)) rValues$values #--------------------------------------------- # BOXPLOTS COMPARING RESULTS bwplot(rValues,metric="ROC") # boxplot
After setting up a grid to search the parameter space of a model, the train() function from the caret package is used used to train a generalized boosted regression model (gbm) and a support vector machine (svm). Setting the seed produces paired samples and enables the two models to be compared using the resampling technique described in Hothorn at al, "The design and analysis of benchmark experiments", Journal of Computational and Graphical Statistics (2005) vol 14 (3) pp 675-699
The performance metric for the comparison is the ROC curve. From examing the boxplots of the sampling distributions for the two models it is apparent that, in this case, the gbm has the advantage.
Also, notice that the call to registerDoParallel() permits parallel execution of the training algorithms. (The taksbar showed all foru cores of my laptop maxed out at 100% utilization.)
I chose this example because I wanted to show programmers coming to R for the first time that the power and productivity of R comes not only from the large number of machine learning models implemented, but also from the tremendous amount of infrastructure that package authors have built up, making it relatively easy to do fairly sophisticated tasks in just a few lines of code.
All of the code for this example along with the rest of the my code from the Datweek R Bootcamp is available on GitHub.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.