Error metrics for multi-class problems in R: beyond Accuracy and Kappa
[This article was first published on Modern Toolmaking, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
The caret package for R provides a variety of error metrics for regression models and 2-class classification models, but only calculates Accuracy and Kappa for multi-class models. Therefore, I wrote the following function to allow caret:::train to calculate a wide variety of error metrics for multi-class problems:Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Multi-Class Summary Function | |
#Based on caret:::twoClassSummary | |
require(compiler) | |
multiClassSummary <- cmpfun(function (data, lev = NULL, model = NULL){ | |
#Load Libraries | |
require(Metrics) | |
require(caret) | |
#Check data | |
if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) | |
stop("levels of observed and predicted data do not match") | |
#Calculate custom one-vs-all stats for each class | |
prob_stats <- lapply(levels(data[, "pred"]), function(class){ | |
#Grab one-vs-all data for the class | |
pred <- ifelse(data[, "pred"] == class, 1, 0) | |
obs <- ifelse(data[, "obs"] == class, 1, 0) | |
prob <- data[,class] | |
#Calculate one-vs-all AUC and logLoss and return | |
cap_prob <- pmin(pmax(prob, .000001), .999999) | |
prob_stats <- c(auc(obs, prob), logLoss(obs, cap_prob)) | |
names(prob_stats) <- c('ROC', 'logLoss') | |
return(prob_stats) | |
}) | |
prob_stats <- do.call(rbind, prob_stats) | |
rownames(prob_stats) <- paste('Class:', levels(data[, "pred"])) | |
#Calculate confusion matrix-based statistics | |
CM <- confusionMatrix(data[, "pred"], data[, "obs"]) | |
#Aggregate and average class-wise stats | |
#Todo: add weights | |
class_stats <- cbind(CM$byClass, prob_stats) | |
class_stats <- colMeans(class_stats) | |
#Aggregate overall stats | |
overall_stats <- c(CM$overall) | |
#Combine overall with class-wise stats and remove some stats we don't want | |
stats <- c(overall_stats, class_stats) | |
stats <- stats[! names(stats) %in% c('AccuracyNull', | |
'Prevalence', 'Detection Prevalence')] | |
#Clean names and return | |
names(stats) <- gsub('[[:blank:]]+', '_', names(stats)) | |
return(stats) | |
}) |
This function was prompted by a question on cross-validated, asking what the optimal value of k is for a knn model fit to the iris dataset. I wanted to look at statistics besides accuracy and kappa, so I wrote a wrapper function for caret:::confusionMatrix and auc and logLoss from the Metric packages. Use the following code to fit a knn model to the iris dataset, aggregate all of the metrics, and save a plot for each metric to a pdf file:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#CLEAR WORKSPACE | |
rm(list = ls(all = TRUE)) | |
gc(reset=TRUE) | |
#Setup parallel cluster | |
#If running on the command line of linux, use method='fork' | |
library(doParallel) | |
cl <- makeCluster(detectCores(), type='PSOCK') | |
registerDoParallel(cl) | |
#Fit model | |
library(caret) | |
set.seed(19556) | |
model <- train( | |
Species~., | |
data=iris, | |
method='knn', | |
tuneGrid=expand.grid(.k=1:30), | |
metric='Accuracy', | |
trControl=trainControl( | |
method='repeatedcv', | |
number=10, | |
repeats=15, | |
classProbs=TRUE, | |
summaryFunction=multiClassSummary)) | |
#Stop parallel cluster | |
stopCluster(cl) | |
#Save pdf of plots | |
dev.off() | |
pdf('plots.pdf') | |
for(stat in c('Accuracy', 'Kappa', 'AccuracyLower', 'AccuracyUpper', 'AccuracyPValue', | |
'Sensitivity', 'Specificity', 'Pos_Pred_Value', | |
'Neg_Pred_Value', 'Detection_Rate', 'ROC', 'logLoss')) { | |
print(plot(model, metric=stat)) | |
} | |
dev.off() |
This demonstrates that, depending on what metric you use, you will end up with a different model. For example, Accuracy seems to peak around 17:
While AUC and logLoss seem to peak around 6:
You can also increase the number of cross-validation repeats, or use a different method of re-sampling, such as bootstrap re-sampling.
To leave a comment for the author, please follow the link and comment on their blog: Modern Toolmaking.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.