Skip to content

Commit

Permalink
better call routing for errors (#1214)
Browse files Browse the repository at this point in the history
* pass call through user-facing predict methods

* pass call for internal code; probably will never be surfaced by user

* show error is from autoplot() instead of map_glmnet_coefs()

* un-used bartMachine code

* fix bug in condense_control and route user-facing call

* unit tests for one-hot encodings

* pass calls through data conversion functions

* redoc

* small formatting changes

* route some glmnet checking calls

* route some spec updating calls

* some predict call routing

* make dev function as internal

* redoc

* revert passing in predict

* Apply suggestions from code review

Co-authored-by: Simon P. Couch <[email protected]>

* update snapshots

* redoc

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>
Co-authored-by: Simon P. Couch <[email protected]>
  • Loading branch information
3 people authored Oct 22, 2024
1 parent 33f621c commit 58e4329
Show file tree
Hide file tree
Showing 40 changed files with 262 additions and 164 deletions.
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ export(bag_mars)
export(bag_mlp)
export(bag_tree)
export(bart)
export(bartMachine_interval_calc)
export(boost_tree)
export(case_weights_allowed)
export(cforest_train)
Expand Down
4 changes: 2 additions & 2 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ make_form_call <- function(object, env = NULL) {
}

# TODO we need something to indicate that case weights are being used.
make_xy_call <- function(object, target, env) {
make_xy_call <- function(object, target, env, call = rlang::caller_env()) {
fit_args <- object$method$fit$args
uses_weights <- has_weights(env)

Expand All @@ -283,7 +283,7 @@ make_xy_call <- function(object, target, env) {
data.frame = rlang::expr(maybe_data_frame(x)),
matrix = rlang::expr(maybe_matrix(x)),
dgCMatrix = rlang::expr(maybe_sparse_matrix(x)),
cli::cli_abort("Invalid data type target: {target}.")
cli::cli_abort("Invalid data type target: {target}.", call = call)
)
if (uses_weights) {
object$method$fit$args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights)
Expand Down
10 changes: 6 additions & 4 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL,
}


map_glmnet_coefs <- function(x) {
map_glmnet_coefs <- function(x, call = rlang::caller_env()) {
coefs <- coef(x)
# If parsnip is used to fit the model, glmnet should be attached and this will
# work. If an object is loaded from a new session, they will need to load the
# package.
if (is.null(coefs)) {
cli::cli_abort(
"Please load the {.pkg glmnet} package before running {.fun autoplot}."
"Please load the {.pkg glmnet} package before running {.fun autoplot}.",
call = call
)
}
p <- x$dim[1]
Expand Down Expand Up @@ -89,9 +90,10 @@ top_coefs <- function(x, top_n = 5) {
dplyr::slice(seq_len(top_n))
}

autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) {
autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L,
call = rlang::caller_env(), ...) {
tidy_coefs <-
map_glmnet_coefs(x) %>%
map_glmnet_coefs(x, call = call) %>%
dplyr::filter(penalty >= min_penalty)

actual_min_penalty <- min(tidy_coefs$penalty)
Expand Down
48 changes: 0 additions & 48 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,61 +130,13 @@ update.bart <-
)
}


#' Developer functions for predictions via BART models
#' @export
#' @keywords internal
#' @name bart-internal
#' @inheritParams predict.model_fit
#' @param obj A parsnip object.
#' @param ci Confidence (TRUE) or prediction interval (FALSE)
#' @param level Confidence level.
#' @param std_err Attach column for standard error of prediction or not.
bartMachine_interval_calc <- function(new_data, obj, ci = TRUE, level = 0.95) {
if (obj$spec$mode == "classification") {
cli::cli_abort(
"Prediction intervals are not possible for classification"
)
}
get_std_err <- obj$spec$method$pred$pred_int$extras$std_error

if (ci) {
cl <-
rlang::call2(
"calc_credible_intervals",
.ns = "bartMachine",
bart_machine = rlang::expr(obj$fit),
new_data = rlang::expr(new_data),
ci_conf = level
)

} else {
cl <-
rlang::call2(
"calc_prediction_intervals",
.ns = "bartMachine",
bart_machine = rlang::expr(obj$fit),
new_data = rlang::expr(new_data),
pi_conf = level
)
}
res <- rlang::eval_tidy(cl)
if (!ci) {
if (get_std_err) {
.std_error <- apply(res$all_prediction_samples, 1, stats::sd, na.rm = TRUE)
}
res <- res$interval
}
res <- tibble::as_tibble(res)
names(res) <- c(".pred_lower", ".pred_upper")
if (!ci & get_std_err) {
res$.std_err <- .std_error
}
res
}

#' @export
#' @rdname bart-internal
#' @keywords internal
dbart_predict_calc <- function(obj, new_data, type, level = 0.95, std_err = FALSE) {
types <- c("numeric", "class", "prob", "conf_int", "pred_int")
Expand Down
13 changes: 9 additions & 4 deletions R/condense_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#'
#' @return A control object with the same elements and classes of `ref`, with
#' values of `x`.
#' @param call The execution environment of a currently running function, e.g.
#' `caller_env()`. The function will be mentioned in error messages as the
#' source of the error. See the call argument of [rlang::abort()] for more
#' information.
#' @keywords internal
#' @export
#'
Expand All @@ -20,16 +24,17 @@
#'
#' ctrl <- condense_control(ctrl, control_parsnip())
#' str(ctrl)
condense_control <- function(x, ref) {
condense_control <- function(x, ref, ..., call = rlang::caller_env()) {
check_dots_empty()
mismatch <- setdiff(names(ref), names(x))
if (length(mismatch)) {
cli::cli_abort(
c(
"Object of class {.cls class(x)[1]} cannot be coerced to
object of class {.cls class(ref)[1]}.",
"{.obj_type_friendly {x}} cannot be coerced to {.obj_type_friendly {ref}}.",
"i" = "{cli::qty(mismatch)} The argument{?s} {.arg {mismatch}}
{?is/are} missing."
)
),
call = call
)
}
res <- x[names(ref)]
Expand Down
9 changes: 7 additions & 2 deletions R/contr_one_hot.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#' This contrast function produces a model matrix with indicator columns for
#' each level of each factor.
#'
#' @param n A vector of character factor levels or the number of unique levels.
#' @param n A vector of character factor levels (of length >=1) or the number
#' of unique levels (>= 1).
#' @param contrasts This argument is for backwards compatibility and only the
#' default of `TRUE` is supported.
#' @param sparse This argument is for backwards compatibility and only the
Expand All @@ -24,9 +25,13 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {
}

if (is.character(n)) {
if (length(n) < 1) {
cli::cli_abort("{.arg n} cannot be empty.")
}
names <- n
n <- length(names)
} else if (is.numeric(n)) {
check_number_whole(n, min = 1)
n <- as.integer(n)

if (length(n) != 1L) {
Expand All @@ -35,7 +40,7 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {

names <- as.character(seq_len(n))
} else {
cli::cli_abort("{.arg n} must be a character vector or an integer of size 1.")
check_number_whole(n, min = 1)
}

out <- diag(n)
Expand Down
28 changes: 17 additions & 11 deletions R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,21 @@
na.action = na.omit,
indicators = "traditional",
composition = "data.frame",
remove_intercept = TRUE) {
remove_intercept = TRUE,
call = rlang::caller_env()) {
if (!(composition %in% c("data.frame", "matrix", "dgCMatrix"))) {
cli::cli_abort(
"{.arg composition} should be either {.val data.frame}, {.val matrix}, or
{.val dgCMatrix}."
{.val dgCMatrix}.",
call = call
)
}

if (sparsevctrs::has_sparse_elements(data)) {
cli::cli_abort(
"Sparse data cannot be used with formula interface. Please use
{.fn fit_xy} instead."
"Sparse data cannot be used with formula interface. Please use
{.fn fit_xy} instead.",
call = call
)
}

Expand Down Expand Up @@ -84,7 +87,7 @@

w <- as.vector(model.weights(mod_frame))
if (!is.null(w) && !is.numeric(w)) {
cli::cli_abort("{.arg weights} must be a numeric vector.")
cli::cli_abort("{.arg weights} must be a numeric vector.", call = call)
}

# TODO: Do we actually use the offset when fitting?
Expand Down Expand Up @@ -175,10 +178,12 @@
.convert_form_to_xy_new <- function(object,
new_data,
na.action = na.pass,
composition = "data.frame") {
composition = "data.frame",
call = rlang::caller_env()) {
if (!(composition %in% c("data.frame", "matrix"))) {
cli::cli_abort(
"{.arg composition} should be either {.val data.frame} or {.val matrix}."
"{.arg composition} should be either {.val data.frame} or {.val matrix}.",
call = call
)
}

Expand Down Expand Up @@ -244,9 +249,10 @@
y,
weights = NULL,
y_name = "..y",
remove_intercept = TRUE) {
remove_intercept = TRUE,
call = rlang::caller_env()) {
if (is.vector(x)) {
cli::cli_abort("{.arg x} cannot be a vector.")
cli::cli_abort("{.arg x} cannot be a vector.", call = call)
}

if (remove_intercept) {
Expand Down Expand Up @@ -279,10 +285,10 @@

if (!is.null(weights)) {
if (!is.numeric(weights)) {
cli::cli_abort("{.arg weights} must be a numeric vector.")
cli::cli_abort("{.arg weights} must be a numeric vector.", call = call)
}
if (length(weights) != nrow(x)) {
cli::cli_abort("{.arg weights} should have {nrow(x)} elements.")
cli::cli_abort("{.arg weights} should have {nrow(x)} elements.", call = call)
}

form <- patch_formula_environment_with_case_weights(
Expand Down
16 changes: 9 additions & 7 deletions R/descriptors.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,23 @@ NULL

# Descriptor retrievers --------------------------------------------------------

get_descr_form <- function(formula, data) {
get_descr_form <- function(formula, data, call = rlang::caller_env()) {
if (inherits(data, "tbl_spark")) {
res <- get_descr_spark(formula, data)
} else {
res <- get_descr_df(formula, data)
res <- get_descr_df(formula, data, call = call)
}
res
}

get_descr_df <- function(formula, data) {
get_descr_df <- function(formula, data, call = rlang::caller_env()) {

tmp_dat <-
.convert_form_to_xy_fit(formula,
data,
indicators = "none",
remove_intercept = TRUE)
remove_intercept = TRUE,
call = call)

if(is.factor(tmp_dat$y)) {
.lvls <- function() {
Expand All @@ -136,7 +137,8 @@ get_descr_df <- function(formula, data) {
formula,
data,
indicators = "traditional",
remove_intercept = TRUE
remove_intercept = TRUE,
call = call
)$x
)
}
Expand Down Expand Up @@ -263,7 +265,7 @@ get_descr_spark <- function(formula, data) {
)
}

get_descr_xy <- function(x, y) {
get_descr_xy <- function(x, y, call = rlang::caller_env()) {

.lvls <- if (is.factor(y)) {
function() table(y, dnn = NULL)
Expand Down Expand Up @@ -291,7 +293,7 @@ get_descr_xy <- function(x, y) {
}

.dat <- function() {
.convert_xy_to_form_fit(x, y, remove_intercept = TRUE)$data
.convert_xy_to_form_fit(x, y, remove_intercept = TRUE, call = call)$data
}

.x <- function() {
Expand Down
5 changes: 3 additions & 2 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ fit.model_spec <-
}

if (all(c("x", "y") %in% names(dots))) {
cli::cli_abort("`fit.model_spec()` is for the formula methods. Use `fit_xy()` instead.")
cli::cli_abort("{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead.")
}
cl <- match.call(expand.dots = TRUE)
# Create an environment with the evaluated argument objects. This will be
Expand Down Expand Up @@ -307,7 +307,8 @@ fit_xy.model_spec <-

if (object$engine == "spark") {
cli::cli_abort(
"spark objects can only be used with the formula interface to {.fn fit} with a spark data object."
"spark objects can only be used with the formula interface to {.fn fit}
with a spark data object."
)
}

Expand Down
9 changes: 5 additions & 4 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ form_form <-

# if descriptors are needed, update descr_env with the calculated values
if (requires_descrs(object)) {
data_stats <- get_descr_form(env$formula, env$data)
data_stats <- get_descr_form(env$formula, env$data, call = call)
scoped_descrs(data_stats)
}

Expand Down Expand Up @@ -86,7 +86,7 @@ xy_xy <- function(object,

# if descriptors are needed, update descr_env with the calculated values
if (requires_descrs(object)) {
data_stats <- get_descr_xy(env$x, env$y)
data_stats <- get_descr_xy(env$x, env$y, call = call)
scoped_descrs(data_stats)
}

Expand All @@ -96,7 +96,7 @@ xy_xy <- function(object,
# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)

fit_call <- make_xy_call(object, target, env)
fit_call <- make_xy_call(object, target, env, call)

res <- list(lvl = levels(env$y), spec = object)

Expand Down Expand Up @@ -141,7 +141,8 @@ form_xy <- function(object, control, env,
...,
composition = target,
indicators = indicators,
remove_intercept = remove_intercept
remove_intercept = remove_intercept,
call = call
)
env$x <- data_obj$x
env$y <- data_obj$y
Expand Down
Loading

0 comments on commit 58e4329

Please sign in to comment.