Site icon R-bloggers

It’s the interactions

[This article was first published on R – Michael's and Christian's Blog, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

What makes a ML model a black-box? It is the interactions. Without any interactions, the ML model is additive and can be exactly described.

Studying interaction effects of ML models is challenging. The main XAI approaches are:

  1. Looking at ICE plots, stratified PDP, and/or 2D PDP.
  2. Study vertical scatter in SHAP dependence plots, or even consider SHAP interaction values.
  3. Check partial-dependence based H-statistics introduced in Friedman and Popescu (2008), or related statistics.

This post is mainly about the third approach. Its beauty is that we get information about all interactions. The downside: it is as good/bad as partial dependence functions. And: the statistics are computationally very expensive to compute (of order n^2).

Different R packages offer some of these H-statistics, including {iml}, {gbm}, {flashlight}, and {vivid}. They all have their limitations. This is why I wrote the new R package {hstats}:

In Python, there is the very interesting project artemis. I will write a post on it later.

Statistics supported by {hstats}

Furthermore, a global measure of non-additivity (proportion of prediction variability unexplained by main effects), and a measure of feature importance is available. For technical details and references, check the following pdf or github.

Classification example

Let’s fit a probability random forest on iris species.

library(ranger)
library(ggplot2)
library(hstats)

v <- setdiff(colnames(iris), "Species")
fit <- ranger(Species ~ ., data = iris, probability = TRUE, seed = 1)
s <- hstats(fit, v = v, X = iris)  # 8 seconds run-time
s
# Proportion of prediction variability unexplained by main effects of v:
#      setosa  versicolor   virginica 
# 0.002705945 0.065629375 0.046742035

plot(s, normalize = FALSE, squared = FALSE) +
  ggtitle("Unnormalized statistics") +
  scale_fill_viridis_d(begin = 0.1, end = 0.9)

ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |> 
  plot(center = TRUE) +
  ggtitle("Centered ICE plots")
Unnormalized H-statistics, i.e., values are roughly on the scale of the predictions (here: probabilities).
Centered ICE plots per class.

Interpretation:

DALEX example

Here, we consider a random forest regression on “Sepal.Length”.

library(DALEX)
library(ranger)
library(hstats)

set.seed(1)

fit <- ranger(Sepal.Length ~ ., data = iris)
ex <- explain(fit, data = iris[-1], y = iris[, 1])

s <- hstats(ex)  # 2 seconds
s  # Non-additivity index 0.054
plot(s)
plot(ice(ex, v = "Sepal.Width", BY = "Petal.Width"), center = TRUE)
H-statistics
Centered ICE plot of strongest relative interactions.

Interpretation

Try it out!

The complete R script can be found here. More examples and background can be found on the Github page of the project.

To leave a comment for the author, please follow the link and comment on their blog: R – Michael's and Christian's Blog.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Exit mobile version