From 24ada80e08965ac4aa520a9abdbf5e437be58c22 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Mon, 4 Nov 2024 09:46:25 -0600 Subject: [PATCH] address `compute_grid_info()` bug for partially regular grids --- R/grid_helpers.R | 31 +++++--- tests/testthat/test-grid_helpers.R | 112 +++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 10 deletions(-) diff --git a/R/grid_helpers.R b/R/grid_helpers.R index 6fcad549..1cc1a641 100644 --- a/R/grid_helpers.R +++ b/R/grid_helpers.R @@ -319,27 +319,39 @@ compute_grid_info <- function(workflow, grid) { res <- min_grid(extract_spec_parsnip(workflow), grid) if (any_parameters_preprocessor) { - res$.iter_preprocessor <- seq_len(nrow(res)) + res$.iter_preprocessor <- + vctrs::vec_group_id(res[parameters_preprocessor$id]) + attr(res$.iter_preprocessor, "n") <- NULL } else { res$.iter_preprocessor <- 1L } res$.msg_preprocessor <- new_msgs_preprocessor( - seq_len(max(res$.iter_preprocessor)), + res$.iter_preprocessor, max(res$.iter_preprocessor) ) if (nrow(res) != nrow(grid) || (any_parameters_model && !any_parameters_preprocessor)) { - res$.iter_model <- seq_len(dplyr::n_distinct(res[parameters_model$id])) + res$.iter_model <- vctrs::vec_group_id(res[parameters_model$id]) + attr(res$.iter_model, "n") <- NULL } else { res$.iter_model <- 1L } res$.iter_config <- list(list()) + shift_submodels <- integer(length(unique(res$.iter_preprocessor))) for (row in seq_len(nrow(res))) { - res$.iter_config[row] <- list(iter_config(res[row, ])) + res_row <- res[row, ] + iter_config <- iter_config( + res_row, + shift = shift_submodels[res_row$.iter_preprocessor] + ) + shift_submodels[res_row$.iter_preprocessor] <- + shift_submodels[res_row$.iter_preprocessor] + + length(res_row$.submodels[[1]]) + res$.iter_config[row] <- list(iter_config) } res$.msg_model <- @@ -348,19 +360,18 @@ compute_grid_info <- function(workflow, grid) { res } -iter_config <- function(res_row) { +iter_config <- function(res_row, shift) { submodels <- res_row$.submodels[[1]] - if (identical(submodels, list())) { - models <- res_row$.iter_model - } else { - models <- seq_len(length(submodels[[1]]) + 1) + model_configs <- res_row$.iter_model + if (!identical(submodels, list())) { + model_configs <- model_configs + seq_len(length(submodels[[1]]) + 1L) - 1 } paste0( "Preprocessor", res_row$.iter_preprocessor, "_Model", - format_with_padding(models) + format_with_padding(shift + model_configs) ) } diff --git a/tests/testthat/test-grid_helpers.R b/tests/testthat/test-grid_helpers.R index e55fc4a2..e08e055e 100644 --- a/tests/testthat/test-grid_helpers.R +++ b/tests/testthat/test-grid_helpers.R @@ -169,3 +169,115 @@ test_that("compute_grid_info - recipe and model (with submodels)", { ) expect_equal(nrow(res), 3) }) + +test_that("compute_grid_info - recipe and model (with and without submodels)", { + library(workflows) + library(parsnip) + library(recipes) + library(dials) + + rec <- recipe(mpg ~ ., mtcars) %>% step_spline_natural(deg_free = tune()) + spec <- boost_tree(mode = "regression", trees = tune(), loss_reduction = tune()) + + wflow <- workflow() + wflow <- add_model(wflow, spec) + wflow <- add_recipe(wflow, rec) + + # use grid_regular to (partially) trigger submodel trick + set.seed(1) + param_set <- extract_parameter_set_dials(wflow) + grid <- bind_rows(grid_regular(param_set), grid_space_filling(param_set)) + res <- compute_grid_info(wflow, grid) + + expect_equal(length(unique(res$.iter_preprocessor)), 5) + expect_equal( + unique(res$.msg_preprocessor), + paste0("preprocessor ", 1:5, "/5") + ) + expect_equal(res$trees, c(rep(max(grid$trees), 10), 1)) + expect_equal(res$.iter_model, c(rep(1:3, each = 3), 4, 5)) + expect_equal( + res$.iter_config[1:3], + list( + c("Preprocessor1_Model1", "Preprocessor1_Model2", "Preprocessor1_Model3", "Preprocessor1_Model4"), + c("Preprocessor2_Model1", "Preprocessor2_Model2", "Preprocessor2_Model3"), + c("Preprocessor3_Model1", "Preprocessor3_Model2", "Preprocessor3_Model3") + ) + ) + expect_equal(res$.msg_model[1:3], paste0("preprocessor ", 1:3, "/5, model 1/5")) + expect_equal( + res$.submodels[1:3], + list( + list(trees = c(1L, 1000L, 1000L)), + list(trees = c(1L, 1000L)), + list(trees = c(1L, 1000L)) + ) + ) + expect_named( + res, + c(".iter_preprocessor", ".msg_preprocessor", "deg_free", "trees", + "loss_reduction", ".iter_model", ".iter_config", ".msg_model", ".submodels"), + ignore.order = TRUE + ) + expect_equal(nrow(res), 11) +}) + +test_that("compute_grid_info - model (with and without submodels)", { + library(workflows) + library(parsnip) + library(recipes) + library(dials) + + rec <- recipe(mpg ~ ., mtcars) + spec <- mars(num_terms = tune(), prod_degree = tune(), prune_method = tune()) %>% + set_mode("classification") %>% + set_engine("earth") + + wflow <- workflow() + wflow <- add_model(wflow, spec) + wflow <- add_recipe(wflow, rec) + + set.seed(123) + params_grid <- grid_space_filling( + num_terms() %>% range_set(c(1L, 12L)), + prod_degree(), + prune_method(values = c("backward", "none", "forward")), + size = 7, + type = "latin_hypercube" + ) + + res <- compute_grid_info(wflow, params_grid) + + expect_equal(res$.iter_preprocessor, rep(1, 5)) + expect_equal(res$.msg_preprocessor, rep("preprocessor 1/1", 5)) + expect_equal(length(unique(res$num_terms)), 5) + expect_equal(res$.iter_model, 1:5) + expect_equal( + res$.iter_config, + list( + c("Preprocessor1_Model1", "Preprocessor1_Model2"), + c("Preprocessor1_Model3", "Preprocessor1_Model4"), + "Preprocessor1_Model5", "Preprocessor1_Model6", "Preprocessor1_Model7" + ) + ) + expect_equal( + unique(res$.msg_model), + paste0("preprocessor 1/1, model ", 1:5,"/5") + ) + expect_equal( + res$.submodels, + list( + list(num_terms = c(1)), + list(num_terms = c(3)), + list(), list(), list() + ) + ) + expect_named( + res, + c(".iter_preprocessor", ".msg_preprocessor", "num_terms", "prod_degree", + "prune_method", ".iter_model", ".iter_config", ".msg_model", ".submodels"), + ignore.order = TRUE + ) + expect_equal(nrow(res), 5) +}) +