Skip to content

Commit

Permalink
Merge pull request #522 from tidymodels/pass-call
Browse files Browse the repository at this point in the history
pass calls around
  • Loading branch information
EmilHvitfeldt authored Oct 29, 2024
2 parents f560b94 + df2366c commit 2b0c096
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 40 deletions.
3 changes: 2 additions & 1 deletion R/aaa-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,8 @@ validate_function_class <- function(fns) {
"*" = "a mix of dynamic and static survival metrics.",
"i" = "The following metric function types are being mixed:",
fn_pastable
))
),
call = rlang::call2("metric_set"))
}

# Safely evaluate metrics in such a way that we can capture the
Expand Down
7 changes: 4 additions & 3 deletions R/aaa-new.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ new_static_survival_metric <- function(fn, direction) {
}

#' @include import-standalone-types-check.R
new_metric <- function(fn, direction, class = NULL) {
check_function(fn)
new_metric <- function(fn, direction, class = NULL, call = caller_env()) {
check_function(fn, call = call)

direction <- arg_match(
direction,
values = c("maximize", "minimize", "zero")
values = c("maximize", "minimize", "zero"),
error_call = call
)

class <- c(class, "metric", "function")
Expand Down
4 changes: 2 additions & 2 deletions R/conf_mat.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ conf_mat.grouped_df <- function(data,
}

conf_mat_impl <- function(truth, estimate, case_weights, call = caller_env()) {
abort_if_class_pred(truth)
estimate <- as_factor_from_class_pred(estimate)
abort_if_class_pred(truth, call = call)
estimate <- as_factor_from_class_pred(estimate, call = call)

estimator <- "not binary"
check_class_metric(truth, estimate, case_weights, estimator, call = call)
Expand Down
5 changes: 3 additions & 2 deletions R/metric-tweak.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ metric_tweak <- function(.name, .fn, ...) {

# ------------------------------------------------------------------------------

check_protected_names <- function(fixed) {
check_protected_names <- function(fixed, call = caller_env()) {
protected <- protected_names()
has_protected_name <- any(names(fixed) %in% protected)

Expand All @@ -94,7 +94,8 @@ check_protected_names <- function(fixed) {
}

cli::cli_abort(
"Arguments passed through {.arg ...} cannot be named any of: {protected}."
"Arguments passed through {.arg ...} cannot be named any of: {.arg {protected}}.",
call = call
)
}

Expand Down
5 changes: 3 additions & 2 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ is_class_pred <- function(x) {
inherits(x, "class_pred")
}

as_factor_from_class_pred <- function(x) {
as_factor_from_class_pred <- function(x, call) {
if (!is_class_pred(x)) {
return(x)
}
Expand All @@ -79,7 +79,8 @@ as_factor_from_class_pred <- function(x) {
cli::cli_abort(
"A {.cls class_pred} input was detected, but the {.pkg probably}
package isn't installed. Install {.pkg probably} to be able to convert
{.cls class_pred} to {.cls factor}."
{.cls class_pred} to {.cls factor}.",
call = call
)
}
probably::as.factor(x)
Expand Down
3 changes: 3 additions & 0 deletions R/prob-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ one_vs_all_impl <- function(fn,
truth,
estimate,
case_weights,
call,
...) {
lvls <- levels(truth)
other <- "..other"
Expand Down Expand Up @@ -76,12 +77,14 @@ one_vs_all_with_level <- function(fn,
truth,
estimate,
case_weights,
call,
...) {
res <- one_vs_all_impl(
fn = fn,
truth = truth,
estimate = estimate,
case_weights = case_weights,
call = call,
...
)

Expand Down
8 changes: 5 additions & 3 deletions R/prob-roc_auc.R
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ warn_roc_truth_no_control <- function(control) {
stop_roc_truth_no_control <- function(control) {
cli::cli_abort(
msg_roc_truth_no_control(control),
class = "yardstick_error_roc_truth_no_control"
class = "yardstick_error_roc_truth_no_control",
call = call("roc_curve")
)
}

Expand All @@ -415,9 +416,10 @@ warn_roc_truth_no_event <- function(event) {
class = "yardstick_warning_roc_truth_no_event"
)
}
stop_roc_truth_no_event <- function(event) {
stop_roc_truth_no_event <- function(event, call = caller_env()) {
cli::cli_abort(
msg_roc_truth_no_event(event),
class = "yardstick_error_roc_truth_no_event"
class = "yardstick_error_roc_truth_no_event",
call = call
)
}
15 changes: 9 additions & 6 deletions R/prob-roc_curve.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,20 @@ roc_curve_estimator_impl <- function(truth,
estimate,
estimator,
event_level,
case_weights) {
case_weights,
call = caller_env()) {
if (is_binary(estimator)) {
roc_curve_binary(truth, estimate, event_level, case_weights)
roc_curve_binary(truth, estimate, event_level, case_weights, call)
} else {
roc_curve_multiclass(truth, estimate, case_weights)
roc_curve_multiclass(truth, estimate, case_weights, call)
}
}

roc_curve_binary <- function(truth,
estimate,
event_level,
case_weights) {
case_weights,
call) {
lvls <- levels(truth)

if (!is_event_first(event_level)) {
Expand All @@ -153,7 +155,7 @@ roc_curve_binary <- function(truth,
control <- lvls[[2]]

if (compute_n_occurrences(truth, event) == 0L) {
stop_roc_truth_no_event(event)
stop_roc_truth_no_event(event, call)
}
if (compute_n_occurrences(truth, control) == 0L) {
stop_roc_truth_no_control(control)
Expand Down Expand Up @@ -197,7 +199,8 @@ roc_curve_binary <- function(truth,
# One-VS-All approach
roc_curve_multiclass <- function(truth,
estimate,
case_weights) {
case_weights,
call) {
one_vs_all_with_level(
fn = roc_curve_binary,
truth = truth,
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/_snaps/aaa-metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
Code
metric_set(rmse, accuracy)
Condition
Error in `validate_function_class()`:
Error in `metric_set()`:
x The combination of metric functions must be:
* only numeric metrics.
* a mix of class metrics and class probability metrics.
Expand All @@ -72,7 +72,7 @@
Code
metric_set(accuracy, foobar, sens, rlang::abort)
Condition
Error in `validate_function_class()`:
Error in `metric_set()`:
x The combination of metric functions must be:
* only numeric metrics.
* a mix of class metrics and class probability metrics.
Expand All @@ -86,7 +86,7 @@
Code
metric_set(accuracy, foobar, sens, rlang::abort)
Condition
Error in `validate_function_class()`:
Error in `metric_set()`:
x The combination of metric functions must be:
* only numeric metrics.
* a mix of class metrics and class probability metrics.
Expand All @@ -100,7 +100,7 @@
Code
metric_set(foobar)
Condition
Error in `validate_function_class()`:
Error in `metric_set()`:
x The combination of metric functions must be:
* only numeric metrics.
* a mix of class metrics and class probability metrics.
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/aaa-new.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
Code
new_class_metric(1, "maximize")
Condition
Error in `new_metric()`:
Error in `new_class_metric()`:
! `fn` must be a function, not the number 1.

# `direction` is validated

Code
new_class_metric(function() 1, "min")
Condition
Error in `new_metric()`:
Error in `new_class_metric()`:
! `direction` must be one of "maximize", "minimize", or "zero", not "min".
i Did you mean "minimize"?

Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/_snaps/class-bal_accuracy.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# work with class_pred input

Code
bal_accuracy(cp_truth, cp_estimate)
bal_accuracy_vec(cp_truth, cp_estimate)
Condition
Error in `UseMethod()`:
! no applicable method for 'bal_accuracy' applied to an object of class "c('class_pred', 'vctrs_vctr')"
Error in `bal_accuracy_vec()`:
! `truth` should not a <class_pred> object.

12 changes: 6 additions & 6 deletions tests/testthat/_snaps/metric-tweak.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
Code
metric_tweak("f_meas2", f_meas, data = 2)
Condition
Error in `check_protected_names()`:
! Arguments passed through `...` cannot be named any of: data, truth, and estimate.
Error in `metric_tweak()`:
! Arguments passed through `...` cannot be named any of: `data`, `truth`, and `estimate`.

---

Code
metric_tweak("f_meas2", f_meas, truth = 2)
Condition
Error in `check_protected_names()`:
! Arguments passed through `...` cannot be named any of: data, truth, and estimate.
Error in `metric_tweak()`:
! Arguments passed through `...` cannot be named any of: `data`, `truth`, and `estimate`.

---

Code
metric_tweak("f_meas2", f_meas, estimate = 2)
Condition
Error in `check_protected_names()`:
! Arguments passed through `...` cannot be named any of: data, truth, and estimate.
Error in `metric_tweak()`:
! Arguments passed through `...` cannot be named any of: `data`, `truth`, and `estimate`.

# `name` must be a string

Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/_snaps/prob-roc_curve.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
Code
roc_curve_vec(no_event$truth, no_event$Class1)[[".estimate"]]
Condition
Error in `stop_roc_truth_no_event()`:
Error in `roc_curve_vec()`:
! No event observations were detected in `truth` with event level 'Class1'.

# roc_curve() - error is thrown when missing controls

Code
roc_curve_vec(no_control$truth, no_control$Class1)[[".estimate"]]
Condition
Error in `stop_roc_truth_no_control()`:
Error in `roc_curve()`:
! No control observations were detected in `truth` with control level 'Class2'.

# roc_curve() - multiclass one-vs-all approach results in error
Expand All @@ -20,7 +20,7 @@
roc_curve_vec(no_event$obs, as.matrix(dplyr::select(no_event, VF:L)))[[
".estimate"]]
Condition
Error in `stop_roc_truth_no_control()`:
Error in `roc_curve()`:
! No control observations were detected in `truth` with control level '..other'.

# roc_curve() - `options` is deprecated
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/probably.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
Code
conf_mat(cp_hpc_cv, obs, pred)
Condition
Error in `conf_mat_impl()`:
Error in `conf_mat()`:
! `truth` should not a <class_pred> object.

---

Code
conf_mat(dplyr::group_by(cp_hpc_cv, Resample), obs, pred)
Condition
Error in `conf_mat_impl()`:
Error in `conf_mat()`:
! `truth` should not a <class_pred> object.

# `class_pred` errors when passed to `metrics()`
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-class-bal_accuracy.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ test_that("work with class_pred input", {

expect_snapshot(
error = TRUE,
bal_accuracy(cp_truth, cp_estimate)
bal_accuracy_vec(cp_truth, cp_estimate)
)
})

Expand Down

0 comments on commit 2b0c096

Please sign in to comment.