Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fairness metrics #434

Merged
merged 34 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b98b2ac
rename `metrics.R` -> `aaa-metrics.R`
simonpcouch May 15, 2023
92c232d
implement metric constructor
simonpcouch May 15, 2023
44ef50b
implement 3 canonical fairness metrics
simonpcouch May 15, 2023
c260656
validate canonical metrics against fairlearn
simonpcouch May 15, 2023
31e4a5f
add NEWS entry
simonpcouch May 15, 2023
5922ae2
add pkgdown entry
simonpcouch May 15, 2023
7236a14
skip error context snap pre-4.0 (currently oldrel-4)
simonpcouch May 15, 2023
60973ba
remove unused test object
simonpcouch May 15, 2023
6998108
remove `rlang::` and `vctrs::`
simonpcouch May 16, 2023
a9fc2e0
optimize `diff_range()` for speed
simonpcouch May 16, 2023
27f53e9
transition from `summarize()`
simonpcouch May 16, 2023
7fe7fb8
remove incomplete phrase
simonpcouch May 16, 2023
bdb7c5a
highlight function factory as output
simonpcouch May 16, 2023
ca6d7b8
restore passing ellipses to `.post()`
simonpcouch May 16, 2023
f1d663f
error informatively with redundant grouping
simonpcouch May 18, 2023
45f57ff
don't pass ellipses to `.post`
simonpcouch May 18, 2023
43037d9
run test on all systems
simonpcouch May 18, 2023
ef7f5a1
namespace `group_by` in test
simonpcouch May 18, 2023
556f2c5
update snaps
simonpcouch May 18, 2023
6fad6a4
defer to inputted `.fn` for metric class
simonpcouch May 18, 2023
a9231a9
check `direction` argument
simonpcouch May 18, 2023
82717fa
move `max_positive_rate_diff()` to where it's used
simonpcouch Jun 21, 2023
a2dfe36
document implementation in `@description`
simonpcouch Jun 21, 2023
375269b
abort with condition parent to improve error context
simonpcouch Jun 21, 2023
76f2d4f
special-case metric factories in `metric_set()` checks
simonpcouch Jun 22, 2023
587ff32
rename `fairness_metric()` -> `new_groupwise_metric()`
simonpcouch Jun 22, 2023
05086ea
contrast "group-wise" and usual grouped behavior of yardstick metrics
simonpcouch Jun 22, 2023
25a893f
rename `new_groupwise_metric()` arguments
simonpcouch Jun 26, 2023
4436785
clarify documentation on `aggregate` arg
simonpcouch Jun 26, 2023
12ff56a
use devl probably
EmilHvitfeldt Jun 27, 2023
251f45b
Merge branch 'fairness' of github.com:tidymodels/yardstick into fairness
EmilHvitfeldt Jun 27, 2023
6c0b76f
`aggregrate` -> `aggregate`
simonpcouch Jun 27, 2023
219562c
remove `Remotes`---probably package is now on CRAN
simonpcouch Oct 23, 2023
690e738
correct duplicated description
simonpcouch Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Suggests:
crayon,
ggplot2,
knitr,
probably (>= 0.0.6),
probably (>= 1.0.0),
rmarkdown,
survival (>= 3.5-0),
testthat (>= 3.0.0),
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,13 @@ export(concordance_survival_vec)
export(conf_mat)
export(curve_metric_summarizer)
export(curve_survival_metric_summarizer)
export(demographic_parity)
export(detection_prevalence)
export(detection_prevalence_vec)
export(dots_to_estimate)
export(dynamic_survival_metric_summarizer)
export(equal_opportunity)
export(equalized_odds)
export(f_meas)
export(f_meas_vec)
export(finalize_estimator)
Expand Down Expand Up @@ -174,6 +177,7 @@ export(msd)
export(msd_vec)
export(new_class_metric)
export(new_dynamic_survival_metric)
export(new_groupwise_metric)
export(new_integrated_survival_metric)
export(new_numeric_metric)
export(new_prob_metric)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ calculated with `roc_auc_survival()`.

