Skip to content

Commit

Permalink
Merge pull request #434 from tidymodels/fairness
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Oct 27, 2023
2 parents e5c36f2 + 690e738 commit ce03a94
Show file tree
Hide file tree
Showing 30 changed files with 1,487 additions and 3 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Suggests:
crayon,
ggplot2,
knitr,
probably (>= 0.0.6),
probably (>= 1.0.0),
rmarkdown,
survival (>= 3.5-0),
testthat (>= 3.0.0),
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,13 @@ export(concordance_survival_vec)
export(conf_mat)
export(curve_metric_summarizer)
export(curve_survival_metric_summarizer)
export(demographic_parity)
export(detection_prevalence)
export(detection_prevalence_vec)
export(dots_to_estimate)
export(dynamic_survival_metric_summarizer)
export(equal_opportunity)
export(equalized_odds)
export(f_meas)
export(f_meas_vec)
export(finalize_estimator)
Expand Down Expand Up @@ -174,6 +177,7 @@ export(msd)
export(msd_vec)
export(new_class_metric)
export(new_dynamic_survival_metric)
export(new_groupwise_metric)
export(new_integrated_survival_metric)
export(new_numeric_metric)
export(new_prob_metric)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ calculated with `roc_auc_survival()`.

* `metric_set()` can now be used with a combination of dynamic and static survival metrics.

* `demographic_parity()`, `equalized_odds()`, and `equal_opportunity()` are new metrics for measuring model fairness. Each is implemented with the `new_groupwise_metric()` constructor, a general interface for defining group-aware metrics that allows for quickly and flexibly defining fairness metrics with the problem context in mind.

# yardstick 1.2.0

## New Metrics
Expand Down
13 changes: 13 additions & 0 deletions R/metrics.R → R/aaa-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,19 @@ validate_function_class <- function(fns) {
}
}

# Special case unevaluated group-wise metric factories
if ("metric_factory" %in% fn_cls) {
factories <- fn_cls[fn_cls == "metric_factory"]
cli::cli_abort(
c("{cli::qty(factories)}The input{?s} {.arg {names(factories)}} \\
{?is a/are} {.help [group-wise metric](yardstick::new_groupwise_metric)} \\
{?factory/factories} and must be passed a data-column before
addition to a metric set.",
"i" = "Did you mean to type e.g. `{names(factories)[1]}(col_name)`?"),
call = rlang::call2("metric_set")
)
}

