What makes a ML model black-box? It’s the interactions!
The first step in understanding interactions is to measure their strength. This is exactly what Friedman and Popescu’s H-statistics [1] do:
Statistic | Short description | How to read its value? |
---|---|---|
\(H^2_j\) | Overall interaction strength per feature | Proportion of prediction variability explained by interactions on feature \(j\). |
\(H^2_{jk}\) | Pairwise interaction strength | Proportion of joint effect variability of features \(j\) and \(k\) coming from their pairwise interaction. |
\(H^2_{jkl}\) | Three-way interaction strength | Proportion of joint effect variability of three features coming from their three-way interaction. |
See section Background for details and definitions.
{hstats} offers these statistics comparably fast and for any model, even for multi-output models, or models with case weights. Additionally, we provide a global statistic \(H^2\) measuring the proportion of prediction variability unexplained by main effects [5], and an experimental feature importance measure. After having identified strong interactions, their shape can be investigated by stratified partial dependence or ICE plots.
The core functions hstats()
, partial_dep()
,
ice()
, perm_importance()
, and
average_loss()
can directly be applied to DALEX explainers,
meta learners (mlr3, tidymodels, caret) and most other models. In case
you need more flexibility, a tailored prediction function can be
specified. Both data.frame and matrix data structures are supported.
n_max = 500
of hstats()
.{hstats} is not the first R package to explore interactions. Here is an incomplete selection:
# From CRAN
install.packages("hstats")
# From Github
::install_github("ModelOriented/hstats") devtools
To demonstrate the typical workflow, we use a beautiful house price dataset with about 14,000 transactions from Miami-Dade County available in the {shapviz} package, and analyzed in [3]. We are going to model logarithmic sales prices with XGBoost.
library(hstats)
library(ggplot2)
library(xgboost)
library(shapviz)
colnames(miami) <- tolower(colnames(miami))
$log_ocean <- log(miami$ocean_dist)
miami<- c("log_ocean", "tot_lvg_area", "lnd_sqfoot", "structure_quality", "age", "month_sold")
x
# Train/valid split
set.seed(1)
<- sample(nrow(miami), 0.8 * nrow(miami))
ix
<- log(miami$sale_prc[ix])
y_train <- log(miami$sale_prc[-ix])
y_valid <- data.matrix(miami[ix, x])
X_train <- data.matrix(miami[-ix, x])
X_valid
<- xgb.DMatrix(X_train, label = y_train)
dtrain <- xgb.DMatrix(X_valid, label = y_valid)
dvalid
# Fit via early stopping
<- xgb.train(
fit params = list(learning_rate = 0.15, objective = "reg:squarederror", max_depth = 5),
data = dtrain,
watchlist = list(valid = dvalid),
early_stopping_rounds = 20,
nrounds = 1000,
callbacks = list(cb.print.evaluation(period = 100))
)
# Mean squared error: 0.0515
average_loss(fit, X = X_valid, y = y_valid)
Let’s calculate different H-statistics via hstats()
:
# 4 seconds on simple laptop - a random forest will take 2 minutes
set.seed(782)
system.time(
<- hstats(fit, X = X_train) #, approx = TRUE: twice as fast
s
)
s# H^2 (normalized)
# [1] 0.10
plot(s) # Or summary(s) for numeric output
# Save for later
# saveRDS(s, file = "h_statistics.rds")
Interpretation
Remarks
plot(h2_pairwise(s, normalize = FALSE, squared = FALSE))
Since distance to the ocean and age have high values in overall interaction strength, it is not surprising that a strong relative pairwise interaction is translated into a strong absolute one.
Note: {hstats} can crunch three-way interaction
statistics \(H^2_{jkl}\) as well. To
calculate them for \(m\) features with
strongest overall interactions, set threeway_m = m
.
Let’s study different plots to understand how the strong interaction between distance to the ocean and age looks like. We will check the following three visualizations.
They all reveal a substantial interaction between the two variables
in the sense that the age effect gets weaker the closer to the ocean.
Note that numeric BY
features are automatically binned into
quartile groups.
partial_dep(fit, v = "age", X = X_train, BY = "log_ocean") |>
plot(show_points = FALSE)
<- partial_dep(fit, v = c("age", "log_ocean"), X = X_train, grid_size = 1000)
pd plot(pd)
plot(pd, d2_geom = "line", show_points = FALSE)
ice(fit, v = "age", X = X_train, BY = "log_ocean") |>
plot(center = TRUE)
In the spirit of [1], and related to [4], we can extract from the “hstats” objects a partial dependence based variable importance measure. It measures not only the main effect strength (see [4]), but also all its interaction effects. It is rather experimental, so use it with care (details in the section “Background”):
pd_importance(s) |>
plot()
# Compared with four times repeated permutation importance regarding MSE
set.seed(10)
perm_importance(fit, X = X_valid, y = y_valid) |>
plot()
Permutation importance returns the same order in this case:
The main functions work smoothly on DALEX explainers:
library(hstats)
library(DALEX)
library(ranger)
set.seed(1)
<- ranger(Sepal.Length ~ ., data = iris)
fit <- DALEX::explain(fit, data = iris[, -1], y = iris[, 1])
ex
<- hstats(ex)
s # 0.054
s plot(s)
# Strongest relative interaction (different visualizations)
ice(ex, v = "Sepal.Width", BY = "Petal.Width") |>
plot(center = TRUE)
partial_dep(ex, v = "Sepal.Width", BY = "Petal.Width") |>
plot(show_points = FALSE)
partial_dep(ex, v = c("Sepal.Width", "Petal.Width"), grid_size = 200) |>
plot()
perm_importance(ex)
# Petal.Length Petal.Width Sepal.Width Species
# 0.59836442 0.11625137 0.07966910 0.03982554
Strongest relative interaction shown as ICE plot.
{hstats} works also with multivariate output, see examples for probabilistic classification with
library(hstats)
<- c(1:40, 51:90, 101:140)
ix <- iris[ix, ]
train <- iris[-ix, ]
valid
<- data.matrix(train[, -5])
X_train <- data.matrix(valid[, -5])
X_valid <- train[[5]]
y_train <- valid[[5]] y_valid
library(ranger)
set.seed(1)
<- ranger(Species ~ ., data = train, probability = TRUE)
fit average_loss(fit, X = valid, y = "Species", loss = "mlogloss") # 0.02
perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")
<- hstats(fit, X = iris[, -5]))
(s plot(s, normalize = FALSE, squared = FALSE)
ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width") |>
plot(center = TRUE)
Note: Versions < 4.0.0 require passing reshape = TRUE
to the prediction function.
library(lightgbm)
set.seed(1)
<- list(objective = "multiclass", num_class = 3, learning_rate = 0.2)
params <- lgb.Dataset(X_train, label = as.integer(y_train) - 1)
dtrain <- lgb.Dataset(X_valid, label = as.integer(y_valid) - 1)
dvalid
<- lgb.train(
fit params = params,
data = dtrain,
valids = list(valid = dvalid),
early_stopping_rounds = 20,
nrounds = 1000
)
# mlogloss: 9.331699e-05
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss")
perm_importance(fit, X = X_valid, y = y_valid, loss = "mlogloss", m_rep = 100)
# Petal.Length Petal.Width Sepal.Width Sepal.Length
# 2.624241332 1.011168660 0.082477177 0.009757393
partial_dep(fit, v = "Petal.Length", X = X_train) |>
plot(show_points = FALSE)
ice(fit, v = "Petal.Length", X = X_train) |>
plot(alpha = 0.05)
# Interaction statistics, including three-way stats
<- hstats(fit, X = X_train, threeway_m = 4))
(H # 0.3010446 0.4167927 0.1623982
plot(H, ncol = 1)
Mind the reshape = TRUE
sent to the prediction
function.
library(xgboost)
set.seed(1)
<- list(objective = "multi:softprob", num_class = 3, learning_rate = 0.2)
params <- xgb.DMatrix(X_train, label = as.integer(y_train) - 1)
dtrain <- xgb.DMatrix(X_valid, label = as.integer(y_valid) - 1)
dvalid
<- xgb.train(
fit params = params,
data = dtrain,
watchlist = list(valid = dvalid),
early_stopping_rounds = 20,
nrounds = 1000
)
# We need to pass reshape = TRUE to get a beautiful matrix
predict(fit, head(X_train, 2), reshape = TRUE)
# [,1] [,2] [,3]
# [1,] 0.9974016 0.002130089 0.0004682819
# [2,] 0.9971375 0.002129525 0.0007328897
# mlogloss: 0.006689544
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE)
partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(show_points = FALSE)
ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(alpha = 0.05)
perm_importance(
X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100
fit,
)# Permutation importance regarding mlogloss
# Petal.Length Petal.Width Sepal.Length Sepal.Width
# 1.731532873 0.276671377 0.009158659 0.005717263
# Interaction statistics including three-way stats
<- hstats(fit, X = X_train, reshape = TRUE, threeway_m = 4))
(H # 0.02714399 0.16067364 0.11606973
plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)
Here, we provide examples for {tidymodels}, {caret}, and {mlr3}.
library(hstats)
library(tidymodels)
set.seed(1)
<- iris |>
iris_recipe recipe(Sepal.Length ~ .)
<- linear_reg() |>
reg set_engine("lm")
<- workflow() |>
iris_wf add_recipe(iris_recipe) |>
add_model(reg)
<- iris_wf |>
fit fit(iris)
<- hstats(fit, X = iris[, -1])
s # 0 -> no interactions
s
partial_dep(fit, v = "Petal.Width", X = iris) |>
plot()
<- perm_importance(fit, X = iris, y = "Sepal.Length")
imp
imp# Permutation importance
# Petal.Length Species Petal.Width Sepal.Width
# 4.39197781 0.35038891 0.11966090 0.09604322
plot(imp)
library(hstats)
library(caret)
set.seed(1)
<- train(
fit ~ .,
Sepal.Length data = iris,
method = "lm",
tuneGrid = data.frame(intercept = TRUE),
trControl = trainControl(method = "none")
)
h2(hstats(fit, X = iris[, -1])) # 0
ice(fit, v = "Petal.Width", X = iris) |>
plot(center = TRUE)
perm_importance(fit, X = iris, y = "Sepal.Length") |>
plot()
library(hstats)
library(mlr3)
library(mlr3learners)
set.seed(1)
# Probabilistic classification
<- TaskClassif$new(id = "class", backend = iris, target = "Species")
task_iris <- lrn("classif.ranger", predict_type = "prob")
fit_rf $train(task_iris)
fit_rf<- hstats(fit_rf, X = iris[, -5], predict_type = "prob")
s plot(s)
# Permutation importance (wrt multi-logloss)
<- perm_importance(
p X = iris, y = "Species", loss = "mlogloss", predict_type = "prob"
fit_rf,
)plot(p)
In [1], Friedman and Popescu introduced different statistics to measure interaction strength based on partial dependence functions. Closely following their notation, we will summarize the main ideas.
Let \(F: R^p \to R\) denote the prediction function that maps the \(p\)-dimensional feature vector \(\boldsymbol x = (x_1, \dots, x_p)\) to its prediction. Furthermore, let \(F_s(\boldsymbol x_s) = E_{\boldsymbol x_{\setminus s}}(F(\boldsymbol x_s, \boldsymbol x_{\setminus s}))\) be the partial dependence function of \(F\) on the feature subset \(\boldsymbol x_s\), where \(s \subseteq \{1, \dots, p\}\), as introduced in [2]. Here, the expectation runs over the joint marginal distribution of features \(\boldsymbol x_{\setminus s}\) not in \(\boldsymbol x_s\).
Given data, \(F_s(\boldsymbol x_s)\) can be estimated by the empirical partial dependence function
\[ \hat F_s(\boldsymbol x_s) = \frac{1}{n} \sum_{i = 1}^n F(\boldsymbol x_s, \boldsymbol x_{i \setminus s}), \]
where \(\boldsymbol x_{i\setminus s}\), \(i = 1, \dots, n\), are the observed values of \(\boldsymbol x_{\setminus s}\).
A partial dependence plot (PDP) plots the values of \(\hat F_s(\boldsymbol x_s)\) over a grid of evaluation points \(\boldsymbol x_s\). Its disaggregated version is called individual conditional expectation (ICE), see [7].
If there are no interactions involving \(x_j\), we can decompose the prediction function \(F\) into the sum of the partial dependence \(F_j\) on \(x_j\) and the partial dependence \(F_{\setminus j}\) on all other features \(\boldsymbol x_{\setminus j}\), i.e.,
\[ F(\boldsymbol x) = F_j(x_j) + F_{\setminus j}(\boldsymbol x_{\setminus j}). \]
Correspondingly, Friedman and Popescu’s statistic of overall interaction strength of \(x_j\) is given by
\[ H_{j}^2 = \frac{\frac{1}{n} \sum_{i = 1}^n\big[F(\boldsymbol x_i) - \hat F_j(x_{ij}) - \hat F_{\setminus j}(\boldsymbol x_{i\setminus j})\big]^2}{\frac{1}{n} \sum_{i = 1}^n\big[F(\boldsymbol x_i)\big]^2}. \]
Remarks
Again following [1], if there are no interaction effects between features \(x_j\) and \(x_k\), their two-dimensional partial dependence function \(F_{jk}\) can be written as the sum of the univariate partial dependencies, i.e.,
\[ F_{jk}(x_j, x_k) = F_j(x_j) + F_k(x_k). \]
Correspondingly, Friedman and Popescu’s statistic of pairwise interaction strength between \(x_j\) and \(x_k\) is defined as
\[ H_{jk}^2 = \frac{A_{jk}}{\frac{1}{n} \sum_{i = 1}^n\big[\hat F_{jk}(x_{ij}, x_{ik})\big]^2} \]
where
\[ A_{jk} = \frac{1}{n} \sum_{i = 1}^n\big[\hat F_{jk}(x_{ij}, x_{ik}) - \hat F_j(x_{ij}) - \hat F_k(x_{ik})\big]^2. \]
Remarks
Modification
To be better able to compare pairwise interaction strength across variable pairs, and to overcome the problem mentioned in the last remark, we suggest as alternative the unnormalized test statistic on the scale of the predictions, i.e., \(\sqrt{A_{jk}}\).
Furthermore, instead of focusing on pairwise calculations for the most important features, we can select features with strongest overall interactions.
[1] also describes a test statistic to measure three-way interactions: in case there are no three-way interactions between features \(x_j\), \(x_k\) and \(x_l\), their three-dimensional partial dependence function \(F_{jkl}\) can be decomposed into lower order terms:
\[ F_{jkl}(x_j, x_k, x_l) = B_{jkl} - C_{jkl} \]
with
\[ B_{jkl} = F_{jk}(x_j, x_k) + F_{jl}(x_j, x_l) + F_{kl}(x_k, x_l) \]
and
\[ C_{jkl} = F_j(x_j) + F_k(x_k) + F_l(x_l). \]
The squared and scaled difference between the two sides of the equation leads to the statistic
\[ H_{jkl}^2 = \frac{\frac{1}{n} \sum_{i = 1}^n \big[\hat F_{jkl}(x_{ij}, x_{ik}, x_{il}) - B^i_{jkl} + C^i_{jkl}\big]^2}{\frac{1}{n} \sum_{i = 1}^n \hat F_{jkl}(x_{ij}, x_{ik}, x_{il})^2}, \]
where
\[ B^i_{jkl} = \hat F_{jk}(x_{ij}, x_{ik}) + \hat F_{jl}(x_{ij}, x_{il}) + \hat F_{kl}(x_{ik}, x_{il}) \]
and
\[ C^i_{jkl} = \hat F_j(x_{ij}) + \hat F_k(x_{ik}) + \hat F_l(x_{il}). \]
Similar remarks as for \(H^2_{jk}\) apply.
If the model is additive in all features (no interactions), then
\[ F(\boldsymbol x) = \sum_{j}^{p} F_j(x_j), \]
i.e., the (centered) predictions can be written as the sum of the (centered) main effects.
To measure the relative amount of variability unexplained by all main effects, we can therefore study the test statistic of total interaction strength
\[ H^2 = \frac{\frac{1}{n} \sum_{i = 1}^n \left[F(\boldsymbol x_i) - \sum_{j = 1}^p\hat F_j(x_{ij})\right]^2}{\frac{1}{n} \sum_{i = 1}^n\left[F(\boldsymbol x_i)\right]^2}. \]
A value of 0 means there are no interaction effects at all. Due to (typically undesired) extrapolation effects of partial dependence functions, depending on the model, values above 1 may occur.
In [5], \(1 - H^2\) is called additivity index. A similar measure using accumulated local effects is discussed in [6].
Calculation of all \(H_j^2\) requires \(O(n^2 p)\) predictions, while calculating of all pairwise \(H_{jk}\) requires \(O(n^2 p^2\) predictions. Therefore, we suggest to reduce the workflow in two important ways:
This leads to a total number of \(O(n'^2 p)\) predictions. If also three-way interactions are to be studied, \(m\) should be of the order \(p^{1/3}\).
[4] proposed the standard deviation of the partial dependence function as a measure of variable importance (for continuous predictors).
Since the partial dependence function suppresses interaction effects, we propose a different measure in the spirit of the interaction statistics above: If \(x_j\) has no effects, the (centered) prediction function \(F\) equals the (centered) partial dependence \(F_{\setminus j}\) on all other features \(\boldsymbol x_{\setminus j}\), i.e.,
\[ F(\boldsymbol x) = F_{\setminus j}(\boldsymbol x_{\setminus j}). \]
Therefore, the following measure of variable importance follows:
\[ PDI_{j} = \frac{\frac{1}{n} \sum_{i = 1}^n\big[F(\boldsymbol x_i) - \hat F_{\setminus j}(\boldsymbol x_{i\setminus j})\big]^2}{\frac{1}{n} \sum_{i = 1}^n\big[F(\boldsymbol x_i)\big]^2}. \]
It differs from \(H^2_j\) only by not subtracting the main effect of the \(j\)-th feature in the numerator. It can be read as the proportion of prediction variability unexplained by all other features. As such, it measures variable importance of the \(j\)-th feature, including its interaction effects.