From 3228b2a46810c8e6dcf9018e5a3fe0980eff03af Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 25 Oct 2024 18:29:32 -0700 Subject: [PATCH 1/3] pass calls --- R/aaa-metrics.R | 3 ++- R/aaa-new.R | 7 ++++--- R/conf_mat.R | 4 ++-- R/metric-tweak.R | 5 +++-- R/misc.R | 5 +++-- R/prob-helpers.R | 3 +++ R/prob-roc_auc.R | 8 +++++--- R/prob-roc_curve.R | 15 +++++++++------ tests/testthat/_snaps/aaa-metrics.md | 8 ++++---- tests/testthat/_snaps/aaa-new.md | 4 ++-- tests/testthat/_snaps/class-bal_accuracy.md | 6 +++--- tests/testthat/_snaps/metric-tweak.md | 6 +++--- tests/testthat/_snaps/prob-roc_curve.md | 6 +++--- tests/testthat/_snaps/probably.md | 4 ++-- tests/testthat/test-class-bal_accuracy.R | 2 +- 15 files changed, 49 insertions(+), 37 deletions(-) diff --git a/R/aaa-metrics.R b/R/aaa-metrics.R index eb6345fd..b0041f3a 100644 --- a/R/aaa-metrics.R +++ b/R/aaa-metrics.R @@ -717,7 +717,8 @@ validate_function_class <- function(fns) { "*" = "a mix of dynamic and static survival metrics.", "i" = "The following metric function types are being mixed:", fn_pastable - )) + ), + call = rlang::call2("metric_set")) } # Safely evaluate metrics in such a way that we can capture the diff --git a/R/aaa-new.R b/R/aaa-new.R index c8e51ee6..c1beba78 100644 --- a/R/aaa-new.R +++ b/R/aaa-new.R @@ -65,12 +65,13 @@ new_static_survival_metric <- function(fn, direction) { } #' @include import-standalone-types-check.R -new_metric <- function(fn, direction, class = NULL) { - check_function(fn) +new_metric <- function(fn, direction, class = NULL, call = caller_env()) { + check_function(fn, call = call) direction <- arg_match( direction, - values = c("maximize", "minimize", "zero") + values = c("maximize", "minimize", "zero"), + error_call = call ) class <- c(class, "metric", "function") diff --git a/R/conf_mat.R b/R/conf_mat.R index c3b3303d..0d14263c 100644 --- a/R/conf_mat.R +++ b/R/conf_mat.R @@ -216,8 +216,8 @@ conf_mat.grouped_df <- function(data, } conf_mat_impl <- function(truth, estimate, case_weights, call = caller_env()) { - abort_if_class_pred(truth) - estimate <- as_factor_from_class_pred(estimate) + abort_if_class_pred(truth, call = call) + estimate <- as_factor_from_class_pred(estimate, call = call) estimator <- "not binary" check_class_metric(truth, estimate, case_weights, estimator, call = call) diff --git a/R/metric-tweak.R b/R/metric-tweak.R index 2a295848..f3908deb 100644 --- a/R/metric-tweak.R +++ b/R/metric-tweak.R @@ -85,7 +85,7 @@ metric_tweak <- function(.name, .fn, ...) { # ------------------------------------------------------------------------------ -check_protected_names <- function(fixed) { +check_protected_names <- function(fixed, call = caller_env()) { protected <- protected_names() has_protected_name <- any(names(fixed) %in% protected) @@ -94,7 +94,8 @@ check_protected_names <- function(fixed) { } cli::cli_abort( - "Arguments passed through {.arg ...} cannot be named any of: {protected}." + "Arguments passed through {.arg ...} cannot be named any of: {protected}.", + call = call ) } diff --git a/R/misc.R b/R/misc.R index 60fd5bdc..76acf64d 100644 --- a/R/misc.R +++ b/R/misc.R @@ -70,7 +70,7 @@ is_class_pred <- function(x) { inherits(x, "class_pred") } -as_factor_from_class_pred <- function(x) { +as_factor_from_class_pred <- function(x, call) { if (!is_class_pred(x)) { return(x) } @@ -79,7 +79,8 @@ as_factor_from_class_pred <- function(x) { cli::cli_abort( "A {.cls class_pred} input was detected, but the {.pkg probably} package isn't installed. Install {.pkg probably} to be able to convert - {.cls class_pred} to {.cls factor}." + {.cls class_pred} to {.cls factor}.", + call = call ) } probably::as.factor(x) diff --git a/R/prob-helpers.R b/R/prob-helpers.R index d118866e..0ef1633b 100644 --- a/R/prob-helpers.R +++ b/R/prob-helpers.R @@ -37,6 +37,7 @@ one_vs_all_impl <- function(fn, truth, estimate, case_weights, + call, ...) { lvls <- levels(truth) other <- "..other" @@ -76,12 +77,14 @@ one_vs_all_with_level <- function(fn, truth, estimate, case_weights, + call, ...) { res <- one_vs_all_impl( fn = fn, truth = truth, estimate = estimate, case_weights = case_weights, + call = call, ... ) diff --git a/R/prob-roc_auc.R b/R/prob-roc_auc.R index 09e080b7..883d6a3a 100644 --- a/R/prob-roc_auc.R +++ b/R/prob-roc_auc.R @@ -399,7 +399,8 @@ warn_roc_truth_no_control <- function(control) { stop_roc_truth_no_control <- function(control) { cli::cli_abort( msg_roc_truth_no_control(control), - class = "yardstick_error_roc_truth_no_control" + class = "yardstick_error_roc_truth_no_control", + call = call("roc_curve") ) } @@ -415,9 +416,10 @@ warn_roc_truth_no_event <- function(event) { class = "yardstick_warning_roc_truth_no_event" ) } -stop_roc_truth_no_event <- function(event) { +stop_roc_truth_no_event <- function(event, call = caller_env()) { cli::cli_abort( msg_roc_truth_no_event(event), - class = "yardstick_error_roc_truth_no_event" + class = "yardstick_error_roc_truth_no_event", + call = call ) } diff --git a/R/prob-roc_curve.R b/R/prob-roc_curve.R index e4d3d43c..6b4947f8 100644 --- a/R/prob-roc_curve.R +++ b/R/prob-roc_curve.R @@ -131,18 +131,20 @@ roc_curve_estimator_impl <- function(truth, estimate, estimator, event_level, - case_weights) { + case_weights, + call = caller_env()) { if (is_binary(estimator)) { - roc_curve_binary(truth, estimate, event_level, case_weights) + roc_curve_binary(truth, estimate, event_level, case_weights, call) } else { - roc_curve_multiclass(truth, estimate, case_weights) + roc_curve_multiclass(truth, estimate, case_weights, call) } } roc_curve_binary <- function(truth, estimate, event_level, - case_weights) { + case_weights, + call) { lvls <- levels(truth) if (!is_event_first(event_level)) { @@ -153,7 +155,7 @@ roc_curve_binary <- function(truth, control <- lvls[[2]] if (compute_n_occurrences(truth, event) == 0L) { - stop_roc_truth_no_event(event) + stop_roc_truth_no_event(event, call) } if (compute_n_occurrences(truth, control) == 0L) { stop_roc_truth_no_control(control) @@ -197,7 +199,8 @@ roc_curve_binary <- function(truth, # One-VS-All approach roc_curve_multiclass <- function(truth, estimate, - case_weights) { + case_weights, + call) { one_vs_all_with_level( fn = roc_curve_binary, truth = truth, diff --git a/tests/testthat/_snaps/aaa-metrics.md b/tests/testthat/_snaps/aaa-metrics.md index db5b793d..c866f26f 100644 --- a/tests/testthat/_snaps/aaa-metrics.md +++ b/tests/testthat/_snaps/aaa-metrics.md @@ -48,7 +48,7 @@ Code metric_set(rmse, accuracy) Condition - Error in `validate_function_class()`: + Error in `metric_set()`: x The combination of metric functions must be: * only numeric metrics. * a mix of class metrics and class probability metrics. @@ -72,7 +72,7 @@ Code metric_set(accuracy, foobar, sens, rlang::abort) Condition - Error in `validate_function_class()`: + Error in `metric_set()`: x The combination of metric functions must be: * only numeric metrics. * a mix of class metrics and class probability metrics. @@ -86,7 +86,7 @@ Code metric_set(accuracy, foobar, sens, rlang::abort) Condition - Error in `validate_function_class()`: + Error in `metric_set()`: x The combination of metric functions must be: * only numeric metrics. * a mix of class metrics and class probability metrics. @@ -100,7 +100,7 @@ Code metric_set(foobar) Condition - Error in `validate_function_class()`: + Error in `metric_set()`: x The combination of metric functions must be: * only numeric metrics. * a mix of class metrics and class probability metrics. diff --git a/tests/testthat/_snaps/aaa-new.md b/tests/testthat/_snaps/aaa-new.md index aadb6ade..06f0902f 100644 --- a/tests/testthat/_snaps/aaa-new.md +++ b/tests/testthat/_snaps/aaa-new.md @@ -3,7 +3,7 @@ Code new_class_metric(1, "maximize") Condition - Error in `new_metric()`: + Error in `new_class_metric()`: ! `fn` must be a function, not the number 1. # `direction` is validated @@ -11,7 +11,7 @@ Code new_class_metric(function() 1, "min") Condition - Error in `new_metric()`: + Error in `new_class_metric()`: ! `direction` must be one of "maximize", "minimize", or "zero", not "min". i Did you mean "minimize"? diff --git a/tests/testthat/_snaps/class-bal_accuracy.md b/tests/testthat/_snaps/class-bal_accuracy.md index 36d48f36..71624fb3 100644 --- a/tests/testthat/_snaps/class-bal_accuracy.md +++ b/tests/testthat/_snaps/class-bal_accuracy.md @@ -1,8 +1,8 @@ # work with class_pred input Code - bal_accuracy(cp_truth, cp_estimate) + bal_accuracy_vec(cp_truth, cp_estimate) Condition - Error in `UseMethod()`: - ! no applicable method for 'bal_accuracy' applied to an object of class "c('class_pred', 'vctrs_vctr')" + Error in `bal_accuracy_vec()`: + ! `truth` should not a object. diff --git a/tests/testthat/_snaps/metric-tweak.md b/tests/testthat/_snaps/metric-tweak.md index 047db20a..507587c6 100644 --- a/tests/testthat/_snaps/metric-tweak.md +++ b/tests/testthat/_snaps/metric-tweak.md @@ -3,7 +3,7 @@ Code metric_tweak("f_meas2", f_meas, data = 2) Condition - Error in `check_protected_names()`: + Error in `metric_tweak()`: ! Arguments passed through `...` cannot be named any of: data, truth, and estimate. --- @@ -11,7 +11,7 @@ Code metric_tweak("f_meas2", f_meas, truth = 2) Condition - Error in `check_protected_names()`: + Error in `metric_tweak()`: ! Arguments passed through `...` cannot be named any of: data, truth, and estimate. --- @@ -19,7 +19,7 @@ Code metric_tweak("f_meas2", f_meas, estimate = 2) Condition - Error in `check_protected_names()`: + Error in `metric_tweak()`: ! Arguments passed through `...` cannot be named any of: data, truth, and estimate. # `name` must be a string diff --git a/tests/testthat/_snaps/prob-roc_curve.md b/tests/testthat/_snaps/prob-roc_curve.md index 7e868a4a..1d915d5a 100644 --- a/tests/testthat/_snaps/prob-roc_curve.md +++ b/tests/testthat/_snaps/prob-roc_curve.md @@ -3,7 +3,7 @@ Code roc_curve_vec(no_event$truth, no_event$Class1)[[".estimate"]] Condition - Error in `stop_roc_truth_no_event()`: + Error in `roc_curve_vec()`: ! No event observations were detected in `truth` with event level 'Class1'. # roc_curve() - error is thrown when missing controls @@ -11,7 +11,7 @@ Code roc_curve_vec(no_control$truth, no_control$Class1)[[".estimate"]] Condition - Error in `stop_roc_truth_no_control()`: + Error in `roc_curve()`: ! No control observations were detected in `truth` with control level 'Class2'. # roc_curve() - multiclass one-vs-all approach results in error @@ -20,7 +20,7 @@ roc_curve_vec(no_event$obs, as.matrix(dplyr::select(no_event, VF:L)))[[ ".estimate"]] Condition - Error in `stop_roc_truth_no_control()`: + Error in `roc_curve()`: ! No control observations were detected in `truth` with control level '..other'. # roc_curve() - `options` is deprecated diff --git a/tests/testthat/_snaps/probably.md b/tests/testthat/_snaps/probably.md index 46e43eeb..469b6c11 100644 --- a/tests/testthat/_snaps/probably.md +++ b/tests/testthat/_snaps/probably.md @@ -11,7 +11,7 @@ Code conf_mat(cp_hpc_cv, obs, pred) Condition - Error in `conf_mat_impl()`: + Error in `conf_mat()`: ! `truth` should not a object. --- @@ -19,7 +19,7 @@ Code conf_mat(dplyr::group_by(cp_hpc_cv, Resample), obs, pred) Condition - Error in `conf_mat_impl()`: + Error in `conf_mat()`: ! `truth` should not a object. # `class_pred` errors when passed to `metrics()` diff --git a/tests/testthat/test-class-bal_accuracy.R b/tests/testthat/test-class-bal_accuracy.R index 22cd943e..f25e3f8c 100644 --- a/tests/testthat/test-class-bal_accuracy.R +++ b/tests/testthat/test-class-bal_accuracy.R @@ -87,7 +87,7 @@ test_that("work with class_pred input", { expect_snapshot( error = TRUE, - bal_accuracy(cp_truth, cp_estimate) + bal_accuracy_vec(cp_truth, cp_estimate) ) }) From c5de84a31f4b09789370bbf3ea80a59cc011f089 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 28 Oct 2024 21:36:36 -0700 Subject: [PATCH 2/3] Update R/metric-tweak.R Co-authored-by: Max Kuhn --- R/metric-tweak.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/metric-tweak.R b/R/metric-tweak.R index f3908deb..e81b9ef4 100644 --- a/R/metric-tweak.R +++ b/R/metric-tweak.R @@ -94,7 +94,7 @@ check_protected_names <- function(fixed, call = caller_env()) { } cli::cli_abort( - "Arguments passed through {.arg ...} cannot be named any of: {protected}.", + "Arguments passed through {.arg ...} cannot be named any of: {.arg {protected}}.", call = call ) } From df2366c660225557a9e4c09ba30bbd7e2f6cac2d Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 28 Oct 2024 21:40:51 -0700 Subject: [PATCH 3/3] update snapshots from review suggestions --- tests/testthat/_snaps/metric-tweak.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/testthat/_snaps/metric-tweak.md b/tests/testthat/_snaps/metric-tweak.md index 507587c6..beeecc52 100644 --- a/tests/testthat/_snaps/metric-tweak.md +++ b/tests/testthat/_snaps/metric-tweak.md @@ -4,7 +4,7 @@ metric_tweak("f_meas2", f_meas, data = 2) Condition Error in `metric_tweak()`: - ! Arguments passed through `...` cannot be named any of: data, truth, and estimate. + ! Arguments passed through `...` cannot be named any of: `data`, `truth`, and `estimate`. --- @@ -12,7 +12,7 @@ metric_tweak("f_meas2", f_meas, truth = 2) Condition Error in `metric_tweak()`: - ! Arguments passed through `...` cannot be named any of: data, truth, and estimate. + ! Arguments passed through `...` cannot be named any of: `data`, `truth`, and `estimate`. --- @@ -20,7 +20,7 @@ metric_tweak("f_meas2", f_meas, estimate = 2) Condition Error in `metric_tweak()`: - ! Arguments passed through `...` cannot be named any of: data, truth, and estimate. + ! Arguments passed through `...` cannot be named any of: `data`, `truth`, and `estimate`. # `name` must be a string