Skip to content

Commit

Permalink
add rlang type checkers (#950)
Browse files Browse the repository at this point in the history
* add type checking files

* remove newly unneeded checking functions

* snapshot updates from tidymodels/recipes#1381

* updates files

* basic replacements

* type checker replacements

* tidymodels/tailor#53

* Update R/checks.R

Co-authored-by: Simon P. Couch <[email protected]>

* add remote to get proper error messages

* typo

* update remotes?

* only test snapshots with more recent version of R *with* rankdeficient

---------

Co-authored-by: Simon P. Couch <[email protected]>
  • Loading branch information
topepo and simonpcouch authored Oct 23, 2024
1 parent bc5422a commit 1184068
Show file tree
Hide file tree
Showing 23 changed files with 1,144 additions and 243 deletions.
15 changes: 9 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Version: 1.2.1.9000
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
person(given = "Posit Software, PBC", role = c("cph", "fnd"))
person("Posit Software, PBC", role = c("cph", "fnd"))
)
Description: The ability to tune models is important. 'tune' contains
functions and classes to be used in conjunction with other
Expand All @@ -27,12 +27,12 @@ Imports:
ggplot2,
glue (>= 1.6.2),
GPfit,
hardhat (>= 1.2.0),
hardhat (>= 1.4.0.9002),
lifecycle (>= 1.0.0),
parsnip (>= 1.2.0),
parsnip (>= 1.2.1.9003),
purrr (>= 1.0.0),
recipes (>= 1.0.4),
rlang (>= 1.1.0),
recipes (>= 1.1.0.9001),
rlang (>= 1.1.4),
rsample (>= 1.2.1.9000),
tailor,
tibble (>= 3.1.0),
Expand All @@ -57,8 +57,11 @@ Suggests:
xgboost,
xml2
Remotes:
tidymodels/hardhat,
tidymodels/parsnip,
tidymodels/recipes,
tidymodels/rsample,
tidymodels/tailor,
tidymodels/tailor,
tidymodels/workflows
Config/Needs/website: pkgdown, tidymodels, kknn, doParallel, doFuture,
tidyverse/tidytemplate
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ export(tune_bayes)
export(tune_grid)
export(val_class_and_single)
export(val_class_or_null)
import(rlang)
import(vctrs)
import(workflows)
importFrom(GPfit,GP_fit)
Expand Down
2 changes: 1 addition & 1 deletion R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#' @importFrom cli cli_inform cli_warn cli_abort qty
#' @importFrom foreach foreach getDoParName %dopar%
#' @importFrom tibble obj_sum size_sum

#' @import rlang