* `metric_set()` can now be used with a combination of dynamic and static survival metrics.

* `demographic_parity()`, `equalized_odds()`, and `equal_opportunity()` are new metrics for measuring model fairness. Each is implemented with the `new_groupwise_metric()` constructor, a general interface for defining group-aware metrics that allows for quickly and flexibly defining fairness metrics with the problem context in mind.

# yardstick 1.2.0

## New Metrics
Expand Down
13 changes: 13 additions & 0 deletions R/metrics.R → R/aaa-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,19 @@ validate_function_class <- function(fns) {
}
}

# Special case unevaluated group-wise metric factories
if ("metric_factory" %in% fn_cls) {
factories <- fn_cls[fn_cls == "metric_factory"]
cli::cli_abort(
c("{cli::qty(factories)}The input{?s} {.arg {names(factories)}} \\
{?is a/are} {.help [group-wise metric](yardstick::new_groupwise_metric)} \\
{?factory/factories} and must be passed a data-column before
addition to a metric set.",
"i" = "Did you mean to type e.g. `{names(factories)[1]}(col_name)`?"),
call = rlang::call2("metric_set")
)
}

# Each element of the list contains the names of the fns
# that inherit that specific class
fn_bad_names <- lapply(fn_cls_unique, function(x) {
Expand Down
1 change: 1 addition & 0 deletions R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ utils::globalVariables(
c(
# for class prob metrics
"estimate",
".estimator",
"threshold",
"specificity",
".level",
Expand Down
222 changes: 222 additions & 0 deletions R/fair-aaa.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#' Create group-wise metrics
#'
#' Group-wise metrics quantify the disparity in value of a metric across a
#' number of groups. Group-wise metrics with a value of zero indicate that the
#' underlying metric is equal across groups. yardstick defines
#' several common fairness metrics using this function, such as
#' [demographic_parity()], [equal_opportunity()], and [equalized_odds()].
#'
#' Note that _all_ yardstick metrics are group-aware in that, when passed
#' grouped data, they will return metric values calculated for each group.
#' When passed grouped data, group-wise metrics also return metric values
#' for each group, but those metric values are calculated by first additionally
#' grouping by the variable passed to `by` and then summarizing the per-group
#' metric estimates across groups using the function passed as the
#' `aggregate` argument.
#'
#' @param fn A yardstick metric function or metric set.
#' @param name The name of the metric to place in the `.metric` column
#' of the output.
#' @param aggregate A function to summarize the generated metric set results.
#' The function takes metric set results as the first argument and returns
#' a single numeric giving the `.estimate` value as output. See the Value and
#' Examples sections for example uses.
#' @inheritParams new_class_metric
#'
#' @section Relevant Group Level:
#' Additional arguments can be passed to the function outputted by
#' the function that this function outputs. That is:
#'
#' ```
#' res_fairness <- new_groupwise_metric(...)
#' res_by <- res_fairness(by)
#' res_by(..., additional_arguments_to_aggregate = TRUE)
#' ```
#'
#' For finer control of how groups in `by` are treated, use the
#' `aggregate` argument.
#'
#' @return
#' This function is a
#' [function factory](https://adv-r.hadley.nz/function-factories.html); it's
#' output is itself a function. Further, the functions that this function
#' outputs are also function factories. More explicitly, this looks like:
#'
#' ```
#' # a function with similar implementation to `demographic_parity()`:
#' diff_range <- function(x) {diff(range(x$.estimate))}
#'
#' dem_parity <-
#' new_groupwise_metric(
#' fn = detection_prevalence,
#' name = "dem_parity",
#' aggregate = diff_range
#' )
#' ```
#'
#' The outputted `dem_parity` is a function that takes one argument, `by`,
#' indicating the data-masked variable giving the sensitive feature.
#'
#' When called with a `by` argument, `dem_parity` will return a yardstick
#' metric function like any other:
#'
#' ```
#' dem_parity_by_gender <- dem_parity(gender)
#' ```
#'
#' Note that `dem_parity` doesn't take any arguments other than `by`, and thus
#' knows nothing about the data it will be applied to other than that it ought
#' to have a column with name `"gender"` in it.
#'
#' The output `dem_parity_by_gender` is a metric function that takes the
#' same arguments as the function supplied as `fn`, in this case
#' `detection_prevalence`. It will thus interface like any other yardstick
#' function except that it will look for a `"gender"` column in
#' the data it's supplied.
#'
#' In addition to the examples below, see the documentation on the
#' return value of fairness metrics like [demographic_parity()],
#' [equal_opportunity()], or [equalized_odds()] to learn more about how the
#' output of this function can be used.
#'
#' @examples
#' data(hpc_cv)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we decide to move forward with a fairness-oriented name, it'd be great if we could use some example data here that has a plausibly "sensitive" attribute. yardstick doesn't Suggests modeldata at the moment, which has some options. infer::gss would also work well here.

#'
#' # `demographic_parity`, among other fairness metrics,
#' # is generated with `new_groupwise_metric()`:
#' diff_range <- function(x) {diff(range(x$.estimate))}
#' demographic_parity_ <-
#' new_groupwise_metric(
#' fn = detection_prevalence,
#' name = "demographic_parity",
#' aggregate = diff_range
#' )
#'
#' m_set <- metric_set(demographic_parity_(Resample))
#'
#' m_set(hpc_cv, truth = obs, estimate = pred)
#'
#' # the `post` argument can be used to accommodate a wide
#' # variety of parameterizations. to encode demographic
#' # parity as a ratio inside of a difference, for example:
#' ratio_range <- function(x, ...) {
#' range <- range(x$.estimate)
#' range[1] / range[2]
#' }
#'
#' demographic_parity_ratio <-
#' new_groupwise_metric(
#' fn = detection_prevalence,
#' name = "demographic_parity_ratio",
#' aggregate = ratio_range
#' )
#'
#' @export
new_groupwise_metric <- function(fn, name, aggregate, direction = "minimize") {
if (is_missing(fn) || !inherits_any(fn, c("metric", "metric_set"))) {
abort("`fn` must be a metric function or metric set.")
}
if (is_missing(name) || !is_string(name)) {
abort("`name` must be a string.")
}
if (is_missing(aggregate) || !is_function(aggregate)) {
abort("`aggregate` must be a function.")
}
arg_match(
direction,
values = c("maximize", "minimize", "zero")
)

metric_factory <-
function(by) {
by_str <- as_string(enexpr(by))
res <-
function(data, ...) {
gp_vars <- dplyr::group_vars(data)

if (by_str %in% gp_vars) {
cli::cli_abort(
"Metric is internally grouped by {.field {by_str}}; grouping \\
{.arg data} by {.field {by_str}} is not well-defined."
)
}

# error informatively when `fn` is a metric set; see `eval_safely()`
data_grouped <- dplyr::group_by(data, {{by}}, .add = TRUE)
res <-
tryCatch(
fn(data_grouped, ...),
error = function(cnd) {
if (!is.null(cnd$parent)) {
cnd <- cnd$parent
}

abort(conditionMessage(cnd), call = call(name))
}
)

# restore to the grouping structure in the supplied data
if (length(gp_vars) > 0) {
res <- dplyr::group_by(res, !!!dplyr::groups(data), .add = FALSE)
}

group_rows <- dplyr::group_rows(res)
group_keys <- dplyr::group_keys(res)
res <- dplyr::ungroup(res)
groups <- vec_chop(res, indices = group_rows)
out <- vector("list", length = length(groups))

for (i in seq_along(groups)) {
group <- groups[[i]]

.estimate <- aggregate(group)

if (!is_bare_numeric(.estimate)) {
abort(
"`aggregate` must return a single numeric value.",
call = call2("new_groupwise_metric")
)
}

elt_out <- list(
.metric = name,
.by = by_str,
.estimator = group$.estimator[1],
.estimate = .estimate
)

out[[i]] <- tibble::new_tibble(elt_out)
}

group_keys <- vctrs::vec_rep_each(group_keys, times = list_sizes(out))
out <- vec_rbind(!!!out)
out <- vec_cbind(group_keys, out)

out
}
res <- new_class_metric(res, direction = "minimize")

structure(
res,
direction = direction,
by = by_str,
class = groupwise_metric_class(fn)
)
}

structure(metric_factory, class = c("metric_factory", "function"))
}

groupwise_metric_class <- function(fn) {
if (inherits(fn, "metric")) {
return(class(fn))
}

class(attr(fn, "metrics")[[1]])
}

diff_range <- function(x) {
estimates <- x$.estimate

max(estimates) - min(estimates)
}
50 changes: 50 additions & 0 deletions R/fair-demographic_parity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#' Demographic parity
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @description
#' Demographic parity is satisfied when a model's predictions have the
#' same predicted positive rate across groups. A value of 0 indicates parity
#' across groups. Note that this definition does not depend on the true
#' outcome; the `truth` argument is included in outputted metrics
#' for consistency.
#'
#' `demographic_parity()` is calculated as the difference between the largest
#' and smallest value of [detection_prevalence()] across groups.
#'
#' Demographic parity is sometimes referred to as group fairness,
#' disparate impact, or statistical parity.
#'
#' See the "Measuring Disparity" section for details on implementation.
#'
#' @param by The column identifier for the sensitive feature. This should be an
#' unquoted column name referring to a column in the un-preprocessed data.
#'
#' @templateVar fn demographic_parity
#' @templateVar internal_fn detection_prevalence
#' @templateVar internal_fn [detection_prevalence()]
#' @template return-fair
#' @template event-fair
#' @template examples-fair
#'
#' @family fairness metrics
#'
#' @references
#'
#' Agarwal, A., Beygelzimer, A., Dudik, M., Langford, J., & Wallach, H. (2018).
#' "A Reductions Approach to Fair Classification." Proceedings of the 35th
#' International Conference on Machine Learning, in Proceedings of Machine
#' Learning Research. 80:60-69.
#'
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In
#' Proceedings of the international workshop on software fairness (pp. 1-7).
#'
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker,
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI".
#' Microsoft, Tech. Rep. MSR-TR-2020-32.
#'
#' @export
demographic_parity <-
new_groupwise_metric(
fn = detection_prevalence,
name = "demographic_parity",
aggregate = diff_range
)
45 changes: 45 additions & 0 deletions R/fair-equal_opportunity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#' Equal opportunity
#'
#' @description
#'
#' Equal opportunity is satisfied when a model's predictions have the same
#' true positive and false negative rates across protected groups. A value of
#' 0 indicates parity across groups.
#'
#' `equal_opportunity()` is calculated as the difference between the largest
#' and smallest value of [sens()] across groups.
#'
#' Equal opportunity is sometimes referred to as equality of opportunity.
#'
#' See the "Measuring Disparity" section for details on implementation.
#'
#' @inheritParams demographic_parity
#'
#' @templateVar fn equal_opportunity
#' @templateVar internal_fn sens
#' @templateVar internal_fn [sens()]
#' @template return-fair
#' @template event-fair
#' @template examples-fair
#'
#' @family fairness metrics
#'
#' @references
#'
#' Hardt, M., Price, E., & Srebro, N. (2016). "Equality of opportunity in
#' supervised learning". Advances in neural information processing systems, 29.
#'
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In
#' Proceedings of the international workshop on software fairness (pp. 1-7).
#'
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker,
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI".
#' Microsoft, Tech. Rep. MSR-TR-2020-32.
#'
#' @export
equal_opportunity <-
new_groupwise_metric(
fn = sens,
name = "equal_opportunity",
aggregate = diff_range
)
Loading
Loading