Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: effectplots
Title: Effect Plots
Version: 0.2.2
Version: 0.2.3
Authors@R:
person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre"))
Description: High-performance implementation of various effect plots
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# effectplots 0.2.3

- More unit tests.
- Better code explanations for .ale().

# effectplots 0.2.2

### Minor improvement
Expand Down
58 changes: 31 additions & 27 deletions R/ale.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Accumulated Local Effects (ALE)
#'
#' @description
#' Calculates ALE for one or multiple continuous features specified by `X`.
#' Calculates uncentered ALE for one or multiple continuous features specified by `X`.
#'
#' The concept of ALE was introduced in Apley et al. (2020) as an alternative to
#' partial dependence (PD). The Ceteris Paribus clause behind PD is a blessing and
Expand All @@ -17,7 +17,11 @@
#' observations falling into this bin. This is repeated for all bins,
#' and the values are *accumulated*.
#'
#' ALE values are plotted against right bin breaks.
#' This implementation closely follows the implementation of Apley et al. (2020).
#' Notably, we also plot the values at the bin breaks, not at the bin means.
#' The main difference to Apley is that we use uniform binning, not quantile binning.
#' For large bins, we sample `ale_bin_size` observations, a step that is not necessary
#' with quantile binning. Furthermore, we don't center the values to mean 0.
#'
#' @details
#' The function is a convenience wrapper around [feature_effects()], which calls
Expand All @@ -43,7 +47,7 @@
#' M |> plot()
#'
#' M2 <- ale(fit, v = colnames(iris)[-1], data = iris, breaks = 5)
#' plot(M2, share_y = "all") # Only continuous variables shown
#' plot(M2, share_y = "all") # Only continuous variables shown
ale <- function(object, ...) {
UseMethod("ale")
}
Expand All @@ -65,8 +69,7 @@ ale.default <- function(
ale_n = 50000L,
ale_bin_size = 200L,
seed = NULL,
...
) {
...) {
feature_effects.default(
object = object,
v = v,
Expand Down Expand Up @@ -106,8 +109,7 @@ ale.ranger <- function(
ale_n = 50000L,
ale_bin_size = 200L,
seed = NULL,
...
) {
...) {
if (is.null(pred_fun)) {
pred_fun <- function(model, newdata, ...) {
stats::predict(model, newdata, ...)$predictions
Expand Down Expand Up @@ -149,8 +151,7 @@ ale.explainer <- function(
ale_n = 50000L,
ale_bin_size = 200L,
seed = NULL,
...
) {
...) {
ale.default(
object = object[["model"]],
v = v,
Expand Down Expand Up @@ -187,8 +188,7 @@ ale.H2OModel <- function(
ale_n = 50000L,
ale_bin_size = 200L,
seed = NULL,
...
) {
...) {
if (!requireNamespace("h2o", quietly = TRUE)) {
stop("Package 'h2o' not installed")
}
Expand Down Expand Up @@ -227,9 +227,8 @@ ale.H2OModel <- function(
#' Per bin, the local effect \eqn{D_j} is calculated, and then accumulated over bins.
#' \eqn{D_j} equals the difference between the partial dependence at the
#' lower and upper bin breaks using only observations within bin.
#' To plot the values, we can make a line plot of the resulting vector against
#' upper bin breaks. Alternatively, the vector can be extended
#' from the left by the value 0, and then plotted against *all* breaks.
#' The values are to be plotted against bin breaks.
#' Note that no centering is applied, i.e., the first value starts at 0.
#'
#' @param v Variable name in `data` to calculate ALE.
#' @param data Matrix or data.frame.
Expand All @@ -242,7 +241,7 @@ ale.H2OModel <- function(
#' @param g For internal use. The result of `as.factor(findInterval(...))`.
#' By default `NULL`.
#' @inheritParams feature_effects
#' @returns Vector representing one ALE per bin.
#' @returns Vector representing one value per break.
#' @export
#' @seealso [partial_dependence()]
#' @inherit ale references
Expand All @@ -262,47 +261,52 @@ ale.H2OModel <- function(
bin_size = 200L,
w = NULL,
g = NULL,
...
) {
...) {
if (is.null(g)) {
x <- if (is.data.frame(data)) data[[v]] else data[, v]
g <- findInterval(
x, vec = breaks, rightmost.closed = TRUE, left.open = right, all.inside = TRUE
x = x, vec = breaks, rightmost.closed = TRUE, left.open = right, all.inside = TRUE
)
g <- collapse::qF(g, sort = FALSE)
}

# List of bin indices. We remove empty or NA bins.
# List containing selected row indices per unsorted(!) bin ID
J <- lapply(
collapse::gsplit(g = g, use.g.names = TRUE),
function(z) if (length(z) <= bin_size) z else sample(z, size = bin_size)
)
# Remove empty or NA bins
ok <- !is.na(names(J)) & lengths(J, use.names = FALSE) > 0L
if (!all(ok)) {
J <- J[ok]
}

# Before flattening the list J, we store bin counts
bin_n <- lengths(J, use.names = FALSE)
ix <- as.integer(names(J))
bin_sizes <- lengths(J, use.names = FALSE)
ix <- as.integer(names(J)) # Unsorted bin IDs
J <- unlist(J, recursive = FALSE, use.names = FALSE)

# Empty bins will get an incremental effect of 0
out <- numeric(length(breaks) - 1L)
# Initialize local effects with 0
out <- numeric(length(breaks)) # first value corresponds to first left break

# Now we create a single prediction dataset. Lower bin edges first, then upper ones.
# Create single prediction data set. Lower bin edges first, then upper ones
# This makes the code harder to read, but is more efficient
data_long <- collapse::ss(data, rep.int(J, 2L))
grid_long <- rep.int(c(breaks[ix], breaks[ix + 1L]), times = c(bin_n, bin_n))
grid_long <- rep.int(c(breaks[ix], breaks[ix + 1L]), times = c(bin_sizes, bin_sizes))
if (is.data.frame(data_long)) {
data_long[[v]] <- grid_long
} else {
data_long[, v] <- grid_long
}

pred <- prep_pred(
pred_fun(object, data_long, ...), trafo = trafo, which_pred = which_pred
pred_fun(object, data_long, ...),
trafo = trafo, which_pred = which_pred
)

# Aggregate individual local effects
n <- length(J)
out[ix] <- collapse::fmean(
out[ix + 1L] <- collapse::fmean(
pred[(n + 1L):(2L * n)] - pred[1L:n],
g = collapse::fdroplevels(g[J]),
w = if (!is.null(w)) w[J],
Expand Down
2 changes: 1 addition & 1 deletion R/feature_effects.R
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ calculate_stats <- function(
w = ale_data$w,
g = if (is.null(ale_data$ix)) ix else ix[ale_data$ix],
...
)
)[-1L] # drop value at first break as we have one value too much
ok <- !is.na(out$bin_mid)

# Centering possible?
Expand Down
10 changes: 7 additions & 3 deletions man/ale.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 3 additions & 4 deletions man/dot-ale.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 7 additions & 6 deletions packaging.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#=============================================================================
# =============================================================================
# Put together the package
#=============================================================================
# =============================================================================

# WORKFLOW: UPDATE EXISTING PACKAGE
# 1) Modify package content and documentation.
Expand All @@ -15,7 +15,7 @@ library(usethis)
use_description(
fields = list(
Title = "Effect Plots",
Version = "0.2.2",
Version = "0.2.3",
Description = "High-performance implementation of various effect plots
useful for regression and probabilistic classification tasks.
The package includes partial dependence plots
Expand Down Expand Up @@ -47,7 +47,8 @@ use_gpl_license()

# Your files that do not belong to the package itself (others are added by "use_* function")
use_build_ignore(
c("^packaging.R$", "[.]Rproj$", "^logo.png$", "^claims.parquet$"), escape = FALSE
c("^packaging.R$", "[.]Rproj$", "^logo.png$", "^claims.parquet$"),
escape = FALSE
)

# Add short docu in Markdown (without running R code)
Expand All @@ -74,9 +75,9 @@ use_rcpp()
# use_github_action("test-coverage")
# use_github_action("pkgdown")

#=============================================================================
# =============================================================================
# Finish package building (can use fresh session)
#=============================================================================
# =============================================================================

library(devtools)

Expand Down
61 changes: 60 additions & 1 deletion tests/testthat/test-ale.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,68 @@ test_that("ale() is consistent with feature_effects()", {
)
suppressMessages(
marg <- feature_effects(
fit, v = v, data = iris, calc_pred = FALSE, pd_n = 0, w = 1:150
fit,
v = v, data = iris, calc_pred = FALSE, pd_n = 0, w = 1:150
)
)
expect_equal(ale, marg)
})

test_that(".ale() respects case weights", {
fit <- lm(Sepal.Length ~ Species * Sepal.Width, data = iris)
v <- "Sepal.Width"
br <- c(2, 2.5, 3, 3.5, 4.5)
w <- c(rep(1L, times = 100L), rep(2L, times = 50L))
ix <- rep(1:nrow(iris), times = w)

res_w <- .ale(fit, v = v, data = iris, breaks = br, w = w)
res_uw <- .ale(fit, v = v, data = iris[ix, ], breaks = br)

expect_equal(res_w, res_uw)
})

test_that("ale() respects case weights", {
fit <- lm(Sepal.Length ~ Species * Sepal.Width, data = iris)
v <- "Sepal.Width"
br <- c(2, 2.5, 3, 3.5, 4.5)
w <- c(rep(1L, times = 100L), rep(2L, times = 50L))
ix <- rep(1:nrow(iris), times = w)

res_w <- ale(fit, v = v, data = iris, breaks = br, w = w)[[1L]]
res_uw <- ale(fit, v = v, data = iris[ix, ], breaks = br)[[1L]]

expect_equal(res_w[-4L], res_uw[-4L])
})

test_that("the level order of g does not matter in .ale()", {
fit <- lm(Sepal.Length ~ Species * Sepal.Width, data = iris)
v <- "Sepal.Width"
br <- c(2, 2.5, 3, 3.5, 4.5)
g1 <- factor(cut(iris[[v]], breaks = br, include.lowest = TRUE, labels = FALSE))
g2 <- factor(g1, levels = rev(levels(g1)))
res_1 <- .ale(fit, v = v, data = iris, breaks = br, g = g1)
res_2 <- .ale(fit, v = v, data = iris, breaks = br, g = g2)

expect_equal(res_1, res_2)
})

test_that(".ale() is consistent with uncentered ALEPlot v 1.1", {
# We use debugonce(ALEPlot) to see the uncentered values
fit <- lm(Sepal.Length ~ Species * Sepal.Width, data = iris)
v <- "Sepal.Width"
K <- 5
br <- unique(stats::quantile(iris[[v]], probs = seq(0, 1, length.out = K + 1)))

# debugonce(ALEPlot)
# ALEPlot(
# iris,
# X.model = fit,
# pred.fun = function(X.model, newdata) as.numeric(predict(X.model, newdata)),
# J = v,
# K = 5
# )
reference <- c(0.0000000, 0.6103576, 0.8673605, 0.9488453, 1.1848637, 1.9006788)
result <- .ale(fit, v = v, data = iris, breaks = br)

expect_equal(result, reference, tolerance = 1e-6)
})
22 changes: 22 additions & 0 deletions tests/testthat/test-average_observed.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,25 @@ test_that("single vector input works", {
expect_equal(out1, out2)
expect_equal(out1, out3)
})

test_that("case weights are respected", {
v <- "Sepal.Width"
br <- c(2, 2.5, 3, 3.5, 4.5)
w <- c(rep(1L, times = 100L), rep(2L, times = 50L))
ix <- rep(1:nrow(iris), times = w)
y <- iris$Sepal.Length

res_w <- average_observed(
iris[v],
y = y,
breaks = br,
w = w
)[[1L]]
res_uw <- average_observed(
iris[ix, v],
y = y[ix],
breaks = br
)[[1L]]

expect_equal(res_w[-4L], res_uw[-4L])
})
22 changes: 22 additions & 0 deletions tests/testthat/test-average_predicted.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,25 @@ test_that("single vector input works", {
expect_equal(out1, out2)
expect_equal(out1, out3)
})

test_that("case weights are respected", {
fit <- lm(Sepal.Length ~ Species * Sepal.Width, data = iris)
v <- "Sepal.Width"
br <- c(2, 2.5, 3, 3.5, 4.5)
w <- c(rep(1L, times = 100L), rep(2L, times = 50L))
ix <- rep(1:nrow(iris), times = w)

res_w <- average_predicted(
iris[v],
pred = predict(fit, iris),
breaks = br,
w = w
)[[1L]]
res_uw <- average_predicted(
iris[ix, v],
pred = predict(fit, iris[ix, ]),
breaks = br
)[[1L]]

expect_equal(res_w[-4L], res_uw[-4L])
})
Loading