Interactions – where are you?

[This article was first published on R – Michael's and Christian's Blog, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

This question sends shivers down the poor modelers spine…

The {hstats} R package introduced in our last post measures their strength using Friedman’s H-statistics, a collection of statistics based on partial dependence functions.

On Github, the preview version of {hstats} 1.0.0 out – I will try to bring it to CRAN in about one week (October 2023). Until then, try it via devtools::install_github("mayer79/hstats")

The current version offers:

  • H statistics per feature, feature pair, and feature triple
  • multivariate predictions at no additional cost
  • a convenient API
  • other important tools from explainable ML:
    • performance calculations
    • permutation importance (e.g., to select features for calculating H-statistics)
    • partial dependence plots (including grouping, multivariate, multivariable)
    • individual conditional expectations (ICE)
  • Case-weights are available for all methods, which is important, e.g., in insurance applications.

This post has two parts:

  1. Example with house-prices and XGBoost
  2. Naive benchmark against {iml}, {DALEX}, and my old {flashlight}.

1. Example

Let’s model logarithmic sales prices of houses sold in Miami Dade County, a dataset prepared by Prof. Dr. Steven Bourassa, and available in {shapviz}. We use XGBoost with interaction constraints to provide a model additive in all structure information, but allowing for interactions between latitude/longitude for a flexible representation of geographic effects.

The following code prepares the data, splits the data into train and validation, and then fits an XGBoost model.

library(hstats)
library(shapviz)
library(xgboost)
library(ggplot2)

# Data preparation
colnames(miami) <- tolower(colnames(miami))
miami <- transform(miami, log_price = log(sale_prc))
x <- c("tot_lvg_area", "lnd_sqfoot", "latitude", "longitude",
       "structure_quality", "age", "month_sold")
coord <- c("longitude", "latitude")

# Modeling
set.seed(1)
ix <- sample(nrow(miami), 0.8 * nrow(miami))
train <- data.frame(miami[ix, ])
valid <- data.frame(miami[-ix, ])
y_train <- train$log_price
y_valid <- valid$log_price
X_train <- data.matrix(train[x])
X_valid <- data.matrix(valid[x])

dtrain <- xgb.DMatrix(X_train, label = y_train)
dvalid <- xgb.DMatrix(X_valid, label = y_valid)

ic <- c(
  list(which(x %in% coord) - 1),
  as.list(which(!x %in% coord) - 1)
)

# Fit via early stopping
fit <- xgb.train(
  params = list(
    learning_rate = 0.15, 
    objective = "reg:squarederror", 
    max_depth = 5,
    interaction_constraints = ic
  ),
  data = dtrain,
  watchlist = list(valid = dvalid),
  early_stopping_rounds = 20,
  nrounds = 1000,
  callbacks = list(cb.print.evaluation(period = 100))
)

Now it is time for a compact analysis with {hstats} to interpret the model:

average_loss(fit, X = X_valid, y = y_valid)  # 0.0247 MSE -> 0.157 RMSE

perm_importance(fit, X = X_valid, y = y_valid) |> 
  plot()

# Or combining some features
v_groups <- list(
  coord = c("longitude", "latitude"),
  size = c("lnd_sqfoot", "tot_lvg_area"),
  condition = c("age", "structure_quality")
)
perm_importance(fit, v = v_groups, X = X_valid, y = y_valid) |> 
  plot()

H <- hstats(fit, v = x, X = X_valid)
H
plot(H)
plot(H, zero = FALSE)
h2_pairwise(H, zero = FALSE, squared = FALSE, normalize = FALSE)
partial_dep(fit, v = "tot_lvg_area", X = X_valid) |> 
  plot()
partial_dep(fit, v = "tot_lvg_area", X = X_valid, BY = "structure_quality") |> 
  plot(show_points = FALSE)
plot(ii <- ice(fit, v = "tot_lvg_area", X = X_valid))
plot(ii, center = TRUE)

# Spatial plots
g <- unique(X_valid[, coord])
pp <- partial_dep(fit, v = coord, X = X_valid, grid = g)
plot(pp, d2_geom = "point", alpha = 0.5, size = 1) + 
  coord_equal()

# Takes some seconds because it generates the last plot per structure quality
partial_dep(fit, v = coord, X = X_valid, grid = g, BY = "structure_quality") |>
  plot(pp, d2_geom = "point", alpha = 0.5) +
  coord_equal()
)

