Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
As advanced machine learning algorithms are gaining acceptance across many organizations and domains, machine learning interpretability is growing in importance to help extract insight and clarity regarding how these algorithms are performing and why one prediction is made over another. There are many methodologies to interpret machine learning results (i.e. variable importance via permutation, partial dependence plots, local interpretable model-agnostic explanations), and many machine learning R packages implement their own versions of one or more methodologies. However, some recent R packages that focus purely on ML interpretability agnostic to any specific ML algorithm are gaining popularity. One such package is DALEX
and this post covers what this package does (and does not do) so that you can determine if it should become part of your preferred machine learning toolbox.
We implement machine learning models using H2O
, a high performance ML toolkit. Let’s see how DALEX
and H2O
work together to get the best of both worlds with high performance and feature explainability!
Similar Articles You Might Enjoy
Data science for business tutorials:
-
HR Analytics: Using Machine Learning To Predict Employee Turnover
-
Customer Analytics: Using Deep Learning With Keras To Predict Customer Churn
-
Sales Analytics: How To Use Machine Learning To Predict And Optimize Product Backorders
LIME for black-box model interpretability and feature:
Expected Value Framework for ML Classification in Business:
Learning Trajectory
We’ll cover the following topics on DALEX
in this article:
-
Advantages & disadvantages: a quick breakdown of what DALEX does and does not do.
-
Replication requirements: what you’ll need to reproduce the analysis.
-
DALEX procedures: necessary functions for downstream explainers.
-
Residual diagnostics: understanding and comparing errors.
-
Variable importance: permutation based importance score.
-
Predictor-response relationship: PDP and ALE plots.
-
Local interpretation: explanations for a single prediction.
Get The Best Resources In Data Science. Every Friday!
Sign up for our free “5 Topic Friday” Newsletter. Every week, I’ll send you the five coolest topics in data science for business that I’ve found that week. These could be new R packages, free books, or just some fun to end the week on.
Sign Up For Five-Topic-Friday!
DALEX and H2O: Machine Learning Model Interpretability And Feature Explanation
By Brad Boehmke, Director of Data Science at 84.51°
1.0 Advantages & disadvantages
DALEX
is an R package with a set of tools that help to provide Descriptive mAchine Learning EXplanations ranging from global to local interpretability methods. In particular, it makes comparing performance across multiple models convenient. However, as is, there are some problems with this package scaling to wider data sets commonly used by organizations. The following provides a quick list of its pros and cons:
Advantages
- ML model and package agnostic: can be used for any supervised regression and binary classification ML model where you can customize the format of the predicted output.
- Provides convenient approaches to compare results across multiple models.
- Residual diagnostics: allows you to compare residual distributions.
- Variable importance: uses a permutation-based approach for variable importance, which is model agnostic, and accepts any loss function to assess importance.
- Partial dependence plots: leverages the
pdp
package. - Provides an alternative to PDPs for categorical predictor variables (merging path plots).
- Includes a unique and intuitive approach for local intepretation.
Disadvantages
- Some functions do not scale well to wide data (many predictor variables)
- Currently only supports regression and binary classification problems (i.e. no multinomial support).
- Only provides permutation-based variable importance scores (which become slow as number of features increase).
- PDP plots can only be performed one variable at a time (options for two-way interaction PDP plots).
- Does not provide ICE curves.
- Does not provide alternative local interpretation algorithms (i.e. LIME, SHAP values).
2.0 Replication requirements
We leverage the following packages:
To demonstrate model visualization techniques we’ll use the employee attrition data that has been included in the rsample package. This demonstrates a binary classification problem (“Yes” vs. “No”) but the same process that you’ll observe can be used for a regression problem.
To demonstrate DALEX
’s capabilities we’ll use the employee attrition data that has been included in the rsample
package. This demonstrates a binary classification problem (“Yes” vs. “No”) but the same process that you’ll observe can be used for a regression problem.
I perform a few house cleaning tasks on the data prior to converting to an h2o
object and splitting.
NOTE: To use some of DALEX
’s functions, categorical predictor variables need to be converted to factors. Also, I force ordered factors to be unordered as h2o
does not support ordered categorical variables.
We will explore how to visualize a few of the more common machine learning algorithms implemented with h2o
. For brevity I train default models and do not emphasize hyperparameter tuning. The following produces a regularized logistic regression, random forest, and gradient boosting machine models; all of which provide AUCs ranging between .75-.79. Although these models have distinct AUC scores, our objective is to understand how these models come to this conclusion in similar or different ways based on underlying logic and data structure.
3.0 DALEX procedures
The DALEX
architecture can be split into three primary operations:
- Any supervised regression or binary classification model with defined input (X) and output (Y) where the output can be customized to a defined format can be used.
- The machine learning model is converted to an “explainer” object via
DALEX::explain()
, which is just a list that contains the training data and data on the machine learning model. - The explainer object can be passed onto multiple functions that explain different components of the given model.
DALEX Application Process and Architecture
Although DALEX
does have native support for some ML model objects (i.e. lm
, randomForest
), it does not have native many of the preferred ML packages produced more recently (i.e. h2o
, xgboost
, ranger
). To make DALEX
compatible with these objects, we need three things:
x_valid
: Our feature set needs to be in its original form not as anh2o
object.y_valid
: Our response variable needs to be a numeric vector. For regression problems this is simple, as it will already be in this format. For binary classification this requires you to convert the responses to 0/1.pred
: a custom predict function that returns a vector of numeric values. For binary classification problems, this means extracting the probability of the response.
Once you have these three components, you can now create your explainer objects for each ML model. Considering I used a validation set to compute the AUC, we want to use that same validation set for ML interpretability.
4.0 Residual diagnostics
As we saw earlier, the GLM model had the highest AUC followed by the random forest model then GBM. However, a single accuracy metric can be a poor indicator of performance. Assessing residuals of predicted versus actuals can allow you to identify where models deviate in their predictive accuracy. We can use DALEX::model_performance
to compute the predictions and residuals. Printing the output returns residual quantiles and plotting the output allows for easy comparison of absolute residual values across models.
In this example, the residuals are comparing the probability of attrition to the binary attrition value (1-yes, 0-no). Looking at the quantiles you can see that the median residuals are lowest for the GBM model. And looking at the boxplots you can see that the GBM model also had the lowest median absolute residual value. Thus, although the GBM model had the lowest AUC score, it actually performs best when considering the median absoluate residuals. However, you can also see a higher number of residuals in the tail of the GBM residual distribution (left plot) suggesting that there may be a higher number of large residuals compared to the GLM model. This helps to illustrate how your residuals behave similarly and differently across models.
5.0 Variable importance
An important task in ML interpretation is to understand which predictor variables are relatively influential on the predicted outcome. Many ML algorithms have their own unique ways to quantify the importance or relative influence of each feature (i.e. coefficients for linear models, impurity for tree-based models). However, other algorithms like naive Bayes classifiers and support vector machines do not. This makes it difficult to compare variable importance across multiple models.
DALEX
uses a model agnostic variable importance measure computed via permutation. This approach follows the following steps:
For any given loss function do 1: compute loss function for full model (denote _full_model_) 2: randomize response variable, apply given ML, and compute loss function (denote _baseline_) 3: for variable j | randomize values | apply given ML model | compute & record loss function end
To compute the permuted variable importance we use DALEX::variable_importance()
. The printed output just provides a data frame with the output and plotting the three variable importance objects allows us to compare the most influential variables for each model. How do we interpret this plot?
- Left edge of x-axis is the loss function for the
_full_model_
. The default loss function is squared error but any custom loss function can be supplied. - The first item listed in each plot is
_baseline_
. This value represents the loss function when our response values are randomized and should be a good indication of the worst-possible loss function value when there is no predictive signal in the data. - The length of the remaining variables represent the variable importance. The larger the line segment, the larger the loss when that variable is randomized.
The results provide some interesting insights. First, the shifted x-axis left edge helps to illustrate the difference in the RMSE loss between the three models (i.e. GLM model has the lowest RMSE suggesting that the greater number of tail residuals in the GBM model is likely penalizing the RMSE score. Second, we can see which variables are consistently influential across all models (i.e. OverTime
, EnvironmentSatisfaction
, Age
), variables that are influential in two but not all three (i.e. BusinessTravel
, WorkLifeBalance
), and variables which are only influential in one model but not others (i.e. DailyRate
, YearsInCurrentRole
). This helps you to see if models are picking up unique structure in the data or if they are using common logic.
In this example, all three models appear to be largely influenced by the OverTime
, EnvironmentSatisfaction
, Age
, TotalWorkingYears
, and JobLevel
variables. This gives us confidences that these features have strong predictive signals.
TIP: You can incorporate custom loss functions using the loss_function
argument.
One downfall of the permutation-based approach to variable importance is it can become slow. Since the algorithm loops through and applies a model for each predictor variable, the more features in your model the longer it will take. For this example, which includes 30 features, it takes 81 seconds to compute variable importance for all three models. However, when tested on a data set with 100 predictors it took nearly 5 minutes to compute.
TIP: variable_importance
includes an n_sample
argument that, by default, will sample only 1000 observations to try increase the speed of computation. Adjusting n_sample = -1
as I did in the above code chunk just means to use all observations.
6.0 Predictor-response relationship
Once we’ve identified influential variables across all three models, next we likely want to understand how the relationship between these influential variables and the predicted response differ between the models. This helps to indicate if each model is responding to the predictor signal similarly or if one or more models respond differently. For example, we saw that the Age
variable was one of the most influential variables across all three models. The below partial dependence plot illustrates that the GBM and random forest models are using the Age
signal in a similar non-linear manner; however, the GLM model is not able to capture this same non-linear relationship. So although the GLM model may perform better (re: AUC score), it may be using features in biased or misleading ways.
Although you can use PDPs for categorical predictor variables, DALEX
provides merging path plots originally provided by the factoMerger
package. For example, the EnvironmentSatisfaction
variable captures the level of satisfaction regarding the working environment among employees. This variable showed up in all three models’ top 10 most influential variable lists. We can use type = "factor"
to create a merging path plot and it shows very similar results for each model. Those employees that have low level of satisfaction have, on average, higher probabilities of attrition. Whereas, employees with medium to very high have about the same likelihood of attriting. The left side of the plot is the merging path plot, which shows the similarity between groups via hierarchical clustering. It illustrates that employees with medium and high satisfaction are most similar, then these employees are next most similar to employees with very high satisfaction. Then finally, the least similar group is the low satisfaction employees.
7.0 Local interpretation
The previous plots help us to understand our model from a global perspective by illustrating errors, identifying the variables with the largest overall impact, and understanding predictor-response relationships across all observations. However, often, we also need to perform local interpretation which allows us to understand why a particular prediction was made for an observation. Understanding and comparing how a model uses the predictor variables to make a given prediction can provide trust to you (the analyst) and also the stakeholder(s) that will be using the model output for decision making purposes.
Although LIME and SHAP (1, 2) values have recently become popular for local ML interpretation, DALEX
uses a process called break down to compute localized variable importance scores.
There are two break down approaches that can be applied. The default is called step up and the algorithm performs the following steps:
existing_data = validation data set used in explainer new_ob = single observation to perform local interpretation on p = number of predictors l = list of predictors baseline = mean predicted response of existing_data for variable i in {1,...,p} do for variable j in {1,...,l} do | substitue variable j in existing_data with variable j value in new_ob | predicted_j = mean predicted response of altered existing_data | diff_j = absolute difference between baseline - predicted | reset existing_data end | t = variable j with largest diff value | contribution for variable t = diff value for variable t | remove variable t from l end
This is called step up because, essentially, it sweeps through each column, identifies the column with the largest difference score, adds that variable to the list as the most important, sweeps through the remaining columns, identifies the column with the largest score, adds that variable to the list as second most important, etc. until all variables have been assessed.
An alternative approach is called the step down which follows a similar algorithm but rather than remove the variable with the largest difference score on each sweep, it removes the variable with the smallest difference score. Both approaches are analogous to backward stepwise selection where step up removes variables with largest impact and step down removes variables with smallest impact.
To perform the break down algorithm on a single observation, use the DALEX::prediction_breakdown
function. The output is a data frame with class “prediction_breakdown_explainer” that lists the contribution for each variable.
TIP: The default approach is step up but you can perform step down by adding the following argument direction = "down"
.
variable | contribution | variable_name | variable_value | cummulative | |
---|---|---|---|---|---|
1 | (Intercept) | 0.0000000 | Intercept | 1 | 0.0000000 |
JobRole | + JobRole = Laboratory_Technician | 0.0377084 | JobRole | Laboratory_Technician | 0.0377084 |
StockOptionLevel | + StockOptionLevel = 0 | 0.0243714 | StockOptionLevel | 0 | 0.0620798 |
MaritalStatus | + MaritalStatus = Single | 0.0242334 | MaritalStatus | Single | 0.0863132 |
JobLevel | + JobLevel = 1 | 0.0318771 | JobLevel | 1 | 0.1181902 |
Age | + Age = 32 | 0.0261924 | Age | 32 | 0.1443826 |
BusinessTravel | + BusinessTravel = Travel_Frequently | 0.0210466 | BusinessTravel | Travel_Frequently | 0.1654292 |
RelationshipSatisfaction | + RelationshipSatisfaction = High | 0.0108112 | RelationshipSatisfaction | High | 0.1762404 |
Education | + Education = College | 0.0016912 | Education | College | 0.1779315 |
PercentSalaryHike | + PercentSalaryHike = 13 | 0.0001158 | PercentSalaryHike | 13 | 0.1780473 |
We can plot the entire list of contributions for each variable of a particular model. We can see that several predictors have zero contribution, while others have positive and negative contributions. For the GBM model, the predicted value for this individual observation was positively influenced (increased probability of attrition) by variables such as JobRole
, StockOptionLevel
, and MaritalStatus
. Alternatively, variables such as JobSatisfaction
, OverTime
, and EnvironmentSatisfaction
reduced this observations probability of attriting.
For data sets with a small number of predictors, you can compare across multiple models in a similar way as with earlier plotting (plot(new_cust_glm, new_cust_rf, new_cust_gbm)
). However, with wider data sets, this becomes cluttered and difficult to interpret. Alternatively, you can filter for the largest absolute contribution values. This causes the output class to lose its prediction_breakdown_explainer class so we can plot the results with ggplot
.
Each model has a similar prediction that the new observation has a low probability of predicting:
- GLM: .12
- random forest: 0.18
- GBM: 0.06
However, how each model comes to that conclusion in a slightly different way. However, there are several predictors that we see consistently having a positive or negative impact on this observations’ probability of attriting (i.e. OverTime
, EnvironmentSatisfaction
, JobSatisfaction
are reducing this employees probability of attriting while JobLevel
, MaritalStatus
, StockOptionLevel
, and JobLevel
are all increasing the probability of attriting). Consequently, we can have a decent amount of trust that these are strong signals for this observation regardless of model. However, when each model picks up unique signals in variables that the other models do not capture (i.e. DistanceFromHome
, NumCompaniesWorked
), its important to be careful how we communicate these signals to stakeholders. Since these variables do not provide consistent signals across all models we should use domain experts or other sources to help validate whether or not these predictors are trustworthy. This will help us understand if the model is using proper logic that translates well to business decisions.
Unfortunately, a major drawback to DALEX
’s implementation of these algorithm’s is that they are not parallelized. Consequently, wide data sets become extremely slow. For example, performing the previous three prediction_breakdown
functions on this attrition data set with 30 predictors takes about 12 minutes. However, this grows exponentially as more predictors are added. When we apply a single instance of prediction_breakdown
to the Ames housing data (80 predictors), it took over 3 hours to execute!
Looking at the underlying code for the prediction_breakdown
function (it simply calls breakDown::broken.default
), there are opportunities for integrating parallelization capabilities (i.e. via foreach
package). Consequently, prior to adding it to your preferred ML toolkit, you should determine:
- if you are satisfied with its general alorithmic approach,
- do you typically use wide data sets, and if so…
- what is your appetite and bandwidth for integrating parallelization (either in your own version or collaborating with the package authors),
- and how is performance after parallelization (do you see enough speed improvement to justify use).
Next Steps: Take The Data Science For Business With R Course!
If interested in learning more about modeling using H2O and model explanations, definitely check out Data Science For Business With R (DS4B 201-R). Over the course of 10 weeks, the the student learn all of the steps to solve a $15M employee turnover problem with H2O along with a host of other tools, frameworks, and techniques.
The students love it. Here’s a comment we received from one of our students, Siddhartha Choudhury, Data Architect at Accenture.
“To be honest, this course is the best example of an end to end project I have seen from business understanding to communication.”
Siddhartha Choudhury, Data Architect at Accenture
See for yourself why our students have rated Data Science For Business With R (DS4B 201-R) a 9.0 of 10.0 for Course Satisfaction!
Data Science For Business With Python (DS4B 201-P)
P.S. Did we mention with have a DS4B Python Course coming?!?! Well we do! Favio Vazquez, Principle Data Scientist at OXXO, is building the Python equivalent of DS4B 201. The problem changes: Customer Churn! The tools will be H2O, LIME, and a host of other tools implemented in Python. More information is forthcoming. Sign up for Business Science University to stay updated.
About The Author
This MACHINE LEARNING TUTORIAL comes from Brad Boehmke, Director of Data Science at 84.51°, where he and his team develops algorithmic processes, solutions, and tools that enable 84.51° and its analysts to efficiently extract insights from data and provide solution alternatives to decision-makers. Brad is not only a talented data scientist, he’s an adjunct professor at the University of Cincinnati, Wake Forest, and Air Force Institute of Technology. Most importantly, he’s an active contributor to the Data Science Community and he enjoys giving back via advanced machine learning education available at the UC Business Analytics R Programming Guide!
Additional DALEX Resources
The following provides resources to learn more about the DALEX
package:
DALEX
GitHub repo: https://github.com/pbiecek/DALEXbreakDown
package which is called byDALEX
: https://github.com/pbiecek/breakDown- Paper that explains the prediction break down algorithm link
Business Science University
If you are looking to take the next step and learn Data Science For Business (DS4B), Business Science University is for you! Our goal is to empower data scientists through teaching the tools and techniques we implement every day. You’ll learn:
- TO SOLVE A REAL WORLD CHURN PROBLEM: Employee Turnover!
- Data Science Framework: Business Science Problem Framework
- Tidy Eval
- H2O Automated Machine Learning
- LIME Feature Explanations
- Sensitivity Analysis
- Tying data science to financial improvement (ROI-Driven Data Science)
Data Science For Business With R Virtual Workshop
Did you know that an organization that loses 200 high performing employees per year is essentially losing $15M/year in lost productivity? Many organizations don’t realize this because it’s an indirect cost. It goes unnoticed.
What if you could use data science to predict and explain turnover in a way that managers could make better decisions and executives would see results? You will learn the tools to do so in our Virtual Workshop using the R Statistical Programming Language. Here’s an example of a Shiny app you will create.
Shiny App That Predicts Attrition and Recommends Management Strategies, Taught in DS4B 301 (Building A Shiny Web App)
Our first Data Science For Business Virtual Workshop teaches you how to solve this employee attrition problem in four courses that are fully integrated:
- DS4B 201-R: Predicting Employee Attrition with
h2o
andlime
- DS4B 301-R (Coming Soon): Building A
Shiny
Web Application - DS4B 302-R (EST Q4): Data Communication With
RMarkdown
Reports and Presentations - DS4B 303-R (EST Q4): Building An R Package For Your Organization,
tidyattrition
The Virtual Workshop is intended for intermediate and advanced R users. It’s code intensive (like these articles), but also teaches you fundamentals of data science consulting including CRISP-DM and the Business Science Problem Framework. The content bridges the gap between data science and the business, making you even more effective and improving your organization in the process.
Data Science For Business With Python (DS4B 201-P)
Did we mention with have a DS4B Python Course coming?!?! Well we do! Favio Vazquez, Principle Data Scientist at OXXO, is building the Python equivalent of DS4B 201. The problem changes: Customer Churn! The tools will be H2O, LIME, and a host of other tools implemented in Python. More information is forthcoming. Sign up for Business Science University to stay updated.
Don’t Miss A Beat
- Sign up for the Business Science “5 Topic Friday” Newsletter!
- Get started with Business Science University to learn how to solve real-world data science problems from Business Science
- Check out our Open Source Software
Connect With Business Science
If you like our software (anomalize
, tidyquant
, tibbletime
, timetk
, and sweep
), our courses, and our company, you can connect with us:
- business-science on GitHub
- Business Science, LLC on LinkedIn
- bizScienc on twitter
- Business Science, LLC on Facebook
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.