Skip to content

Commit

Permalink
Changes for #956
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Nov 13, 2024
1 parent a212f78 commit 7b95f20
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
21 changes: 21 additions & 0 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,27 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
# evaluated value for the parameter.
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
}

# ------------------------------------------------------------------------------
# We want to avoid folks passing in a poisson family instead of using
# poisson_reg(). It's hard to detect this.

is_fam <- names(x$eng_args) == "family"
if (any(is_fam)) {
eng_args <- rlang::eval_tidy(x$eng_args[[which(is_fam)]])
if (is.function(eng_args)) {
eng_args <- try(eng_args(), silent = TRUE)
}
if (inherits(eng_args, "family")) {
eng_args <- eng_args$family
}
if (eng_args == "poisson") {
cli::cli_abort(
"A Poisson family was requested for {.fn linear_reg}. Please use
{.fn poisson_reg} and the engines in the {.pkg poissonreg} package.",
call = rlang::call2("linear_reg"))
}
}
x
}

Expand Down
36 changes: 36 additions & 0 deletions tests/testthat/_snaps/linear_reg.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,39 @@
Error in `fit()`:
! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1.

# Poisson family (#956)

Code
linear_reg(penalty = 1) %>% set_engine("glmnet", family = poisson) %>%
translate()
Condition
Error in `linear_reg()`:
! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package.

---

Code
linear_reg(penalty = 1) %>% set_engine("glmnet", family = stats::poisson) %>%
translate()
Condition
Error in `linear_reg()`:
! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package.

---

Code
linear_reg(penalty = 1) %>% set_engine("glmnet", family = stats::poisson()) %>%
translate()
Condition
Error in `linear_reg()`:
! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package.

---

Code
linear_reg(penalty = 1) %>% set_engine("glmnet", family = "poisson") %>%
translate()
Condition
Error in `linear_reg()`:
! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package.

30 changes: 30 additions & 0 deletions tests/testthat/test-linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,33 @@ test_that("check_args() works", {
}
)
})


test_that('Poisson family (#956)', {
expect_snapshot(
linear_reg(penalty = 1) %>%
set_engine("glmnet", family = poisson) %>%
translate(),
error = TRUE
)
expect_snapshot(
linear_reg(penalty = 1) %>%
set_engine("glmnet", family = stats::poisson) %>%
translate(),
error = TRUE
)
expect_snapshot(
linear_reg(penalty = 1) %>%
set_engine("glmnet", family = stats::poisson()) %>%
translate(),
error = TRUE
)
expect_snapshot(
linear_reg(penalty = 1) %>%
set_engine("glmnet", family = "poisson") %>%
translate(),
error = TRUE
)


})

0 comments on commit 7b95f20

Please sign in to comment.