diff --git a/R/grid_helpers.R b/R/grid_helpers.R index 6fcad549..3cfd0dca 100644 --- a/R/grid_helpers.R +++ b/R/grid_helpers.R @@ -319,28 +319,49 @@ compute_grid_info <- function(workflow, grid) { res <- min_grid(extract_spec_parsnip(workflow), grid) if (any_parameters_preprocessor) { - res$.iter_preprocessor <- seq_len(nrow(res)) + res$.iter_preprocessor <- + vctrs::vec_group_id(res[parameters_preprocessor$id]) + attr(res$.iter_preprocessor, "n") <- NULL + model_iters_needed <- duplicated(res$.iter_preprocessor) } else { res$.iter_preprocessor <- 1L + model_iters_needed <- TRUE } res$.msg_preprocessor <- new_msgs_preprocessor( - seq_len(max(res$.iter_preprocessor)), + res$.iter_preprocessor, max(res$.iter_preprocessor) ) - if (nrow(res) != nrow(grid) || - (any_parameters_model && !any_parameters_preprocessor)) { - res$.iter_model <- seq_len(dplyr::n_distinct(res[parameters_model$id])) - } else { - res$.iter_model <- 1L + res$.iter_model <- 1L + + if (any_parameters_model) { + model_iter_ids <- vctrs::vec_group_id( + res[model_iters_needed, parameters_model$id] + ) + res$.iter_model[model_iters_needed] <- model_iter_ids + 1 + } + + if (isTRUE(model_iters_needed)) { + res$.iter_model <- res$.iter_model - 1 } res$.iter_config <- list(list()) + shift_submodels <- integer(length(unique(res$.iter_preprocessor))) for (row in seq_len(nrow(res))) { - res$.iter_config[row] <- list(iter_config(res[row, ])) + res_row <- res[row, ] + iter_config <- iter_config( + res_row, + shift = shift_submodels[res_row$.iter_preprocessor] + ) + shift_submodels[res_row$.iter_preprocessor] <- + shift_submodels[res_row$.iter_preprocessor] + + length(res_row$.submodels[[1]]) + res$.iter_config[row] <- list(iter_config) } + res$.iter_config <- format_.iter_config(res$.iter_config) + res$.msg_model <- new_msgs_model(i = res$.iter_model, n = max(res$.iter_model), res$.msg_preprocessor) @@ -348,22 +369,34 @@ compute_grid_info <- function(workflow, grid) { res } -iter_config <- function(res_row) { +iter_config <- function(res_row, shift) { submodels <- res_row$.submodels[[1]] - if (identical(submodels, list())) { - models <- res_row$.iter_model - } else { - models <- seq_len(length(submodels[[1]]) + 1) + model_configs <- res_row$.iter_model + if (!identical(submodels, list())) { + model_configs <- model_configs + seq_len(length(submodels[[1]]) + 1L) - 1 } - paste0( - "Preprocessor", - res_row$.iter_preprocessor, - "_Model", - format_with_padding(models) + # return separately so new .iter_model can be passed through + # `format_with_padding with all iterations` + list( + config = paste0( + "Preprocessor", + res_row$.iter_preprocessor, + "_Model" + ), + model_config = shift + model_configs ) } +format_.iter_config <- function(x) { + res <- dplyr::bind_rows(x, .id = "idx") + res$model_config <- format_with_padding(res$model_config) + res$.iter_config <- paste0(res$config, res$model_config) + res <- tidyr::nest(res, .by = "idx") + + purrr::map(res$data, ~.x$.iter_config) +} + # This generates a "dummy" grid_info object that has the same # structure as a grid-info object with no tunable recipe parameters # and no tunable model parameters. diff --git a/tests/testthat/test-grid_helpers.R b/tests/testthat/test-grid_helpers.R index e55fc4a2..3c0dcfcf 100644 --- a/tests/testthat/test-grid_helpers.R +++ b/tests/testthat/test-grid_helpers.R @@ -169,3 +169,161 @@ test_that("compute_grid_info - recipe and model (with submodels)", { ) expect_equal(nrow(res), 3) }) + +test_that("compute_grid_info - recipe and model (with and without submodels)", { + library(workflows) + library(parsnip) + library(recipes) + library(dials) + + rec <- recipe(mpg ~ ., mtcars) %>% step_spline_natural(deg_free = tune()) + spec <- boost_tree(mode = "regression", trees = tune(), loss_reduction = tune()) + + wflow <- workflow() + wflow <- add_model(wflow, spec) + wflow <- add_recipe(wflow, rec) + + # use grid_regular to (partially) trigger submodel trick + set.seed(1) + param_set <- extract_parameter_set_dials(wflow) + grid <- bind_rows(grid_regular(param_set), grid_space_filling(param_set)) + res <- compute_grid_info(wflow, grid) + + expect_equal(length(unique(res$.iter_preprocessor)), 5) + expect_equal( + unique(res$.msg_preprocessor), + paste0("preprocessor ", 1:5, "/5") + ) + expect_equal(res$trees, c(rep(max(grid$trees), 10), 1)) + expect_equal(unique(res$.iter_model), 1:3) + expect_equal( + res$.iter_config[1:3], + list( + c("Preprocessor1_Model1", "Preprocessor1_Model2", "Preprocessor1_Model3", "Preprocessor1_Model4"), + c("Preprocessor2_Model1", "Preprocessor2_Model2", "Preprocessor2_Model3"), + c("Preprocessor3_Model1", "Preprocessor3_Model2", "Preprocessor3_Model3") + ) + ) + expect_equal(res$.msg_model[1:3], paste0("preprocessor ", 1:3, "/5, model 1/3")) + expect_equal( + res$.submodels[1:3], + list( + list(trees = c(1L, 1000L, 1000L)), + list(trees = c(1L, 1000L)), + list(trees = c(1L, 1000L)) + ) + ) + expect_named( + res, + c(".iter_preprocessor", ".msg_preprocessor", "deg_free", "trees", + "loss_reduction", ".iter_model", ".iter_config", ".msg_model", ".submodels"), + ignore.order = TRUE + ) + expect_equal(nrow(res), 11) +}) + +test_that("compute_grid_info - model (with and without submodels)", { + library(workflows) + library(parsnip) + library(recipes) + library(dials) + + rec <- recipe(mpg ~ ., mtcars) + spec <- mars(num_terms = tune(), prod_degree = tune(), prune_method = tune()) %>% + set_mode("classification") %>% + set_engine("earth") + + wflow <- workflow() + wflow <- add_model(wflow, spec) + wflow <- add_recipe(wflow, rec) + + set.seed(123) + params_grid <- grid_space_filling( + num_terms() %>% range_set(c(1L, 12L)), + prod_degree(), + prune_method(values = c("backward", "none", "forward")), + size = 7, + type = "latin_hypercube" + ) + + res <- compute_grid_info(wflow, params_grid) + + expect_equal(res$.iter_preprocessor, rep(1, 5)) + expect_equal(res$.msg_preprocessor, rep("preprocessor 1/1", 5)) + expect_equal(length(unique(res$num_terms)), 5) + expect_equal(res$.iter_model, 1:5) + expect_equal( + res$.iter_config, + list( + c("Preprocessor1_Model1", "Preprocessor1_Model2"), + c("Preprocessor1_Model3", "Preprocessor1_Model4"), + "Preprocessor1_Model5", "Preprocessor1_Model6", "Preprocessor1_Model7" + ) + ) + expect_equal( + unique(res$.msg_model), + paste0("preprocessor 1/1, model ", 1:5,"/5") + ) + expect_equal( + res$.submodels, + list( + list(num_terms = c(1)), + list(num_terms = c(3)), + list(), list(), list() + ) + ) + expect_named( + res, + c(".iter_preprocessor", ".msg_preprocessor", "num_terms", "prod_degree", + "prune_method", ".iter_model", ".iter_config", ".msg_model", ".submodels"), + ignore.order = TRUE + ) + expect_equal(nrow(res), 5) +}) + +test_that("compute_grid_info - recipe and model (no submodels but has inner grid)", { + library(workflows) + library(parsnip) + library(recipes) + library(dials) + + set.seed(1) + + helper_objects <- helper_objects_tune() + + wflow <- workflow() %>% + add_recipe(helper_objects$rec_tune_1) %>% + add_model(helper_objects$svm_mod) + + pset <- extract_parameter_set_dials(wflow) %>% + update(num_comp = dials::num_comp(c(1, 3))) + + grid <- dials::grid_regular(pset, levels = 3) + + res <- compute_grid_info(wflow, grid) + + expect_equal(res$.iter_preprocessor, rep(1:3, each = 3)) + expect_equal(res$.msg_preprocessor, rep(paste0("preprocessor ", 1:3, "/3"), each = 3)) + expect_equal(res$.iter_model, rep(1:3, times = 3)) + expect_equal( + res$.iter_config, + as.list(paste0( + rep(paste0("Preprocessor", 1:3, "_Model"), each = 3), + rep(1:3, times = 3) + )) + ) + expect_equal( + unique(res$.msg_model), + paste0( + rep(paste0("preprocessor ", 1:3, "/3, model "), each = 3), + paste0(rep(1:3, times = 3), "/3") + ) + ) + expect_named( + res, + c("cost", "num_comp", ".submodels", ".iter_preprocessor", ".msg_preprocessor", + ".iter_model", ".iter_config", ".msg_model"), + ignore.order = TRUE + ) + expect_equal(nrow(res), 9) +})