ensembleML provides a single, consistent
API for ensemble machine learning in R. Regardless of which
algorithm you choose, the core workflow is always:
em_fit() -> em_predict() -> em_evaluate()
Advanced usage adds:
em_cv() # k-fold cross-validation (stability estimates)
em_tune() # grid-search hyperparameter optimisation
em_compare() # side-by-side algorithm comparison
em_importance() # feature importance
em_partial() # partial dependence plots
em_confusion() # confusion matrix heatmap
em_calibration() # calibration / reliability diagram
em_residuals() # regression diagnostics
data(iris)
set.seed(42)
idx <- sample(nrow(iris), 120)
train <- iris[idx, ]
test <- iris[-idx, ]
rf <- em_fit(Species ~ ., data = train, method = "random_forest",
verbose = TRUE)
#> [ensembleML] task auto-detected as 'classification'
#>
#> ╭────────────────────────────────────────────────────╮
#> │ Algorithm: random_forest │
#> │ Task: classification │
#> │ Response: Species │
#> │ Classes: setosa, versicolor, virginica│
#> │ Predictors: 4 (Sepal.Length, Sepal.Width, Petal.Length, …)│
#> │ Training n: 120 │
#> │ Fit time: 0.020 sec │
#> │ Train metrics: accuracy=1.0000 kappa=1.0000 precision=1.0000 recall=1.0000 f1=1.0000 auc=NA│
#> │ ⚠ Use em_evaluate() on held-out data │
#> ╰────────────────────────────────────────────────────╯Switching algorithms requires changing a single argument:
xgb <- em_fit(Species ~ ., data = train, method = "xgboost")
ada <- em_fit(Species ~ ., data = train, method = "adaboost")
bag <- em_fit(Species ~ ., data = train, method = "bagging")preds <- em_predict(rf, newdata = test)
head(preds)
#> 7 11 12 19 23 28
#> setosa setosa setosa setosa setosa setosa
#> Levels: setosa versicolor virginicaClass probabilities:
probs <- em_predict(rf, newdata = test, type = "prob")
head(probs, 3)
#> setosa versicolor virginica
#> 7 1.000 0.000 0
#> 11 0.998 0.002 0
#> 12 1.000 0.000 0em_evaluate(rf, newdata = test)
#> accuracy kappa precision recall f1 auc
#> 0.9333 0.8997 0.9364 0.9364 0.9364 NASelect specific metrics:
em_evaluate(rf, newdata = test, metrics = c("accuracy", "f1", "kappa"))
#> accuracy f1 kappa
#> 0.9333 0.9364 0.8997Use em_cv() to get mean +/- SD across folds before
committing to a method:
cv_res <- em_cv(Species ~ ., data = iris, method = "random_forest",
cv_folds = 5, repeats = 3)
cv_res$summary
em_plot_cv(cv_res, metric = "accuracy")grid <- list(ntree = c(100, 300, 500), mtry = c(1, 2, 3))
tuned <- em_tune(
Species ~ ., data = train, method = "random_forest",
param_grid = grid, cv_folds = 5
)
tuned$best_params
tuned$best_score
tuned$resultscmp <- em_compare(Species ~ ., train = train, test = test)
cmp$tableem_importance(rf, top_n = 4)em_partial(rf, data = train, feature = "Petal.Length")em_confusion(rf, newdata = test)
em_confusion(rf, newdata = test, normalise = TRUE)Everything works identically for numeric responses:
set.seed(7)
reg_data <- data.frame(
x1 = rnorm(200), x2 = rnorm(200),
y = 3 + 2 * rnorm(200) + rnorm(200))
reg_train <- reg_data[1:160, ]
reg_test <- reg_data[161:200, ]
reg_model <- em_fit(y ~ ., data = reg_train, method = "random_forest")
#> [ensembleML] task auto-detected as 'regression'
em_evaluate(reg_model, reg_test)
#> rmse mae mape rsq adj_rsq
#> 2.4320 1.8556 88.1007 -0.2193 -0.2852
em_residuals(reg_model, reg_test)
#> `geom_smooth()` using formula = 'y ~ x'If you use ensembleML in published work, please cite
it:
citation("ensembleML")The individual algorithms should also be cited — see
citation("ensembleML") for the full list of references.
sessionInfo()
#> R version 4.2.1 (2022-06-23 ucrt)
#> Platform: x86_64-w64-mingw32/x64 (64-bit)
#> Running under: Windows 10 x64 (build 26200)
#>
#> Matrix products: default
#>
#> locale:
#> [1] LC_COLLATE=C
#> [2] LC_CTYPE=English_United States.utf8
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C
#> [5] LC_TIME=English_United States.utf8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] ensembleML_0.2.5
#>
#> loaded via a namespace (and not attached):
#> [1] bslib_0.10.0 compiler_4.2.1 pillar_1.11.1
#> [4] RColorBrewer_1.1-3 jquerylib_0.1.4 tools_4.2.1
#> [7] digest_0.6.39 lattice_0.20-45 nlme_3.1-168
#> [10] jsonlite_2.0.0 evaluate_1.0.5 lifecycle_1.0.5
#> [13] tibble_3.3.1 gtable_0.3.6 mgcv_1.8-40
#> [16] pkgconfig_2.0.3 rlang_1.1.7 Matrix_1.6-5
#> [19] cli_3.6.5 rstudioapi_0.18.0 yaml_2.3.12
#> [22] xfun_0.57 fastmap_1.2.0 gridExtra_2.3
#> [25] withr_3.0.2 dplyr_1.2.0 knitr_1.51
#> [28] generics_0.1.4 sass_0.4.10 vctrs_0.7.2
#> [31] grid_4.2.1 tidyselect_1.2.1 glue_1.7.0
#> [34] R6_2.6.1 otel_0.2.0 rmarkdown_2.31
#> [37] ggplot2_4.0.2 farver_2.1.2 magrittr_2.0.3
#> [40] splines_4.2.1 scales_1.4.0 htmltools_0.5.9
#> [43] randomForest_4.7-1.2 labeling_0.4.3 S7_0.2.1
#> [46] cachem_1.1.0