# ------------------------------------------------------------------------------
# Only a small number of functions in workflows.
Expand Down
10 changes: 5 additions & 5 deletions R/acquisition.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ print.prob_improve <- function(x, ...) {
#' @export
predict.prob_improve <-
function(object, new_data, maximize, iter, best, ...) {
check_direction(maximize)
check_best(best)
check_bool(maximize)
check_number_decimal(best, allow_infinite = FALSE)

if (is.function(object$trade_off)) {
trade_off <- object$trade_off(iter)
Expand Down Expand Up @@ -126,8 +126,8 @@ exp_improve <- function(trade_off = 0, eps = .Machine$double.eps) {

#' @export
predict.exp_improve <- function(object, new_data, maximize, iter, best, ...) {
check_direction(maximize)
check_best(best)
check_bool(maximize)
check_number_decimal(best, allow_infinite = FALSE)

if (is.function(object$trade_off)) {
trade_off <- object$trade_off(iter)
Expand Down Expand Up @@ -177,7 +177,7 @@ conf_bound <- function(kappa = 0.1) {

#' @export
predict.conf_bound <- function(object, new_data, maximize, iter, ...) {
check_direction(maximize)
check_bool(maximize)

if (is.function(object$kappa)) {
kappa <- object$kappa(iter)
Expand Down
31 changes: 10 additions & 21 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -493,26 +493,6 @@ get_objective_name <- function(x, metrics) {
x
}


# ------------------------------------------------------------------------------
# acq functions

check_direction <- function(x) {
if (!is.logical(x) || length(x) != 1) {
rlang::abort("`maximize` should be a single logical.")
}
invisible(NULL)
}


check_best <- function(x) {
if (!is.numeric(x) || length(x) != 1 || is.na(x)) {
rlang::abort("`best` should be a single, non-missing numeric.")
}
invisible(NULL)
}


# ------------------------------------------------------------------------------

check_class_or_null <- function(x, cls = "numeric") {
Expand All @@ -537,6 +517,7 @@ val_class_or_null <- function(x, cls = "numeric", where = NULL) {
}
invisible(NULL)
}
# TODO remove this once finetune is updated

check_class_and_single <- function(x, cls = "numeric") {
isTRUE(inherits(x, cls) & length(x) == 1)
Expand All @@ -558,7 +539,7 @@ val_class_and_single <- function(x, cls = "numeric", where = NULL) {
}
invisible(NULL)
}

# TODO remove this once finetune is updated

# Check the data going into the GP. If there are all missing values, fail. If some
# are missing, remove them and send a warning. If all metrics are the same, fail.
Expand Down Expand Up @@ -644,3 +625,11 @@ check_eval_time <- function(eval_time, metrics) {
invisible(NULL)

}

check_time_limit_arg <- function(x, call = rlang::caller_env()) {
if (!inherits(x, c("logical", "numeric")) || length(x) != 1L) {
cli::cli_abort("{.arg time_limit} should be either a single numeric or
logical value.", call = call)
}
invisible(NULL)
}
5 changes: 1 addition & 4 deletions R/compute_metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ compute_metrics.tune_results <- function(x,
summarize = TRUE,
event_level = "first") {
rlang::check_dots_empty()
check_bool(summarize)
if (!".predictions" %in% names(x)) {
rlang::abort(paste0(
"`x` must have been generated with the ",
Expand Down Expand Up @@ -114,10 +115,6 @@ compute_metrics.tune_results <- function(x,
))
}

if (!inherits(summarize, "logical") || length(summarize) != 1L) {
rlang::abort("The `summarize` argument must be a single logical value.")
}

param_names <- .get_tune_parameter_names(x)
outcome_name <- .get_tune_outcome_names(x)

Expand Down
55 changes: 27 additions & 28 deletions R/control.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ control_grid <- function(verbose = FALSE, allow_par = TRUE,
# Any added arguments should also be added in superset control functions
# in other packages

# add options for seeds per resample
# add options for seeds per resample
check_bool(verbose)
check_bool(allow_par)
check_bool(save_pred)
check_bool(save_workflow)
check_string(event_level)
check_character(pkgs, allow_null = TRUE)
check_function(extract, allow_null = TRUE)

val_class_and_single(verbose, "logical", "control_grid()")
val_class_and_single(allow_par, "logical", "control_grid()")
val_class_and_single(save_pred, "logical", "control_grid()")
val_class_and_single(save_workflow, "logical", "control_grid()")
val_class_and_single(event_level, "character", "control_grid()")
val_class_or_null(pkgs, "character", "control_grid()")
val_class_or_null(extract, "function", "control_grid()")
val_parallel_over(parallel_over, "control_grid()")


Expand Down Expand Up @@ -241,26 +241,27 @@ control_bayes <-
# in other packages

# add options for seeds per resample
check_bool(verbose)
check_bool(verbose_iter)
check_bool(allow_par)
check_bool(save_pred)
check_bool(save_workflow)
check_bool(save_gp_scoring)
check_character(pkgs, allow_null = TRUE)
check_function(extract, allow_null = TRUE)
check_number_whole(no_improve, min = 0, allow_infinite = TRUE)
check_number_whole(uncertain, min = 0, allow_infinite = TRUE)
check_number_whole(seed)

check_time_limit_arg(time_limit)

val_class_and_single(verbose, "logical", "control_bayes()")
val_class_and_single(verbose_iter, "logical", "control_bayes()")
val_class_and_single(save_pred, "logical", "control_bayes()")
val_class_and_single(save_gp_scoring, "logical", "control_bayes()")
val_class_and_single(save_workflow, "logical", "control_bayes()")
val_class_and_single(no_improve, c("numeric", "integer"), "control_bayes()")
val_class_and_single(uncertain, c("numeric", "integer"), "control_bayes()")
val_class_and_single(seed, c("numeric", "integer"), "control_bayes()")
val_class_or_null(extract, "function", "control_bayes()")
val_class_and_single(time_limit, c("logical", "numeric"), "control_bayes()")
val_class_or_null(pkgs, "character", "control_bayes()")
val_class_and_single(event_level, "character", "control_bayes()")
val_parallel_over(parallel_over, "control_bayes()")
val_class_and_single(allow_par, "logical", "control_bayes()")


if (!is.infinite(uncertain) && uncertain > no_improve) {
cli::cli_alert_warning(
"Uncertainty sample scheduled after {uncertain} poor iterations but the search will stop after {no_improve}."
cli::cli_warn(
"Uncertainty sample scheduled after {uncertain} poor iterations but the
search will stop after {no_improve}."
)
}

Expand Down Expand Up @@ -296,13 +297,11 @@ print.control_bayes <- function(x, ...) {
# ------------------------------------------------------------------------------

val_parallel_over <- function(parallel_over, where) {
if (is.null(parallel_over)) {
return(invisible(NULL))
check_string(parallel_over, allow_null = TRUE)
if (!is.null(parallel_over)) {
rlang::arg_match0(parallel_over, c("resamples", "everything"), "parallel_over")
}

val_class_and_single(parallel_over, "character", where)
rlang::arg_match0(parallel_over, c("resamples", "everything"), "parallel_over")

invisible(NULL)
}

Expand Down
4 changes: 1 addition & 3 deletions R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ extract_spec_parsnip.tune_results <- function(x, ...) {
#' @rdname extract-tune
extract_recipe.tune_results <- function(x, ..., estimated = TRUE) {
check_empty_dots(...)
if (!rlang::is_bool(estimated)) {
rlang::abort("`estimated` must be a single `TRUE` or `FALSE`.")
}
check_bool(estimated)
extract_recipe(extract_workflow(x), estimated = estimated)
}
check_empty_dots <- function(...) {
Expand Down
Loading

0 comments on commit 1184068

Please sign in to comment.