Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass weights to xgboost internal validation set #803

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

joeycouse
Copy link

In response to #771 (comment)

This pull request passes case weights to the interval validation set of xgboost. This causes test failures here:

expect_null(xgboost::getinfo(wted_val$watchlist$validation, "weight"))

and here.

expect_null(xgboost::getinfo(wted_val$watchlist$validation, "weight"))

Seems like the original intention was to not pass case weights to the internal validation set?

@topepo
Copy link
Member

topepo commented Aug 31, 2022

This depends on the type of weight:

  • Importance weights only affect the model estimation and supervised recipes steps. They are not used with yardstick functions for calculating measures of model performance.

  • Frequency weights are used for all parts of the preprocessing, model fitting, and performance estimation operations.

(This is form the blog post but it should be better documented in hardhat).

So we should do this but just for frequency weights (which are less likely to be used with boosting.

@joeycouse
Copy link
Author

@topepo @simonpcouch I've updated the PR to only pass freq weights to the internal validation set. This is done by delaying the conversion of the weights to numeric/integer till pass to as_xgb_data()

library(parsnip)

freq_weights <- hardhat::frequency_weights(1:32)

mtcar_x <- mtcars[, -1]
mtcar_mat <- as.matrix(mtcar_x)

set.seed(1)
val_freq_wts <- parsnip:::as_xgb_data(mtcar_mat, mtcars$mpg, weights = freq_weights, validation = 1/10)
xgboost::getinfo(val_freq_wts$watchlist$validation, "weight")
#> [1]  3 17 26


imp_wts <- hardhat::importance_weights(1:32)

mtcar_x <- mtcars[, -1]
mtcar_mat <- as.matrix(mtcar_x)

set.seed(1)
val_freq_wts <- parsnip:::as_xgb_data(mtcar_mat, mtcars$mpg, weights = imp_wts, validation = 1/10)
xgboost::getinfo(val_freq_wts$watchlist$validation, "weight")
#> NULL

Created on 2022-09-01 by the reprex package (v2.0.1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants