-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #434 from tidymodels/fairness
- Loading branch information
Showing
30 changed files
with
1,487 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
#' Create group-wise metrics | ||
#' | ||
#' Group-wise metrics quantify the disparity in value of a metric across a | ||
#' number of groups. Group-wise metrics with a value of zero indicate that the | ||
#' underlying metric is equal across groups. yardstick defines | ||
#' several common fairness metrics using this function, such as | ||
#' [demographic_parity()], [equal_opportunity()], and [equalized_odds()]. | ||
#' | ||
#' Note that _all_ yardstick metrics are group-aware in that, when passed | ||
#' grouped data, they will return metric values calculated for each group. | ||
#' When passed grouped data, group-wise metrics also return metric values | ||
#' for each group, but those metric values are calculated by first additionally | ||
#' grouping by the variable passed to `by` and then summarizing the per-group | ||
#' metric estimates across groups using the function passed as the | ||
#' `aggregate` argument. | ||
#' | ||
#' @param fn A yardstick metric function or metric set. | ||
#' @param name The name of the metric to place in the `.metric` column | ||
#' of the output. | ||
#' @param aggregate A function to summarize the generated metric set results. | ||
#' The function takes metric set results as the first argument and returns | ||
#' a single numeric giving the `.estimate` value as output. See the Value and | ||
#' Examples sections for example uses. | ||
#' @inheritParams new_class_metric | ||
#' | ||
#' @section Relevant Group Level: | ||
#' Additional arguments can be passed to the function outputted by | ||
#' the function that this function outputs. That is: | ||
#' | ||
#' ``` | ||
#' res_fairness <- new_groupwise_metric(...) | ||
#' res_by <- res_fairness(by) | ||
#' res_by(..., additional_arguments_to_aggregate = TRUE) | ||
#' ``` | ||
#' | ||
#' For finer control of how groups in `by` are treated, use the | ||
#' `aggregate` argument. | ||
#' | ||
#' @return | ||
#' This function is a | ||
#' [function factory](https://adv-r.hadley.nz/function-factories.html); it's | ||
#' output is itself a function. Further, the functions that this function | ||
#' outputs are also function factories. More explicitly, this looks like: | ||
#' | ||
#' ``` | ||
#' # a function with similar implementation to `demographic_parity()`: | ||
#' diff_range <- function(x) {diff(range(x$.estimate))} | ||
#' | ||
#' dem_parity <- | ||
#' new_groupwise_metric( | ||
#' fn = detection_prevalence, | ||
#' name = "dem_parity", | ||
#' aggregate = diff_range | ||
#' ) | ||
#' ``` | ||
#' | ||
#' The outputted `dem_parity` is a function that takes one argument, `by`, | ||
#' indicating the data-masked variable giving the sensitive feature. | ||
#' | ||
#' When called with a `by` argument, `dem_parity` will return a yardstick | ||
#' metric function like any other: | ||
#' | ||
#' ``` | ||
#' dem_parity_by_gender <- dem_parity(gender) | ||
#' ``` | ||
#' | ||
#' Note that `dem_parity` doesn't take any arguments other than `by`, and thus | ||
#' knows nothing about the data it will be applied to other than that it ought | ||
#' to have a column with name `"gender"` in it. | ||
#' | ||
#' The output `dem_parity_by_gender` is a metric function that takes the | ||
#' same arguments as the function supplied as `fn`, in this case | ||
#' `detection_prevalence`. It will thus interface like any other yardstick | ||
#' function except that it will look for a `"gender"` column in | ||
#' the data it's supplied. | ||
#' | ||
#' In addition to the examples below, see the documentation on the | ||
#' return value of fairness metrics like [demographic_parity()], | ||
#' [equal_opportunity()], or [equalized_odds()] to learn more about how the | ||
#' output of this function can be used. | ||
#' | ||
#' @examples | ||
#' data(hpc_cv) | ||
#' | ||
#' # `demographic_parity`, among other fairness metrics, | ||
#' # is generated with `new_groupwise_metric()`: | ||
#' diff_range <- function(x) {diff(range(x$.estimate))} | ||
#' demographic_parity_ <- | ||
#' new_groupwise_metric( | ||
#' fn = detection_prevalence, | ||
#' name = "demographic_parity", | ||
#' aggregate = diff_range | ||
#' ) | ||
#' | ||
#' m_set <- metric_set(demographic_parity_(Resample)) | ||
#' | ||
#' m_set(hpc_cv, truth = obs, estimate = pred) | ||
#' | ||
#' # the `post` argument can be used to accommodate a wide | ||
#' # variety of parameterizations. to encode demographic | ||
#' # parity as a ratio inside of a difference, for example: | ||
#' ratio_range <- function(x, ...) { | ||
#' range <- range(x$.estimate) | ||
#' range[1] / range[2] | ||
#' } | ||
#' | ||
#' demographic_parity_ratio <- | ||
#' new_groupwise_metric( | ||
#' fn = detection_prevalence, | ||
#' name = "demographic_parity_ratio", | ||
#' aggregate = ratio_range | ||
#' ) | ||
#' | ||
#' @export | ||
new_groupwise_metric <- function(fn, name, aggregate, direction = "minimize") { | ||
if (is_missing(fn) || !inherits_any(fn, c("metric", "metric_set"))) { | ||
abort("`fn` must be a metric function or metric set.") | ||
} | ||
if (is_missing(name) || !is_string(name)) { | ||
abort("`name` must be a string.") | ||
} | ||
if (is_missing(aggregate) || !is_function(aggregate)) { | ||
abort("`aggregate` must be a function.") | ||
} | ||
arg_match( | ||
direction, | ||
values = c("maximize", "minimize", "zero") | ||
) | ||
|
||
metric_factory <- | ||
function(by) { | ||
by_str <- as_string(enexpr(by)) | ||
res <- | ||
function(data, ...) { | ||
gp_vars <- dplyr::group_vars(data) | ||
|
||
if (by_str %in% gp_vars) { | ||
cli::cli_abort( | ||
"Metric is internally grouped by {.field {by_str}}; grouping \\ | ||
{.arg data} by {.field {by_str}} is not well-defined." | ||
) | ||
} | ||
|
||
# error informatively when `fn` is a metric set; see `eval_safely()` | ||
data_grouped <- dplyr::group_by(data, {{by}}, .add = TRUE) | ||
res <- | ||
tryCatch( | ||
fn(data_grouped, ...), | ||
error = function(cnd) { | ||
if (!is.null(cnd$parent)) { | ||
cnd <- cnd$parent | ||
} | ||
|
||
abort(conditionMessage(cnd), call = call(name)) | ||
} | ||
) | ||
|
||
# restore to the grouping structure in the supplied data | ||
if (length(gp_vars) > 0) { | ||
res <- dplyr::group_by(res, !!!dplyr::groups(data), .add = FALSE) | ||
} | ||
|
||
group_rows <- dplyr::group_rows(res) | ||
group_keys <- dplyr::group_keys(res) | ||
res <- dplyr::ungroup(res) | ||
groups <- vec_chop(res, indices = group_rows) | ||
out <- vector("list", length = length(groups)) | ||
|
||
for (i in seq_along(groups)) { | ||
group <- groups[[i]] | ||
|
||
.estimate <- aggregate(group) | ||
|
||
if (!is_bare_numeric(.estimate)) { | ||
abort( | ||
"`aggregate` must return a single numeric value.", | ||
call = call2("new_groupwise_metric") | ||
) | ||
} | ||
|
||
elt_out <- list( | ||
.metric = name, | ||
.by = by_str, | ||
.estimator = group$.estimator[1], | ||
.estimate = .estimate | ||
) | ||
|
||
out[[i]] <- tibble::new_tibble(elt_out) | ||
} | ||
|
||
group_keys <- vctrs::vec_rep_each(group_keys, times = list_sizes(out)) | ||
out <- vec_rbind(!!!out) | ||
out <- vec_cbind(group_keys, out) | ||
|
||
out | ||
} | ||
res <- new_class_metric(res, direction = "minimize") | ||
|
||
structure( | ||
res, | ||
direction = direction, | ||
by = by_str, | ||
class = groupwise_metric_class(fn) | ||
) | ||
} | ||
|
||
structure(metric_factory, class = c("metric_factory", "function")) | ||
} | ||
|
||
groupwise_metric_class <- function(fn) { | ||
if (inherits(fn, "metric")) { | ||
return(class(fn)) | ||
} | ||
|
||
class(attr(fn, "metrics")[[1]]) | ||
} | ||
|
||
diff_range <- function(x) { | ||
estimates <- x$.estimate | ||
|
||
max(estimates) - min(estimates) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#' Demographic parity | ||
#' | ||
#' @description | ||
#' Demographic parity is satisfied when a model's predictions have the | ||
#' same predicted positive rate across groups. A value of 0 indicates parity | ||
#' across groups. Note that this definition does not depend on the true | ||
#' outcome; the `truth` argument is included in outputted metrics | ||
#' for consistency. | ||
#' | ||
#' `demographic_parity()` is calculated as the difference between the largest | ||
#' and smallest value of [detection_prevalence()] across groups. | ||
#' | ||
#' Demographic parity is sometimes referred to as group fairness, | ||
#' disparate impact, or statistical parity. | ||
#' | ||
#' See the "Measuring Disparity" section for details on implementation. | ||
#' | ||
#' @param by The column identifier for the sensitive feature. This should be an | ||
#' unquoted column name referring to a column in the un-preprocessed data. | ||
#' | ||
#' @templateVar fn demographic_parity | ||
#' @templateVar internal_fn detection_prevalence | ||
#' @templateVar internal_fn [detection_prevalence()] | ||
#' @template return-fair | ||
#' @template event-fair | ||
#' @template examples-fair | ||
#' | ||
#' @family fairness metrics | ||
#' | ||
#' @references | ||
#' | ||
#' Agarwal, A., Beygelzimer, A., Dudik, M., Langford, J., & Wallach, H. (2018). | ||
#' "A Reductions Approach to Fair Classification." Proceedings of the 35th | ||
#' International Conference on Machine Learning, in Proceedings of Machine | ||
#' Learning Research. 80:60-69. | ||
#' | ||
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In | ||
#' Proceedings of the international workshop on software fairness (pp. 1-7). | ||
#' | ||
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker, | ||
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI". | ||
#' Microsoft, Tech. Rep. MSR-TR-2020-32. | ||
#' | ||
#' @export | ||
demographic_parity <- | ||
new_groupwise_metric( | ||
fn = detection_prevalence, | ||
name = "demographic_parity", | ||
aggregate = diff_range | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#' Equal opportunity | ||
#' | ||
#' @description | ||
#' | ||
#' Equal opportunity is satisfied when a model's predictions have the same | ||
#' true positive and false negative rates across protected groups. A value of | ||
#' 0 indicates parity across groups. | ||
#' | ||
#' `equal_opportunity()` is calculated as the difference between the largest | ||
#' and smallest value of [sens()] across groups. | ||
#' | ||
#' Equal opportunity is sometimes referred to as equality of opportunity. | ||
#' | ||
#' See the "Measuring Disparity" section for details on implementation. | ||
#' | ||
#' @inheritParams demographic_parity | ||
#' | ||
#' @templateVar fn equal_opportunity | ||
#' @templateVar internal_fn sens | ||
#' @templateVar internal_fn [sens()] | ||
#' @template return-fair | ||
#' @template event-fair | ||
#' @template examples-fair | ||
#' | ||
#' @family fairness metrics | ||
#' | ||
#' @references | ||
#' | ||
#' Hardt, M., Price, E., & Srebro, N. (2016). "Equality of opportunity in | ||
#' supervised learning". Advances in neural information processing systems, 29. | ||
#' | ||
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In | ||
#' Proceedings of the international workshop on software fairness (pp. 1-7). | ||
#' | ||
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker, | ||
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI". | ||
#' Microsoft, Tech. Rep. MSR-TR-2020-32. | ||
#' | ||
#' @export | ||
equal_opportunity <- | ||
new_groupwise_metric( | ||
fn = sens, | ||
name = "equal_opportunity", | ||
aggregate = diff_range | ||
) |
Oops, something went wrong.