Results summarized by plots

Permutation importance

Figure 1: Permutation importance (4 repetitions) on the validation data. Error bars show standard errors of the estimated increase in MSE from shuffling feature values.
Figure 2: Feature groups can be shuffled together – accounting for issues of permutation importance with highly correlated features

H-Statistics

Let’s now move on to interaction statistics.

Figure 3: Overall and pairwise H-statistics. Overall H^2 gives the proportion of prediction variability explained by all interactions of the feature. By default, {hstats} picks the five features with largest H^2 and calculates their pairwise H^2. This explains why not all 21 feature pairs appear in the figure on the right-hand side. Pairwise H^2 is differently scaled than overall H^2: It gives the proportion of joint effect variability of the two features explained by their interaction.
Figure 4: Use “zero = FALSE” to drop variable (pairs) with value 0.

PDPs and ICEs

Figure 5: A partial dependence plot of living area.
Figure 6: Stratification shows indeed: no interactions between structure quality and living area.
Figure 7: ICE plots also show no interations with any other feature. The interaction constraints of XGBoost did a good job.
Figure 8: This two-dimensional PDP evaluated over all unique coordinates shows a realistic profile of house prices in Miami Dade County (mind the log scale).
Figure 8: Same, but grouped by structure quality (5 is best). Since there is no interaction between location and structure quality, the plots are just shifted versions of each other. (You can’t really see it on the plots.)

Naive Benchmark

All methods in {hstats} are optimized for speed. But how fast are they compared to other implementations? Note that: this is just a simple benchmark run on a Windows notebook with Intel i7-8650U CPU.

Note that {iml} offers a parallel backend, but we could not make it run with XGBoost and Windows. Let me know how fast it is using parallelism and Linux!

Setup + benchmark on permutation importance

Always using the full validation dataset and 10 repetitions.

library(iml)  # Might benefit of multiprocessing, but on Windows with XGB models, this is not easy
library(DALEX)
library(ingredients)
library(flashlight)
library(microbenchmark)

set.seed(1)

# iml
predf <- function(object, newdata) predict(object, data.matrix(newdata[x]))
mod <- Predictor$new(fit, data = as.data.frame(X_valid), y = y_valid, 
                     predict.function = predf)

# DALEX
ex <- DALEX::explain(fit, data = X_valid, y = y_valid)

# flashlight (my slightly old fashioned package)
fl <- flashlight(
  model = fit, data = valid, y = "log_price", predict_function = predf, label = "lm"
)
  
# Permutation importance: 10 repeats over full validation data (~2700 rows)
microbenchmark(
  iml = FeatureImp$new(mod, n.repetitions = 10, loss = "mse", compare = "difference"),
  dalex = feature_importance(ex, B = 10, type = "difference", n_sample = Inf),
  flashlight = light_importance(fl, v = x, n_max = Inf, m_repetitions = 10),
  hstats = perm_importance(fit, X = X_valid, y = y_valid, perms = 10),
  times = 4
)

# Unit: milliseconds
# expr             min        lq      mean    median        uq       max neval  cld
# iml        1558.3352 1585.3964 1651.9098 1625.5042 1718.4233 1798.2958     4 a   
# dalex       556.1398  573.8428  594.5660  592.1752  615.2893  637.7739     4  b  
# flashlight 1207.8085 1238.2424 1347.5105 1340.0633 1456.7787 1502.1071     4   c 
# hstats      146.0656  146.9564  151.3652  149.4352  155.7741  160.5249     4    d

{hstats} is about four times as fast as the second, {DALEX}.

Partial dependence

Here, we study the time for crunching partial dependence of a continuous feature and a discrete feature.

