Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass calls around #522

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion R/aaa-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,8 @@ validate_function_class <- function(fns) {
"*" = "a mix of dynamic and static survival metrics.",
"i" = "The following metric function types are being mixed:",
fn_pastable
))
),
call = rlang::call2("metric_set"))
}

# Safely evaluate metrics in such a way that we can capture the
Expand Down
7 changes: 4 additions & 3 deletions R/aaa-new.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ new_static_survival_metric <- function(fn, direction) {
}

#' @include import-standalone-types-check.R
new_metric <- function(fn, direction, class = NULL) {
check_function(fn)
new_metric <- function(fn, direction, class = NULL, call = caller_env()) {
check_function(fn, call = call)

direction <- arg_match(
direction,
values = c("maximize", "minimize", "zero")
values = c("maximize", "minimize", "zero"),
error_call = call
)

class <- c(class, "metric", "function")
Expand Down
4 changes: 2 additions & 2 deletions R/conf_mat.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ conf_mat.grouped_df <- function(data,
}

conf_mat_impl <- function(truth, estimate, case_weights, call = caller_env()) {
abort_if_class_pred(truth)
estimate <- as_factor_from_class_pred(estimate)
abort_if_class_pred(truth, call = call)
estimate <- as_factor_from_class_pred(estimate, call = call)

estimator <- "not binary"
check_class_metric(truth, estimate, case_weights, estimator, call = call)
Expand Down
5 changes: 3 additions & 2 deletions R/metric-tweak.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ metric_tweak <- function(.name, .fn, ...) {

# ------------------------------------------------------------------------------

check_protected_names <- function(fixed) {
check_protected_names <- function(fixed, call = caller_env()) {
protected <- protected_names()
has_protected_name <- any(names(fixed) %in% protected)

Expand All @@ -94,7 +94,8 @@ check_protected_names <- function(fixed) {
}

cli::cli_abort(
"Arguments passed through {.arg ...} cannot be named any of: {protected}."
"Arguments passed through {.arg ...} cannot be named any of: {protected}.",
EmilHvitfeldt marked this conversation as resolved.
Show resolved Hide resolved
call = call
)
}

Expand Down
5 changes: 3 additions & 2 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ is_class_pred <- function(x) {
inherits(x, "class_pred")
}

as_factor_from_class_pred <- function(x) {
as_factor_from_class_pred <- function(x, call) {
if (!is_class_pred(x)) {
return(x)
}
Expand All @@ -79,7 +79,8 @@ as_factor_from_class_pred <- function(x) {
cli::cli_abort(
"A {.cls class_pred} input was detected, but the {.pkg probably}
package isn't installed. Install {.pkg probably} to be able to convert
{.cls class_pred} to {.cls factor}."
{.cls class_pred} to {.cls factor}.",
call = call
)
}
probably::as.factor(x)
Expand Down
3 changes: 3 additions & 0 deletions R/prob-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ one_vs_all_impl <- function(fn,
truth,
estimate,
case_weights,
call,
...) {
lvls <- levels(truth)
other <- "..other"
Expand Down Expand Up @@ -76,12 +77,14 @@ one_vs_all_with_level <- function(fn,
truth,
estimate,
case_weights,
call,
...) {
res <- one_vs_all_impl(
fn = fn,
truth = truth,
estimate = estimate,
case_weights = case_weights,
call = call,
...
)

Expand Down
8 changes: 5 additions & 3 deletions R/prob-roc_auc.R
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ warn_roc_truth_no_control <- function(control) {
stop_roc_truth_no_control <- function(control) {
cli::cli_abort(
msg_roc_truth_no_control(control),
class = "yardstick_error_roc_truth_no_control"
class = "yardstick_error_roc_truth_no_control",
call = call("roc_curve")
)
}

Expand All @@ -415,9 +416,10 @@ warn_roc_truth_no_event <- function(event) {
class = "yardstick_warning_roc_truth_no_event"
)
}
stop_roc_truth_no_event <- function(event) {
stop_roc_truth_no_event <- function(event, call = caller_env()) {
cli::cli_abort(
msg_roc_truth_no_event(event),
class = "yardstick_error_roc_truth_no_event"
class = "yardstick_error_roc_truth_no_event",
call = call
)
}
15 changes: 9 additions & 6 deletions R/prob-roc_curve.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,20 @@ roc_curve_estimator_impl <- function(truth,
estimate,
estimator,
event_level,
case_weights) {
case_weights,
call = caller_env()) {
if (is_binary(estimator)) {
roc_curve_binary(truth, estimate, event_level, case_weights)
roc_curve_binary(truth, estimate, event_level, case_weights, call)
} else {
roc_curve_multiclass(truth, estimate, case_weights)
roc_curve_multiclass(truth, estimate, case_weights, call)
}
}

roc_curve_binary <- function(truth,
estimate,
event_level,
case_weights) {
case_weights,
call) {
lvls <- levels(truth)

if (!is_event_first(event_level)) {
Expand All @@ -153,7 +155,7 @@ roc_curve_binary <- function(truth,
control <- lvls[[2]]

if (compute_n_occurrences(truth, event) == 0L) {
stop_roc_truth_no_event(event)
stop_roc_truth_no_event(event, call)
}
if (compute_n_occurrences(truth, control) == 0L) {
stop_roc_truth_no_control(control)
Expand Down Expand Up @@ -197,7 +199,8 @@ roc_curve_binary <- function(truth,
# One-VS-All approach
roc_curve_multiclass <- function(truth,
estimate,
case_weights) {
case_weights,
call) {
one_vs_all_with_level(
fn = roc_curve_binary,
truth = truth,
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/_snaps/aaa-metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
Code
metric_set(rmse, accuracy)
Condition
Error in `validate_function_class()`:
Error in `metric_set()`:
x The combination of metric functions must be:
* only numeric metrics.
* a mix of class metrics and class probability metrics.
Expand All @@ -72,7 +72,7 @@
Code
metric_set(accuracy, foobar, sens, rlang::abort)
Condition
Error in `validate_function_class()`:
Error in `metric_set()`:
x The combination of metric functions must be:
* only numeric metrics.
* a mix of class metrics and class probability metrics.
Expand All @@ -86,7 +86,7 @@
Code
metric_set(accuracy, foobar, sens, rlang::abort)
Condition
Error in `validate_function_class()`:
Error in `metric_set()`:
x The combination of metric functions must be:
* only numeric metrics.
* a mix of class metrics and class probability metrics.
Expand All @@ -100,7 +100,7 @@
Code
metric_set(foobar)
Condition
Error in `validate_function_class()`:
Error in `metric_set()`:
x The combination of metric functions must be:
* only numeric metrics.
* a mix of class metrics and class probability metrics.
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/aaa-new.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
Code
new_class_metric(1, "maximize")
Condition
Error in `new_metric()`:
Error in `new_class_metric()`:
! `fn` must be a function, not the number 1.

# `direction` is validated

Code
new_class_metric(function() 1, "min")
Condition
Error in `new_metric()`:
Error in `new_class_metric()`:
! `direction` must be one of "maximize", "minimize", or "zero", not "min".
i Did you mean "minimize"?

Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/_snaps/class-bal_accuracy.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# work with class_pred input

Code
bal_accuracy(cp_truth, cp_estimate)
bal_accuracy_vec(cp_truth, cp_estimate)
Condition
Error in `UseMethod()`:
! no applicable method for 'bal_accuracy' applied to an object of class "c('class_pred', 'vctrs_vctr')"
Error in `bal_accuracy_vec()`:
! `truth` should not a <class_pred> object.

6 changes: 3 additions & 3 deletions tests/testthat/_snaps/metric-tweak.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
Code
metric_tweak("f_meas2", f_meas, data = 2)
Condition
Error in `check_protected_names()`:
Error in `metric_tweak()`:
! Arguments passed through `...` cannot be named any of: data, truth, and estimate.

---

Code
metric_tweak("f_meas2", f_meas, truth = 2)
Condition
Error in `check_protected_names()`:
Error in `metric_tweak()`:
! Arguments passed through `...` cannot be named any of: data, truth, and estimate.

---

Code
metric_tweak("f_meas2", f_meas, estimate = 2)
Condition
Error in `check_protected_names()`:
Error in `metric_tweak()`:
! Arguments passed through `...` cannot be named any of: data, truth, and estimate.

# `name` must be a string
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/_snaps/prob-roc_curve.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
Code
roc_curve_vec(no_event$truth, no_event$Class1)[[".estimate"]]
Condition
Error in `stop_roc_truth_no_event()`:
Error in `roc_curve_vec()`:
! No event observations were detected in `truth` with event level 'Class1'.

# roc_curve() - error is thrown when missing controls

Code
roc_curve_vec(no_control$truth, no_control$Class1)[[".estimate"]]
Condition
Error in `stop_roc_truth_no_control()`:
Error in `roc_curve()`:
! No control observations were detected in `truth` with control level 'Class2'.

# roc_curve() - multiclass one-vs-all approach results in error
Expand All @@ -20,7 +20,7 @@
roc_curve_vec(no_event$obs, as.matrix(dplyr::select(no_event, VF:L)))[[
".estimate"]]
Condition
Error in `stop_roc_truth_no_control()`:
Error in `roc_curve()`:
! No control observations were detected in `truth` with control level '..other'.

# roc_curve() - `options` is deprecated
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/probably.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
Code
conf_mat(cp_hpc_cv, obs, pred)
Condition
Error in `conf_mat_impl()`:
Error in `conf_mat()`:
! `truth` should not a <class_pred> object.

---

Code
conf_mat(dplyr::group_by(cp_hpc_cv, Resample), obs, pred)
Condition
Error in `conf_mat_impl()`:
Error in `conf_mat()`:
! `truth` should not a <class_pred> object.

# `class_pred` errors when passed to `metrics()`
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-class-bal_accuracy.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ test_that("work with class_pred input", {

expect_snapshot(
error = TRUE,
bal_accuracy(cp_truth, cp_estimate)
bal_accuracy_vec(cp_truth, cp_estimate)
)
})

Expand Down
Loading