-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add fairness metrics #434
Merged
Merged
add fairness metrics #434
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
b98b2ac
rename `metrics.R` -> `aaa-metrics.R`
simonpcouch 92c232d
implement metric constructor
simonpcouch 44ef50b
implement 3 canonical fairness metrics
simonpcouch c260656
validate canonical metrics against fairlearn
simonpcouch 31e4a5f
add NEWS entry
simonpcouch 5922ae2
add pkgdown entry
simonpcouch 7236a14
skip error context snap pre-4.0 (currently oldrel-4)
simonpcouch 60973ba
remove unused test object
simonpcouch 6998108
remove `rlang::` and `vctrs::`
simonpcouch a9fc2e0
optimize `diff_range()` for speed
simonpcouch 27f53e9
transition from `summarize()`
simonpcouch 7fe7fb8
remove incomplete phrase
simonpcouch bdb7c5a
highlight function factory as output
simonpcouch ca6d7b8
restore passing ellipses to `.post()`
simonpcouch f1d663f
error informatively with redundant grouping
simonpcouch 45f57ff
don't pass ellipses to `.post`
simonpcouch 43037d9
run test on all systems
simonpcouch ef7f5a1
namespace `group_by` in test
simonpcouch 556f2c5
update snaps
simonpcouch 6fad6a4
defer to inputted `.fn` for metric class
simonpcouch a9231a9
check `direction` argument
simonpcouch 82717fa
move `max_positive_rate_diff()` to where it's used
simonpcouch a2dfe36
document implementation in `@description`
simonpcouch 375269b
abort with condition parent to improve error context
simonpcouch 76f2d4f
special-case metric factories in `metric_set()` checks
simonpcouch 587ff32
rename `fairness_metric()` -> `new_groupwise_metric()`
simonpcouch 05086ea
contrast "group-wise" and usual grouped behavior of yardstick metrics
simonpcouch 25a893f
rename `new_groupwise_metric()` arguments
simonpcouch 4436785
clarify documentation on `aggregate` arg
simonpcouch 12ff56a
use devl probably
EmilHvitfeldt 251f45b
Merge branch 'fairness' of github.com:tidymodels/yardstick into fairness
EmilHvitfeldt 6c0b76f
`aggregrate` -> `aggregate`
simonpcouch 219562c
remove `Remotes`---probably package is now on CRAN
simonpcouch 690e738
correct duplicated description
simonpcouch File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
#' Create group-wise metrics | ||
#' | ||
#' Group-wise metrics quantify the disparity in value of a metric across a | ||
#' number of groups. Group-wise metrics with a value of zero indicate that the | ||
#' underlying metric is equal across groups. yardstick defines | ||
#' several common fairness metrics using this function, such as | ||
#' [demographic_parity()], [equal_opportunity()], and [equalized_odds()]. | ||
#' | ||
#' Note that _all_ yardstick metrics are group-aware in that, when passed | ||
#' grouped data, they will return metric values calculated for each group. | ||
#' When passed grouped data, group-wise metrics also return metric values | ||
#' for each group, but those metric values are calculated by first additionally | ||
#' grouping by the variable passed to `by` and then summarizing the per-group | ||
#' metric estimates across groups using the function passed as the | ||
#' `aggregate` argument. | ||
#' | ||
#' @param fn A yardstick metric function or metric set. | ||
#' @param name The name of the metric to place in the `.metric` column | ||
#' of the output. | ||
#' @param aggregate A function to summarize the generated metric set results. | ||
#' The function takes metric set results as the first argument and returns | ||
#' a single numeric giving the `.estimate` value as output. See the Value and | ||
#' Examples sections for example uses. | ||
#' @inheritParams new_class_metric | ||
#' | ||
#' @section Relevant Group Level: | ||
#' Additional arguments can be passed to the function outputted by | ||
#' the function that this function outputs. That is: | ||
#' | ||
#' ``` | ||
#' res_fairness <- new_groupwise_metric(...) | ||
#' res_by <- res_fairness(by) | ||
#' res_by(..., additional_arguments_to_aggregate = TRUE) | ||
#' ``` | ||
#' | ||
#' For finer control of how groups in `by` are treated, use the | ||
#' `aggregate` argument. | ||
#' | ||
#' @return | ||
#' This function is a | ||
#' [function factory](https://adv-r.hadley.nz/function-factories.html); it's | ||
#' output is itself a function. Further, the functions that this function | ||
#' outputs are also function factories. More explicitly, this looks like: | ||
#' | ||
#' ``` | ||
#' # a function with similar implementation to `demographic_parity()`: | ||
#' diff_range <- function(x) {diff(range(x$.estimate))} | ||
#' | ||
#' dem_parity <- | ||
#' new_groupwise_metric( | ||
#' fn = detection_prevalence, | ||
#' name = "dem_parity", | ||
#' aggregate = diff_range | ||
#' ) | ||
#' ``` | ||
#' | ||
#' The outputted `dem_parity` is a function that takes one argument, `by`, | ||
#' indicating the data-masked variable giving the sensitive feature. | ||
#' | ||
#' When called with a `by` argument, `dem_parity` will return a yardstick | ||
#' metric function like any other: | ||
#' | ||
#' ``` | ||
#' dem_parity_by_gender <- dem_parity(gender) | ||
#' ``` | ||
#' | ||
#' Note that `dem_parity` doesn't take any arguments other than `by`, and thus | ||
#' knows nothing about the data it will be applied to other than that it ought | ||
#' to have a column with name `"gender"` in it. | ||
#' | ||
#' The output `dem_parity_by_gender` is a metric function that takes the | ||
#' same arguments as the function supplied as `fn`, in this case | ||
#' `detection_prevalence`. It will thus interface like any other yardstick | ||
#' function except that it will look for a `"gender"` column in | ||
#' the data it's supplied. | ||
#' | ||
#' In addition to the examples below, see the documentation on the | ||
#' return value of fairness metrics like [demographic_parity()], | ||
#' [equal_opportunity()], or [equalized_odds()] to learn more about how the | ||
#' output of this function can be used. | ||
#' | ||
#' @examples | ||
#' data(hpc_cv) | ||
#' | ||
#' # `demographic_parity`, among other fairness metrics, | ||
#' # is generated with `new_groupwise_metric()`: | ||
#' diff_range <- function(x) {diff(range(x$.estimate))} | ||
#' demographic_parity_ <- | ||
#' new_groupwise_metric( | ||
#' fn = detection_prevalence, | ||
#' name = "demographic_parity", | ||
#' aggregate = diff_range | ||
#' ) | ||
#' | ||
#' m_set <- metric_set(demographic_parity_(Resample)) | ||
#' | ||
#' m_set(hpc_cv, truth = obs, estimate = pred) | ||
#' | ||
#' # the `post` argument can be used to accommodate a wide | ||
#' # variety of parameterizations. to encode demographic | ||
#' # parity as a ratio inside of a difference, for example: | ||
#' ratio_range <- function(x, ...) { | ||
#' range <- range(x$.estimate) | ||
#' range[1] / range[2] | ||
#' } | ||
#' | ||
#' demographic_parity_ratio <- | ||
#' new_groupwise_metric( | ||
#' fn = detection_prevalence, | ||
#' name = "demographic_parity_ratio", | ||
#' aggregate = ratio_range | ||
#' ) | ||
#' | ||
#' @export | ||
new_groupwise_metric <- function(fn, name, aggregate, direction = "minimize") { | ||
if (is_missing(fn) || !inherits_any(fn, c("metric", "metric_set"))) { | ||
abort("`fn` must be a metric function or metric set.") | ||
} | ||
if (is_missing(name) || !is_string(name)) { | ||
abort("`name` must be a string.") | ||
} | ||
if (is_missing(aggregate) || !is_function(aggregate)) { | ||
abort("`aggregate` must be a function.") | ||
} | ||
arg_match( | ||
direction, | ||
values = c("maximize", "minimize", "zero") | ||
) | ||
|
||
metric_factory <- | ||
function(by) { | ||
by_str <- as_string(enexpr(by)) | ||
res <- | ||
function(data, ...) { | ||
gp_vars <- dplyr::group_vars(data) | ||
|
||
if (by_str %in% gp_vars) { | ||
cli::cli_abort( | ||
"Metric is internally grouped by {.field {by_str}}; grouping \\ | ||
{.arg data} by {.field {by_str}} is not well-defined." | ||
) | ||
} | ||
|
||
# error informatively when `fn` is a metric set; see `eval_safely()` | ||
data_grouped <- dplyr::group_by(data, {{by}}, .add = TRUE) | ||
res <- | ||
tryCatch( | ||
fn(data_grouped, ...), | ||
error = function(cnd) { | ||
if (!is.null(cnd$parent)) { | ||
cnd <- cnd$parent | ||
} | ||
|
||
abort(conditionMessage(cnd), call = call(name)) | ||
} | ||
) | ||
|
||
# restore to the grouping structure in the supplied data | ||
if (length(gp_vars) > 0) { | ||
res <- dplyr::group_by(res, !!!dplyr::groups(data), .add = FALSE) | ||
} | ||
|
||
group_rows <- dplyr::group_rows(res) | ||
group_keys <- dplyr::group_keys(res) | ||
res <- dplyr::ungroup(res) | ||
groups <- vec_chop(res, indices = group_rows) | ||
out <- vector("list", length = length(groups)) | ||
|
||
for (i in seq_along(groups)) { | ||
group <- groups[[i]] | ||
|
||
.estimate <- aggregate(group) | ||
|
||
if (!is_bare_numeric(.estimate)) { | ||
abort( | ||
"`aggregate` must return a single numeric value.", | ||
call = call2("new_groupwise_metric") | ||
) | ||
} | ||
|
||
elt_out <- list( | ||
.metric = name, | ||
.by = by_str, | ||
.estimator = group$.estimator[1], | ||
.estimate = .estimate | ||
) | ||
|
||
out[[i]] <- tibble::new_tibble(elt_out) | ||
} | ||
|
||
group_keys <- vctrs::vec_rep_each(group_keys, times = list_sizes(out)) | ||
out <- vec_rbind(!!!out) | ||
out <- vec_cbind(group_keys, out) | ||
|
||
out | ||
} | ||
res <- new_class_metric(res, direction = "minimize") | ||
|
||
structure( | ||
res, | ||
direction = direction, | ||
by = by_str, | ||
class = groupwise_metric_class(fn) | ||
) | ||
} | ||
|
||
structure(metric_factory, class = c("metric_factory", "function")) | ||
} | ||
|
||
groupwise_metric_class <- function(fn) { | ||
if (inherits(fn, "metric")) { | ||
return(class(fn)) | ||
} | ||
|
||
class(attr(fn, "metrics")[[1]]) | ||
} | ||
|
||
diff_range <- function(x) { | ||
estimates <- x$.estimate | ||
|
||
max(estimates) - min(estimates) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#' Demographic parity | ||
simonpcouch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#' | ||
#' @description | ||
#' Demographic parity is satisfied when a model's predictions have the | ||
#' same predicted positive rate across groups. A value of 0 indicates parity | ||
#' across groups. Note that this definition does not depend on the true | ||
#' outcome; the `truth` argument is included in outputted metrics | ||
#' for consistency. | ||
#' | ||
#' `demographic_parity()` is calculated as the difference between the largest | ||
#' and smallest value of [detection_prevalence()] across groups. | ||
#' | ||
#' Demographic parity is sometimes referred to as group fairness, | ||
#' disparate impact, or statistical parity. | ||
#' | ||
#' See the "Measuring Disparity" section for details on implementation. | ||
#' | ||
#' @param by The column identifier for the sensitive feature. This should be an | ||
#' unquoted column name referring to a column in the un-preprocessed data. | ||
#' | ||
#' @templateVar fn demographic_parity | ||
#' @templateVar internal_fn detection_prevalence | ||
#' @templateVar internal_fn [detection_prevalence()] | ||
#' @template return-fair | ||
#' @template event-fair | ||
#' @template examples-fair | ||
#' | ||
#' @family fairness metrics | ||
#' | ||
#' @references | ||
#' | ||
#' Agarwal, A., Beygelzimer, A., Dudik, M., Langford, J., & Wallach, H. (2018). | ||
#' "A Reductions Approach to Fair Classification." Proceedings of the 35th | ||
#' International Conference on Machine Learning, in Proceedings of Machine | ||
#' Learning Research. 80:60-69. | ||
#' | ||
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In | ||
#' Proceedings of the international workshop on software fairness (pp. 1-7). | ||
#' | ||
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker, | ||
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI". | ||
#' Microsoft, Tech. Rep. MSR-TR-2020-32. | ||
#' | ||
#' @export | ||
demographic_parity <- | ||
new_groupwise_metric( | ||
fn = detection_prevalence, | ||
name = "demographic_parity", | ||
aggregate = diff_range | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#' Equal opportunity | ||
#' | ||
#' @description | ||
#' | ||
#' Equal opportunity is satisfied when a model's predictions have the same | ||
#' true positive and false negative rates across protected groups. A value of | ||
#' 0 indicates parity across groups. | ||
#' | ||
#' `equal_opportunity()` is calculated as the difference between the largest | ||
#' and smallest value of [sens()] across groups. | ||
#' | ||
#' Equal opportunity is sometimes referred to as equality of opportunity. | ||
#' | ||
#' See the "Measuring Disparity" section for details on implementation. | ||
#' | ||
#' @inheritParams demographic_parity | ||
#' | ||
#' @templateVar fn equal_opportunity | ||
#' @templateVar internal_fn sens | ||
#' @templateVar internal_fn [sens()] | ||
#' @template return-fair | ||
#' @template event-fair | ||
#' @template examples-fair | ||
#' | ||
#' @family fairness metrics | ||
#' | ||
#' @references | ||
#' | ||
#' Hardt, M., Price, E., & Srebro, N. (2016). "Equality of opportunity in | ||
#' supervised learning". Advances in neural information processing systems, 29. | ||
#' | ||
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In | ||
#' Proceedings of the international workshop on software fairness (pp. 1-7). | ||
#' | ||
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker, | ||
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI". | ||
#' Microsoft, Tech. Rep. MSR-TR-2020-32. | ||
#' | ||
#' @export | ||
equal_opportunity <- | ||
new_groupwise_metric( | ||
fn = sens, | ||
name = "equal_opportunity", | ||
aggregate = diff_range | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we decide to move forward with a fairness-oriented name, it'd be great if we could use some example data here that has a plausibly "sensitive" attribute. yardstick doesn't
Suggests
modeldata at the moment, which has some options.infer::gss
would also work well here.