# Partial dependence (cont)
v <- "tot_lvg_area"
microbenchmark(
  iml = FeatureEffect$new(mod, feature = v, grid.size = 50, method = "pdp"),
  dalex = partial_dependence(ex, variables = v, N = Inf, grid_points = 50),
  flashlight = light_profile(fl, v = v, pd_n_max = Inf, n_bins = 50),
  hstats = partial_dep(fit, v = v, X = X_valid, grid_size = 50, n_max = Inf),
  times = 4
)
# Unit: milliseconds
# expr            min       lq     mean    median        uq       max neval  cld
# iml        941.7763 968.5576 993.0481 1002.5849 1017.5386 1025.2462     4 a   
# dalex      694.8007 740.1619 767.1501  788.6172  794.1384  796.5654     4  b  
# flashlight 327.6056 328.7617 330.4069  330.5388  332.0522  332.9445     4   c 
# hstats     216.4040 217.0602 217.5606  217.8603  218.0611  218.1179     4    d

# Partial dependence (discrete)
v <- "structure_quality"
microbenchmark(
  iml = FeatureEffect$new(mod, feature = v, method = "pdp", grid.points = 1:5),
  dalex = partial_dependence(ex, variables = v, N = Inf, variable_type = "categorical", grid_points = 5),
  flashlight = light_profile(fl, v = v, pd_n_max = Inf),
  hstats = partial_dep(fit, v = v, X = X_valid, n_max = Inf),
  times = 4
)

# Unit: milliseconds
# expr            min        lq      mean    median        uq      max neval  cld
# iml         90.3690  91.08965  94.18403  92.57250  97.27840 101.2221     4 a   
# dalex      174.2517 174.97330 179.43483 175.87115 183.89635 191.7453     4  b  
# flashlight  43.9318  45.05070  48.09375  46.64275  51.13680  55.1577     4   c 
# hstats      24.5972  24.64975  25.01325  24.94085  25.37675  25.5741     4    d

{hstats} is 1.5 to 2 times faster than {flashlight}, and about four times as fast as the other packages.

H-statistics

How fast can overall H-statistics be computed? How fast can it do pairwise calculations?

{DALEX} does not offer these statistics yet, {iml} focuses on overall H per feature. It was the first model-agnostic implementation of H-statistics I am aware of. It uses quantile approximation by default, but we purposely force it to calculate exact, in order to compare the numbers. Thus, we made it slower than it actually is. Forgive me, Christoph :-).

# H-Stats -> we use a subset of 500 rows
X_v500 <- X_valid[1:500, ]
mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), predict.function = predf)
fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ]))

# iml  # 90 s (no pairwise possible)
system.time(
  iml_overall <- Interaction$new(mod500, grid.size = 500)
)

# flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s
system.time(  # 12s
  fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf)
)
system.time(  # 2s
  fl_pairwise <- light_interaction(
    fl500, v = coord, grid_size = Inf, n_max = Inf, pairwise = TRUE
  )
)

# hstats: 3s total
system.time({
  H <- hstats(fit, v = x, X = X_v500, n_max = Inf)
  hstats_overall <- h2_overall(H, squared = FALSE, zero = FALSE)
  hstats_pairwise <- h2_pairwise(H, squared = FALSE, zero = FALSE)
}
)

# Overall statistics correspond exactly
iml_overall$results |> filter(.interaction > 1e-6)
#     .feature .interaction
# 1:  latitude    0.2458269
# 2: longitude    0.2458269

fl_overall$data |> subset(value > 0, select = c(variable, value))
#   variable  value
# 1 latitude  0.246
# 2 longitude 0.246

hstats_overall
# longitude  latitude 
# 0.2458269 0.2458269 

# Pairwise results match as well
fl_pairwise$data |> subset(value > 0, select = c(variable, value))
# latitude:longitude 0.394

hstats_pairwise
# latitude:longitude 
# 0.3942526     4    d
  • {hstats} is about four times as fast as {flashlight}
  • Since one often want to study relative and absolute H statistics, in practice, the speed-up would be about a factor of eight.
  • In multi-classification/multi-output settings with m categories, the speed-up would be even m times larger.
  • Forcing all three packages to calculate exact statistics, all results match.

Wrap-Up

  • {hstats} is much faster than other XAI packages, at least in our use-case. This includes H-statistics, permutation importance, and partial dependence. Note that making good benchmarks is not my strength, so forgive any bias in the results.
  • With multivariate output, the potential is even larger.
  • H-Statistics match other implementations.

Try it out!

The full R code in one piece is here.

To leave a comment for the author, please follow the link and comment on their blog: R – Michael's and Christian's Blog.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)