Skip to content

Commit

Permalink
Remove intercept coming from workflows based on engine encodings (#1033)
Browse files Browse the repository at this point in the history
* for intercepts coming from workflows (details in the PR description)

* bump version for extratests

* update NEWS
  • Loading branch information
hfrick authored Dec 7, 2023
1 parent d65dde2 commit 0372970
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.1.1.9003
Version: 1.1.1.9004
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

* When computing censoring weights, the resulting vectors are no longer named (#1023).

* Fixed a bug in the integration with workflows where using a model formula with a formula preprocessor could result in a double intercept (#1033).


# parsnip 1.1.1

* Fixed bug where prediction on rank deficient `lm()` models produced `.pred_res` instead of `.pred`. (#985)
Expand Down
4 changes: 4 additions & 0 deletions R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
rlang::abort("`composition` should be either 'data.frame' or 'matrix'.")
}

if (remove_intercept) {
data <- data[, colnames(data) != "(Intercept)", drop = FALSE]
}

## Assemble model.frame call from call arguments
mf_call <- quote(model.frame(formula, data))
mf_call$na.action <- match.call()$na.action # TODO this should work better
Expand Down
9 changes: 9 additions & 0 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
form_form <-
function(object, control, env, ...) {

encoding_info <-
get_encoding(class(object)[1]) %>%
dplyr::filter(mode == object$mode, engine == object$engine)

remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept)
if (remove_intercept) {
env$data <- env$data[, colnames(env$data) != "(Intercept)", drop = FALSE]
}

if (inherits(env$data, "data.frame")) {
check_outcome(eval_tidy(rlang::f_lhs(env$formula), env$data), object)
}
Expand Down

0 comments on commit 0372970

Please sign in to comment.