# Each element of the list contains the names of the fns
# that inherit that specific class
fn_bad_names <- lapply(fn_cls_unique, function(x) {
Expand Down
1 change: 1 addition & 0 deletions R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ utils::globalVariables(
c(
# for class prob metrics
"estimate",
".estimator",
"threshold",
"specificity",
".level",
Expand Down
222 changes: 222 additions & 0 deletions R/fair-aaa.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#' Create group-wise metrics
#'
#' Group-wise metrics quantify the disparity in value of a metric across a
#' number of groups. Group-wise metrics with a value of zero indicate that the
#' underlying metric is equal across groups. yardstick defines
#' several common fairness metrics using this function, such as
#' [demographic_parity()], [equal_opportunity()], and [equalized_odds()].
#'
#' Note that _all_ yardstick metrics are group-aware in that, when passed
#' grouped data, they will return metric values calculated for each group.
#' When passed grouped data, group-wise metrics also return metric values
#' for each group, but those metric values are calculated by first additionally
#' grouping by the variable passed to `by` and then summarizing the per-group
#' metric estimates across groups using the function passed as the
#' `aggregate` argument.
#'
#' @param fn A yardstick metric function or metric set.
#' @param name The name of the metric to place in the `.metric` column
#' of the output.
#' @param aggregate A function to summarize the generated metric set results.
#' The function takes metric set results as the first argument and returns
#' a single numeric giving the `.estimate` value as output. See the Value and
#' Examples sections for example uses.
#' @inheritParams new_class_metric
#'
#' @section Relevant Group Level:
#' Additional arguments can be passed to the function outputted by
#' the function that this function outputs. That is:
#'
#' ```
#' res_fairness <- new_groupwise_metric(...)
#' res_by <- res_fairness(by)
#' res_by(..., additional_arguments_to_aggregate = TRUE)
#' ```
#'
#' For finer control of how groups in `by` are treated, use the
#' `aggregate` argument.
#'
#' @return
#' This function is a
#' [function factory](https://adv-r.hadley.nz/function-factories.html); it's
#' output is itself a function. Further, the functions that this function
#' outputs are also function factories. More explicitly, this looks like:
#'
#' ```
#' # a function with similar implementation to `demographic_parity()`:
#' diff_range <- function(x) {diff(range(x$.estimate))}
#'
#' dem_parity <-
#' new_groupwise_metric(
#' fn = detection_prevalence,
#' name = "dem_parity",
#' aggregate = diff_range
#' )
#' ```
#'
#' The outputted `dem_parity` is a function that takes one argument, `by`,
#' indicating the data-masked variable giving the sensitive feature.
#'
#' When called with a `by` argument, `dem_parity` will return a yardstick
#' metric function like any other:
#'
#' ```
#' dem_parity_by_gender <- dem_parity(gender)
#' ```
#'
#' Note that `dem_parity` doesn't take any arguments other than `by`, and thus
#' knows nothing about the data it will be applied to other than that it ought
#' to have a column with name `"gender"` in it.
#'
#' The output `dem_parity_by_gender` is a metric function that takes the
#' same arguments as the function supplied as `fn`, in this case
#' `detection_prevalence`. It will thus interface like any other yardstick
#' function except that it will look for a `"gender"` column in
#' the data it's supplied.
#'
#' In addition to the examples below, see the documentation on the
#' return value of fairness metrics like [demographic_parity()],
#' [equal_opportunity()], or [equalized_odds()] to learn more about how the
#' output of this function can be used.
#'
#' @examples
#' data(hpc_cv)
#'
#' # `demographic_parity`, among other fairness metrics,
#' # is generated with `new_groupwise_metric()`:
#' diff_range <- function(x) {diff(range(x$.estimate))}
#' demographic_parity_ <-
#' new_groupwise_metric(
#' fn = detection_prevalence,
#' name = "demographic_parity",
#' aggregate = diff_range
#' )
#'
#' m_set <- metric_set(demographic_parity_(Resample))
#'
#' m_set(hpc_cv, truth = obs, estimate = pred)
#'
#' # the `post` argument can be used to accommodate a wide
#' # variety of parameterizations. to encode demographic
#' # parity as a ratio inside of a difference, for example:
#' ratio_range <- function(x, ...) {
#' range <- range(x$.estimate)
#' range[1] / range[2]
#' }
#'
#' demographic_parity_ratio <-
#' new_groupwise_metric(
#' fn = detection_prevalence,
#' name = "demographic_parity_ratio",
#' aggregate = ratio_range
#' )
#'
#' @export
new_groupwise_metric <- function(fn, name, aggregate, direction = "minimize") {
if (is_missing(fn) || !inherits_any(fn, c("metric", "metric_set"))) {
abort("`fn` must be a metric function or metric set.")
}
if (is_missing(name) || !is_string(name)) {
abort("`name` must be a string.")
}
if (is_missing(aggregate) || !is_function(aggregate)) {
abort("`aggregate` must be a function.")
}
arg_match(
direction,
values = c("maximize", "minimize", "zero")
)

metric_factory <-
function(by) {
by_str <- as_string(enexpr(by))
res <-
function(data, ...) {
gp_vars <- dplyr::group_vars(data)

if (by_str %in% gp_vars) {
cli::cli_abort(
"Metric is internally grouped by {.field {by_str}}; grouping \\
{.arg data} by {.field {by_str}} is not well-defined."
)
}

# error informatively when `fn` is a metric set; see `eval_safely()`
data_grouped <- dplyr::group_by(data, {{by}}, .add = TRUE)
res <-
tryCatch(
fn(data_grouped, ...),
error = function(cnd) {
if (!is.null(cnd$parent)) {
cnd <- cnd$parent
}

abort(conditionMessage(cnd), call = call(name))
}
)

# restore to the grouping structure in the supplied data
if (length(gp_vars) > 0) {
res <- dplyr::group_by(res, !!!dplyr::groups(data), .add = FALSE)
}

group_rows <- dplyr::group_rows(res)
group_keys <- dplyr::group_keys(res)
res <- dplyr::ungroup(res)
groups <- vec_chop(res, indices = group_rows)
out <- vector("list", length = length(groups))

for (i in seq_along(groups)) {
group <- groups[[i]]

.estimate <- aggregate(group)

if (!is_bare_numeric(.estimate)) {
abort(
"`aggregate` must return a single numeric value.",
call = call2("new_groupwise_metric")
)
}

elt_out <- list(
.metric = name,
.by = by_str,
.estimator = group$.estimator[1],
.estimate = .estimate
)

out[[i]] <- tibble::new_tibble(elt_out)
}

group_keys <- vctrs::vec_rep_each(group_keys, times = list_sizes(out))
out <- vec_rbind(!!!out)
out <- vec_cbind(group_keys, out)

out
}
res <- new_class_metric(res, direction = "minimize")

structure(
res,
direction = direction,
by = by_str,
class = groupwise_metric_class(fn)
)
}

structure(metric_factory, class = c("metric_factory", "function"))
}

groupwise_metric_class <- function(fn) {
if (inherits(fn, "metric")) {
return(class(fn))
}

class(attr(fn, "metrics")[[1]])
}

diff_range <- function(x) {
estimates <- x$.estimate

max(estimates) - min(estimates)
}
50 changes: 50 additions & 0 deletions R/fair-demographic_parity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#' Demographic parity
#'
#' @description
#' Demographic parity is satisfied when a model's predictions have the
#' same predicted positive rate across groups. A value of 0 indicates parity
#' across groups. Note that this definition does not depend on the true
#' outcome; the `truth` argument is included in outputted metrics
#' for consistency.
#'
#' `demographic_parity()` is calculated as the difference between the largest
#' and smallest value of [detection_prevalence()] across groups.
#'
#' Demographic parity is sometimes referred to as group fairness,
#' disparate impact, or statistical parity.
#'
#' See the "Measuring Disparity" section for details on implementation.
#'
#' @param by The column identifier for the sensitive feature. This should be an
#' unquoted column name referring to a column in the un-preprocessed data.
#'
#' @templateVar fn demographic_parity
#' @templateVar internal_fn detection_prevalence
#' @templateVar internal_fn [detection_prevalence()]
#' @template return-fair
#' @template event-fair
#' @template examples-fair
#'
#' @family fairness metrics
#'
#' @references
#'
#' Agarwal, A., Beygelzimer, A., Dudik, M., Langford, J., & Wallach, H. (2018).
#' "A Reductions Approach to Fair Classification." Proceedings of the 35th
#' International Conference on Machine Learning, in Proceedings of Machine
#' Learning Research. 80:60-69.
#'
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In
#' Proceedings of the international workshop on software fairness (pp. 1-7).
#'
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker,
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI".
#' Microsoft, Tech. Rep. MSR-TR-2020-32.
#'
#' @export
demographic_parity <-
new_groupwise_metric(
fn = detection_prevalence,
name = "demographic_parity",
aggregate = diff_range
)
45 changes: 45 additions & 0 deletions R/fair-equal_opportunity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#' Equal opportunity
#'
#' @description
#'
#' Equal opportunity is satisfied when a model's predictions have the same
#' true positive and false negative rates across protected groups. A value of
#' 0 indicates parity across groups.
#'
#' `equal_opportunity()` is calculated as the difference between the largest
#' and smallest value of [sens()] across groups.
#'
#' Equal opportunity is sometimes referred to as equality of opportunity.
#'
#' See the "Measuring Disparity" section for details on implementation.
#'
#' @inheritParams demographic_parity
#'
#' @templateVar fn equal_opportunity
#' @templateVar internal_fn sens
#' @templateVar internal_fn [sens()]
#' @template return-fair
#' @template event-fair
#' @template examples-fair
#'
#' @family fairness metrics
#'
#' @references
#'
#' Hardt, M., Price, E., & Srebro, N. (2016). "Equality of opportunity in
#' supervised learning". Advances in neural information processing systems, 29.
#'
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In
#' Proceedings of the international workshop on software fairness (pp. 1-7).
#'
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker,
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI".
#' Microsoft, Tech. Rep. MSR-TR-2020-32.
#'
#' @export
equal_opportunity <-
new_groupwise_metric(
fn = sens,
name = "equal_opportunity",
aggregate = diff_range
)
Loading

0 comments on commit ce03a94

Please sign in to comment.