In the following, we explain the counterfactuals
workflow for both a classification and a regression task using concrete
use cases.
library("counterfactuals")
library("iml")
library("rpart")The Predictor class of the iml package
provides the necessary flexibility to cover classification and
regression models fitted with diverse R packages. In the introduction
vignette, we saw models fitted with the mlr3 and
randomForest packages. In the following, we show extensions
to - an classification tree fitted with the caret package,
the mlr (a predecesor of mlr3) and
tidymodels. For each model we generate counterfactuals for
the 100th row of the plasma dataset of the gamlss.data
package using the WhatIf method.
data(plasma, package = "gamlss.data")
x_interest = plasma[100L,]library("caret")
treecaret = caret::train(retplasma ~ ., data = plasma[-100L,], method = "rpart",
tuneGrid = data.frame(cp = 0.01))
predcaret = Predictor$new(model = treecaret, data = plasma[-100L,], y = "retplasma")
predcaret$predict(x_interest)
#> .prediction
#> 1 342.9231nicecaret = NICERegr$new(predcaret, optimization = "proximity",
margin_correct = 0.5, return_multiple = FALSE)
nicecaret$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218library("tidymodels")
treetm = decision_tree(mode = "regression", engine = "rpart") %>%
fit(retplasma ~ ., data = plasma[-100L,])
predtm = Predictor$new(model = treetm, data = plasma[-100L,], y = "retplasma")
predtm$predict(x_interest)
#> .pred
#> 1 342.9231nicetm = NICERegr$new(predtm, optimization = "proximity",
margin_correct = 0.5, return_multiple = FALSE)
nicetm$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218library("mlr")
#> Warning in fun(pkgname, pkgpath): Packages 'paradox' and 'ParamHelpers' are conflicting and should not be loaded in the same session
#> Warning in fun(pkgname, pkgpath): Packages 'mlr3' and 'mlr' are conflicting and should not be loaded in the same sessiontask = mlr::makeRegrTask(data = plasma[-100L,], target = "retplasma")
mod = mlr::makeLearner("regr.rpart")
treemlr = mlr::train(mod, task)
predmlr = Predictor$new(model = treemlr, data = plasma[-100L,], y = "retplasma")
predmlr$predict(x_interest)
#> .prediction
#> 1 342.9231nicemlr = NICERegr$new(predmlr, optimization = "proximity",
margin_correct = 0.5, return_multiple = FALSE)
nicemlr$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218treerpart = rpart(retplasma ~ ., data = plasma[-100L,])
predrpart = Predictor$new(model = treerpart, data = plasma[-100L,], y = "retplasma")
predrpart$predict(x_interest)
#> pred
#> 1 342.9231nicerpart = NICERegr$new(predrpart, optimization = "proximity",
margin_correct = 0.5, return_multiple = FALSE)
nicerpart$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218