Site icon R-bloggers

Comparing machine learning models in R

[This article was first published on Revolutions, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

by Joseph Rickert

While preparing for the DataWeek R Bootcamp that I conducted this week I came across the following gem. This code, based directly on a Max Kuhn presentation of a couple years back, compares the efficacy of two machine learning models on a training data set.

#-----------------------------------------
# SET UP THE PARAMETER SPACE SEARCH GRID
ctrl <- trainControl(method="repeatedcv",	        # use repeated 10fold cross validation
		     repeats=5,	                        # do 5 repititions of 10-fold cv
		     summaryFunction=twoClassSummary,	# Use AUC to pick the best model
		     classProbs=TRUE)
# Note that the default search grid selects 3 values of each tuning parameter
#
grid <- expand.grid(.interaction.depth = seq(1,7,by=2), # look at tree depths from 1 to 7
		    .n.trees=seq(10,100,by=5),	        # let iterations go from 10 to 100
		    .shrinkage=c(0.01,0.1))		# Try 2 values of the learning rate parameter
 
# BOOSTED TREE MODEL										
set.seed(1)
names(trainData)
trainX <-trainData[,4:61]
 
registerDoParallel(4)		                        # Registrer a parallel backend for train
getDoParWorkers()
 
system.time(gbm.tune <- train(x=trainX,y=trainData$Class,
				method = "gbm",
				metric = "ROC",
				trControl = ctrl,
				tuneGrid=grid,
				verbose=FALSE))
 
#---------------------------------
# SUPPORT VECTOR MACHINE MODEL
#
set.seed(1)
registerDoParallel(4,cores=4)
getDoParWorkers()
system.time(
	svm.tune <- train(x=trainX, y= trainData$Class,
				method = "svmRadial",
				tuneLength = 9,		# 9 values of the cost function
				preProc = c("center","scale"),
				metric="ROC",
				trControl=ctrl)	        # same as for gbm above
)	
#-----------------------------------
# COMPARE MODELS USING RESAPMLING
# Having set the seed to 1 before running gbm.tune and svm.tune we have generated paired samplesfor comparing models using resampling.
#
# The resamples function in caret collates the resampling results from the two models
rValues <- resamples(list(svm=svm.tune,gbm=gbm.tune))
rValues$values
#---------------------------------------------
# BOXPLOTS COMPARING RESULTS
bwplot(rValues,metric="ROC")		# boxplot

After setting up a grid to search the parameter space of a model, the train() function from the caret package is used used to train a generalized boosted regression model (gbm) and a support vector machine (svm). Setting the seed produces paired samples and enables the two models to be compared using the resampling technique described in Hothorn at al, "The design and analysis of benchmark experiments", Journal of Computational and Graphical Statistics (2005) vol 14 (3) pp 675-699

The performance metric for the comparison is the ROC curve. From examing the boxplots of the sampling distributions for the two models it is apparent that, in this case, the gbm has the advantage.

Also, notice that the call to registerDoParallel() permits parallel execution of the training algorithms. (The taksbar showed all foru cores of my laptop maxed out at 100% utilization.)

I chose this example because I wanted to show programmers coming to R for the first time that the power and productivity of R comes not only from the large number of machine learning models implemented, but also from the tremendous amount of infrastructure that package authors have built up, making it relatively easy to do fairly sophisticated tasks in just a few lines of code.

All of the code for this example along with the rest of the my code from the Datweek R Bootcamp is available on GitHub.

To leave a comment for the author, please follow the link and comment on their blog: Revolutions.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.