The Geometry of Classifiers
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
As John mentioned in his last post, we have been quite interested in the recent study by Fernandez-Delgado, et.al., “Do we Need Hundreds of Classifiers to Solve Real World Classification Problems?” (the “DWN study” for short), which evaluated 179 popular implementations of common classification algorithms over 120 or so data sets, mostly from the UCI Machine Learning Repository. For fun, we decided to do a follow-up study, using their data and several classifier implementations from scikit-learn
, the Python machine learning library. We were interested not just in classifier accuracy, but also in seeing if there is a “geometry” of classifiers: which classifiers produce predictions patterns that look similar to each other, and which classifiers produce predictions that are quite different? To examine these questions, we put together a Shiny app to interactively explore how the relative behavior of classifiers changes for different types of data sets.
The Classifiers
We looked at seven classifiers from scikit-learn
:
- SVM (
sklearn.svm.SVC
) with the radial basis function kernel, gamma=0.001 and C=10 - Random Forest (
sklearn.ensemble.RandomForestClassifier
) with 100 trees, each limited to a maximum depth of 10 - Gradient Boosting (
sklearn.ensemble.GradientBoostingClassifier
) - Decision Tree (
sklearn.tree.DecisionTreeClassifier
) - Gaussian Naive Bayes (
sklearn.naive_bayes.GaussianNB
) - Logistic Regression (
sklearn.linear_model.LogisticRegression
) - K-Nearest Neighbors (
sklearn.neighbors.KNeighborsClassifier
) with K=5
We predicted class probabilities for each target class (using the predict_proba()
method), rather than simply predicting class. Note that the decision tree and K-Nearest Neighbors implementations only return 0/1 predictions — that is, they only predict class — even when using their predict_proba()
methods. We made no effort to optimize the classifier performance on a per-data-set basis.
The Data Sets
We used the 123 pre-prepared data sets compiled by the authors of the DWN study; these datasets have been centered, scaled, and stored in a format (.arff
) that can be easily read into Python. The data sets vary in size from 16 to over 130,000 rows, and from 3 to 262 variables. They ranged from having 2 to 100 target classes (most were two-class problems; the median number of classes was three; the average, about seven).
As we noted in our previous post, eight of the 123 data sets in the collection encoded categorical variables as real numbers, by hand-converting them to ordered levels (a variation on this trick is to hash the strings that describe each category). As our previous post pointed out, this is not the correct way to encode categoricals; you should instead convert the categories to indicator variables. However, rather than re-doing the encodings, we left them as-is for our study. This disadvantages logistic regression and SVM, which cannot undo the information loss that results from this encoding, and probably disadvantages K-Nearest Neighbors as well; however the number of affected data sets is relatively small.
Some, but not all, of the data sets in the DWM study were broken into training and test sets; we only looked at the training sets. If the training set was larger than 500 rows, then we randomly selected 100 items and held them out as the test set and used the remaining data to train the classifiers. If the training set was smaller than 500 rows, then we used hold-one-out cross validation on 100 random rows (or on all the rows, for data sets with less than 100 datums): that is, for each row, we trained on all the data except that row, and then classified the held-out row.
The Questions
We considered two questions in our study. First, which classification methods are most accurate in general — that is, which methods identify the correct class most of the time. Second, which classifiers behave most like each other, in terms of the class probabilities that they assign to each of the target classes.
To answer the first question, we first consider each classifiers’ accuracy rate over all data sets:
Over all the data sets we considered, random forest (highlighted in blue) had the highest accuracy, correctly identifying the class of a held-out instance about 82% of the time (considered over all test sets).
Note that we are using accuracy as a quick and convenient indication of classifier quality. See this post (and our discussion later in the present article) for why accuracy should not be considered an end-all measure.
We also counted up how often a classifier did the best on a given test set. For each data set D, we call the best that any classifier did on that set N_best. For example, if the best any classifier did on data set D was 97/100, then N_best for D is 97/100. Every classifier that achieved 97/100 on D is then counted as part of “the winner’s circle”. We then tallied how many times each classifier ended up in the winner’s circle over all the data sets under consideration. Here are the results for all the data sets:
This graph shows that over all the data sets we considered, gradient boosting and random forest (highlighted in blue) reached the best achievable accuracy (N_best) the most often: on 39 of the 123 data sets. This is consistent with the findings of the DWN paper, which noted that Random Forest is often among the most accurate classifiers over this sample of data sets. Notice that the sum of counts over all classifiers is greater than 123, meaning that quite often more than one classifier achieved N_best.
To answer the second question, we took the vector of class probabilities returned by each classifier for a given data set D, and measured squared Euclidean distance between each pair of models, and between each model and ground truth. The (squared) distance between two models is then the sum of the squared Euclidean distance between them over all data sets. We can use these distances to determine the similarity between various models. One way to visualize this is via a dendrogram:
Over this set of data sets, gradient boosting and random forest behaved very similarly (that is, tended to return similar class probability estimates), which is not too surprising, since they are both ensembles over decision trees. It’s not visible from the dendrogram, but they are also the closest to ground truth. Logistic regression and SVM are also quite similar to each other. We found this second observation somewhat surprising, but the similarity of logistic regression and SVM has been observed previously, for example by Jian Zhang, et. al. for text classification (“Modified Logistic Regression: An Approximation to SVM and Its Applications in Large-Scale Text Classification”, ICML 2003). Nearest neighbor behaves somewhat similarly to the first four classifiers, but the naive Bayes and decision tree classifiers do not.
Alternatively, we can visualize the approximate distances between the classifiers using multidimensional scaling. This is a screenshot of a rotatable visualization of the model distances in 3-D that uses the package rgl
(and shinyRGL
).
These distances are only approximate, but they are consistent with dendrogram above.
As a side note, there are other ways to define similarity between classifiers. Since we are looking at distributions over classes, we could consider Kullback-Leibler divergence (although KL divergence is not symmetric, and hence not a metric); or we could try cosine similarity. We chose squared Euclidean distance as the most straightforward measure.
Drilling Down Further
Since we’ve collected all the measurements, we can further explore how the results vary with different properties of the data sets: their size, the number of target classes, or even the “shape” of the data: narrow (few variables relative to the number of datums) or wide (many variables relative to the number of rows). To do that we built a Shiny App that lets us produce the above visualizations for different slices of the data. Since we only have 123 data sets, not all possible combinations of the data set parameters are well represented, but we can still find interesting results. For example, here we see that for small, moderately narrow data sets, logistic regression is frequently a good choice.
Overall, we noticed that random forest and gradient boosting were strong performers over a variety of data set conditions. You can explore for yourself; our app is online at https://win-vector.shinyapps.io/ExploreModels/ .
Caveats
While the results we got were suggestive, and are consistent with the results of the DWN paper, there are a lot of caveats. First of all, the data sets we used do not really represent all the kinds of data sets that we might encounter in the wild. There are no text data sets in this collection, and the variables tend to be numeric rather than categorical (and when they are categorical, the current data treatment is not ideal, as we discussed above). Many of the data sets are far, far smaller than we would expect to encounter in a typical data science project, and they are all fairly clean; much of the data treatment and feature engineering was done before the data was submitted to the UCI repository.
Second, even though we used accuracy to evaluate the classifiers, this may not be the criterion you want. Classifier performance is (at least) two-dimensional; generally you are trading off one performance metric (like precision) for another (like recall). Simple accuracy may not capture what is most important to you in a classifier.
Other points: we could probably get better performance out of many of the classifiers with some per-dataset tuning, rather than picking one setting for all the data sets. It’s also worth remembering that our results are strictly observations about the scikit-learn
implementations of these algorithms; different implementations could behave differently relative to each other. For example, the decision tree classifier does not actually return class probabilities, despite the fact that it could (R’s decision tree implementation can return class probabilities, for one). In addition, many of these implementations do not implement true multinomial classification for the multiclass case (even though in theory, there may be a multinomial version of the algorithm); instead they use a set of binomial classifiers in either a one-against-one (compare all pairs of classes) or one-against-all (compare each class against all the rest) approach. It may be the case in some situations that one approach will work better than the other.
One data set characteristic that we didn’t investigate, but feel would be interesting, is the rarity of the target class (assuming there is a single target class). In fraud detection, for example, the target class is (hopefully) rare. Some classification approaches will do much better in that situation than others.
Caveats notwithstanding, we feel that the DWM paper (and our little follow-on) represent some good, useful effort toward characterizing classifier performance. We don’t expect that there is a single, one-size-fits-all, best classification algorithm, and we’d like to see some science around what types of data sets (and which problem situations or use cases) represent “sweet spots” for different algorithms. We would also like to see more studies about the similarities of different algorithms. After all, why use a computationally expensive algorithm in a situation where a simpler approach is likely to do just as well?
In short, we’d like to see see further studies similar to the DWN paper, hopefully done over data sets that better represent the situations that data scientists are likely to encounter. Hopefully, exploring the current results with our Shiny app will give you some ideas for further work.
Materials
We have made our source code available on Github. The repository includes
- The Python script for scoring all the data sets against the different models
- The R script for creating the summary tables used by the Shiny app
- The code for the Shiny app itself
Some of the relative paths for reading or writing files may not be correct, but it should be easy to figure out how to fix them. You will need the ggplot2
, sqldf
, rgl
and ShinyRGL
packages to run the R code. Tutorials and documentation for Shiny, as well as directions for building and launching an app (it’s quite easy in RStudio, and not much harder without it) are available here.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.