Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fairness metrics #434

Merged
merged 34 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b98b2ac
rename `metrics.R` -> `aaa-metrics.R`
simonpcouch May 15, 2023
92c232d
implement metric constructor
simonpcouch May 15, 2023
44ef50b
implement 3 canonical fairness metrics
simonpcouch May 15, 2023
c260656
validate canonical metrics against fairlearn
simonpcouch May 15, 2023
31e4a5f
add NEWS entry
simonpcouch May 15, 2023
5922ae2
add pkgdown entry
simonpcouch May 15, 2023
7236a14
skip error context snap pre-4.0 (currently oldrel-4)
simonpcouch May 15, 2023
60973ba
remove unused test object
simonpcouch May 15, 2023
6998108
remove `rlang::` and `vctrs::`
simonpcouch May 16, 2023
a9fc2e0
optimize `diff_range()` for speed
simonpcouch May 16, 2023
27f53e9
transition from `summarize()`
simonpcouch May 16, 2023
7fe7fb8
remove incomplete phrase
simonpcouch May 16, 2023
bdb7c5a
highlight function factory as output
simonpcouch May 16, 2023
ca6d7b8
restore passing ellipses to `.post()`
simonpcouch May 16, 2023
f1d663f
error informatively with redundant grouping
simonpcouch May 18, 2023
45f57ff
don't pass ellipses to `.post`
simonpcouch May 18, 2023
43037d9
run test on all systems
simonpcouch May 18, 2023
ef7f5a1
namespace `group_by` in test
simonpcouch May 18, 2023
556f2c5
update snaps
simonpcouch May 18, 2023
6fad6a4
defer to inputted `.fn` for metric class
simonpcouch May 18, 2023
a9231a9
check `direction` argument
simonpcouch May 18, 2023
82717fa
move `max_positive_rate_diff()` to where it's used
simonpcouch Jun 21, 2023
a2dfe36
document implementation in `@description`
simonpcouch Jun 21, 2023
375269b
abort with condition parent to improve error context
simonpcouch Jun 21, 2023
76f2d4f
special-case metric factories in `metric_set()` checks
simonpcouch Jun 22, 2023
587ff32
rename `fairness_metric()` -> `new_groupwise_metric()`
simonpcouch Jun 22, 2023
05086ea
contrast "group-wise" and usual grouped behavior of yardstick metrics
simonpcouch Jun 22, 2023
25a893f
rename `new_groupwise_metric()` arguments
simonpcouch Jun 26, 2023
4436785
clarify documentation on `aggregate` arg
simonpcouch Jun 26, 2023
12ff56a
use devl probably
EmilHvitfeldt Jun 27, 2023
251f45b
Merge branch 'fairness' of github.com:tidymodels/yardstick into fairness
EmilHvitfeldt Jun 27, 2023
6c0b76f
`aggregrate` -> `aggregate`
simonpcouch Jun 27, 2023
219562c
remove `Remotes`---probably package is now on CRAN
simonpcouch Oct 23, 2023
690e738
correct duplicated description
simonpcouch Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,16 @@ export(concordance_survival_vec)
export(conf_mat)
export(curve_metric_summarizer)
export(curve_survival_metric_summarizer)
export(demographic_parity)
export(detection_prevalence)
export(detection_prevalence_vec)
export(dots_to_estimate)
export(dynamic_survival_metric_summarizer)
export(equal_opportunity)
export(equalized_odds)
export(f_meas)
export(f_meas_vec)
export(fairness_metric)
export(finalize_estimator)
export(finalize_estimator_internal)
export(gain_capture)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ calculated with `roc_auc_survival()`.

* `metric_set()` can now be used with a combination of dynamic and static survival metrics.

* `demographic_parity()`, `equalized_odds()`, and `equal_opportunity()` are new metrics for measuring model fairness. Each is implemented with the `fairness_metric()` constructor, a general interface for defining group-aware metrics that allows for quickly and flexibly defining fairness metrics with the problem context in mind.

# yardstick 1.2.0

## New Metrics
Expand Down
EmilHvitfeldt marked this conversation as resolved.
Show resolved Hide resolved
File renamed without changes.
1 change: 1 addition & 0 deletions R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ utils::globalVariables(
c(
# for class prob metrics
"estimate",
".estimator",
"threshold",
"specificity",
".level",
Expand Down
137 changes: 137 additions & 0 deletions R/fair-aaa.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#' Create fairness metrics
#'
#' Fairness metrics quantify the disparity in value of a metric across a number
#' of groups. Fairness metrics with a value of zero indicate that the
#' underlying metric has parity across groups. yardstick defines
#' several common fairness metrics using this function, such as
#' [demographic_parity()], [equal_opportunity()], and [equalized_odds()].
#'
#' @param .fn A yardstick metric function or metric set.
#' @param .name The name of the metric to place in the `.metric` column
#' of the output.
#' @param .post A function to post-process the generated metric set results `x`.
#' In many cases, `~diff(range(x$.estimate))` or
#' `~r <- range(x$.estimate); r[1]/r[2]`.
#'
#' @section Relevant Group Level:
#' By default,
#'
#' Additional arguments can be passed to the function outputted by
#' the function that this function outputs. That is:
#'
#' ```
#' res_fairness <- fairness_metric(...)
#' res_by <- res_fairness(by)
#' res_by(..., additional_arguments_to_.post = TRUE)
#' ```
#'
#' For finer control of how groups in `by` are treated, use the
#' `.post` argument.
#'
#' @return A function with one argument, `by`, indicating the data-masked
#' variable giving the sensitive feature. See the documentation on the
#' return value of fairness metrics like [demographic_parity()],
#' [equal_opportunity()], or [equalized_odds()] to learn more about how the
#' output of this function can be used.
#'
#' @examples
#' data(hpc_cv)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we decide to move forward with a fairness-oriented name, it'd be great if we could use some example data here that has a plausibly "sensitive" attribute. yardstick doesn't Suggests modeldata at the moment, which has some options. infer::gss would also work well here.

#'
#' # `demographic_parity`, among other fairness metrics,
#' # is generated with `fairness_metric()`:
#' diff_range <- function(x, ...) {diff(range(x$.estimate))}
#' demographic_parity_ <-
#' fairness_metric(
#' .fn = detection_prevalence,
#' .name = "demographic_parity",
#' .post = diff_range
#' )
#'
#' m_set <- metric_set(demographic_parity_(Resample))
#'
#' m_set(hpc_cv, truth = obs, estimate = pred)
#'
#' # the `post` argument can be used to accommodate a wide
#' # variety of parameterizations. to encode demographic
#' # parity as a ratio inside of a difference, for example:
#' ratio_range <- function(x, ...) {
#' range <- range(x$.estimate)
#' range[1] / range[2]
#' }
#'
#' demographic_parity_ratio <-
#' fairness_metric(
#' .fn = detection_prevalence,
#' .name = "demographic_parity_ratio",
#' .post = ratio_range
#' )
#'
#' @export
fairness_metric <- function(.fn, .name, .post) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second and third issues linked in the PR description may be related here—there's nothing specific to only fairness about this function for now besides naming choices. This might be a solution for grouped metrics generally.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would lean toward changing the naming here (and the direction = "minimize" default?) so that it is more clearly groupwise_metric() or similar. I'd change the section in the pkgdown site to something like "Fairness and Group Metrics". I think this is the right way to go both because folks have non-fairness group metric needs, and because then the name helps users understand how fairness metrics work. I think it's better for learning/using, not worse.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree, switching this file over to talk about them as "group-wise" metrics is the right move.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm game! Thank you.

A difficult bit here is that all yardstick metrics know about groups, so I want to make sure we don't imply that non-fairness-metrics aren't group-aware, there just isn't an intermediate grouped operation happening under the hood. I do think that groupwise_metric() could be a good way to phrase that (accompanied by strong docs), but also very much open to other ideas, esp. if there's some dplyr-ish concept that already speaks to this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notes from the group meeting:

  • Max mentioned that it might be nice to prefix whatever this function name is with create_ or some other eliciting verb to indicate that this is a function factory, and others agreed
  • I suggested disparity_metric as a descriptor for this type of metric that doesn't have as strong of a social connotation—seems like "disparity" could describe differences across groups regardless of whether that group is regarded as a sensitive feature

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose moving forward with create_disparity_metric(). Any thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of using create_* here, but it is a departure compared to the other function factories in yardstick (of which there are a lot, like metric_tweak(), metric_set(), and so forth). Do you think it's better to stay more similar to the naming conventions of yardstick, or to use something like create_*?

I have a mild preference for something like create_groupwise_metric() because I think there is more ML community vocabulary around what "groupwise" means. The word "disparity" makes me think about the specific metric disparate impact. That being said, my opinion is not super strong here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points--if we want to look to other function factories in the package, maybe the parallel we might want to draw is with new_metric()? Something like new_groupwise_metric()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that a lot, new_groupwise_metric() 👍

if (rlang::is_missing(.fn) || !inherits_any(.fn, c("metric", "metric_set"))) {
rlang::abort("`.fn` must be a metric function or metric set.")
}
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
if (rlang::is_missing(.name) || !is_string(.name)) {
abort("`.name` must be a string.")
}
if (rlang::is_missing(.post) || !is_function(.post)) {
abort("`.post` must be a function.")
}

function(by) {
by_str <- rlang::as_string(rlang::enexpr(by))
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
res <-
function(data, ...) {
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
gp_vars <- dplyr::group_vars(data)
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved

res <- dplyr::group_by(data, {{by}}, .add = TRUE)
res <- .fn(res, ...)

if (length(gp_vars) > 0) {
splits <- vctrs::vec_split(res, res[gp_vars])
.estimate <- vapply(splits$val, .post, numeric(1), ...)
} else {
.estimate <- .post(res, ...)
}

if (!rlang::is_bare_numeric(.estimate)) {
abort(
"`.post` must return a single numeric value.",
call = rlang::call2("fairness_metric")
)
}

if (length(gp_vars) > 0) {
res <- dplyr::group_by(res, !!!dplyr::groups(data), .add = FALSE)
}

res <-
dplyr::summarize(
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
res,
.metric = .name,
!!".by" := by_str,
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
.estimator = .estimator[1],
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
.groups = "drop"
)

res$.estimate <- .estimate

res
}
res <- new_class_metric(res, direction = "minimize")
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
attr(res, "by") <- by_str
res
}
}

diff_range <- function(x, ...) {
diff(range(x$.estimate))
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
}

max_positive_rate_diff <- function(x, ...) {
metric_values <- vctrs::vec_split(x, x$.metric)

positive_rate_diff <- vapply(metric_values$val, diff_range, numeric(1), ...)

max(positive_rate_diff)
}
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
47 changes: 47 additions & 0 deletions R/fair-demographic_parity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#' Demographic parity
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @description
#' Demographic parity is satisfied when a model's predictions have the
#' same predicted positive rate across groups. A value of 0 indicates parity
#' across groups. Note that this definition does not depend on the true
#' outcome; the `truth` argument is included in outputted metrics
#' for consistency.
#'
#' Demographic parity is sometimes referred to as group fairness,
#' disparate impact, or statistical parity.
#'
#' See the "Measuring Disparity" section for details on implementation.
#'
#' @param by The column identifier for the sensitive feature. This should be an
#' unquoted column name referring to a column in the un-preprocessed data.
#'
#' @templateVar fn demographic_parity
#' @templateVar internal_.fn detection_prevalence
#' @templateVar internal_fn [detection_prevalence()]
#' @template return-fair
#' @template event-fair
#' @template examples-fair
#'
#' @family fairness metrics
#'
#' @references
#'
#' Agarwal, A., Beygelzimer, A., Dudik, M., Langford, J., & Wallach, H. (2018).
#' "A Reductions Approach to Fair Classification." Proceedings of the 35th
#' International Conference on Machine Learning, in Proceedings of Machine
#' Learning Research. 80:60-69.
#'
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In
#' Proceedings of the international workshop on software fairness (pp. 1-7).
#'
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker,
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI".
#' Microsoft, Tech. Rep. MSR-TR-2020-32.
#'
#' @export
demographic_parity <-
fairness_metric(
.fn = detection_prevalence,
.name = "demographic_parity",
.post = diff_range
)
43 changes: 43 additions & 0 deletions R/fair-equal_opportunity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#' Equal opportunity
#'
#' @description
#'
#' Equal opportunity is satisfied when a model's predictions have the same
#' true positive and false negative rates across protected groups. A value of
#' 0 indicates parity across groups.
#'
#' Equal opportunity is sometimes referred to as conditional procedure accuracy
#' equality or disparate mistreatment.
#'
#' See the "Measuring Disparity" section for details on implementation.
#'
#' @inheritParams demographic_parity
#'
#' @templateVar fn equal_opportunity
#' @templateVar internal_.fn sens
#' @templateVar internal_fn [sens()]
#' @template return-fair
#' @template event-fair
#' @template examples-fair
#'
#' @family fairness metrics
#'
#' @references
#'
#' Hardt, M., Price, E., & Srebro, N. (2016). "Equality of opportunity in
#' supervised learning". Advances in neural information processing systems, 29.
#'
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In
#' Proceedings of the international workshop on software fairness (pp. 1-7).
#'
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker,
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI".
#' Microsoft, Tech. Rep. MSR-TR-2020-32.
#'
#' @export
equal_opportunity <-
fairness_metric(
.fn = sens,
.name = "equal_opportunity",
.post = diff_range
)
70 changes: 70 additions & 0 deletions R/fair-equalized_odds.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#' Equalized odds
#'
#' @description
#'
#' Equalized odds is satisfied when a model's predictions have the same false
#' positive, true positive, false negative, and true negative rates across
#' protected groups. A value of 0 indicates parity across groups.
#'
#' Equalized odds is sometimes referred to as conditional procedure accuracy
#' equality or disparate mistreatment.
#'
#' See the "Measuring disparity" section for details on implementation.
#'
#' @inheritParams demographic_parity
#'
#' @templateVar fn equalized_odds
#' @templateVar internal_fn [sens()] and [spec()]
#' @template return-fair
#' @template examples-fair
#'
#' @section Measuring Disparity:
#' By default, this function takes the maximum difference in range of [sens()]
#' and [spec()] `.estimate`s across groups. That is, the maximum pair-wise
#' disparity in [sens()] or [spec()] between groups is the return value of
#' `equalized_odds()`'s `.estimate`.
#'
#' For finer control of group treatment, construct a context-aware fairness
#' metric with the [fairness_metric()] function by passing a custom `.post`
#' function:
#'
#' ```
#' # see yardstick:::max_positive_rate_diff for the actual `.post()`
#' diff_range <- function(x, ...) {diff(range(x$.estimate))}
#'
#' equalized_odds_2 <-
#' fairness_metric(
#' .fn = metric_set(sens, spec),
#' .name = "equalized_odds_2",
#' .post = diff_range
#' )
#' ```
#'
#' In `.post()`, `x` is the [metric_set()] output with [sens()] and [spec()]
#' values for each group, and `...` gives additional arguments (such as a grouping
#' level to refer to as the "baseline") to pass to the function outputted
#' by `equalized_odds_2()` for context.
#'
#' @family fairness metrics
#'
#' @references
#'
#' Agarwal, A., Beygelzimer, A., Dudik, M., Langford, J., & Wallach, H. (2018).
#' "A Reductions Approach to Fair Classification." Proceedings of the 35th
#' International Conference on Machine Learning, in Proceedings of Machine
#' Learning Research. 80:60-69.
#'
#' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In
#' Proceedings of the international workshop on software fairness (pp. 1-7).
#'
#' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker,
#' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI".
#' Microsoft, Tech. Rep. MSR-TR-2020-32.
#'
#' @export
equalized_odds <-
fairness_metric(
.fn = metric_set(sens, spec),
.name = "equalized_odds",
.post = max_positive_rate_diff
)
7 changes: 7 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ reference:
- iic
- poisson_log_loss

- title: Fairness Metrics
contents:
- fairness_metric
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
- demographic_parity
- equalized_odds
- equal_opportunity

- title: Curve Functions
contents:
- roc_curve
Expand Down
25 changes: 25 additions & 0 deletions man-roxygen/event-fair.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#' @section Measuring Disparity:
#' By default, this function takes the difference in range of <%=internal_fn %>
#' `.estimate`s across groups. That is, the maximum pair-wise disparity between
#' groups is the return value of `<%=fn %>()`'s `.estimate`.
#'
#' For finer control of group treatment, construct a context-aware fairness
#' metric with the [fairness_metric()] function by passing a custom `.post`
#' function:
#'
#' ```
#' # the actual default `.post` is:
#' diff_range <- function(x, ...) {diff(range(x$.estimate))}
#'
#' <%=fn %>_2 <-
#' fairness_metric(
#' .fn = <%=internal_.fn %>,
#' .name = "<%=fn %>_2",
#' .post = diff_range
#' )
#' ```
#'
#' In `.post()`, `x` is the `metric_set()` output with <%=internal_fn %> values
#' for each group, and `...` gives additional arguments (such as a grouping
#' level to refer to as the "baseline") to pass to the function outputted
#' by `<%=fn %>_2()` for context.
19 changes: 19 additions & 0 deletions man-roxygen/examples-fair.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#' @examples
#' library(dplyr)
#'
#' data(hpc_cv)
#'
#' head(hpc_cv)
#'
#' # evaluate `<%=fn %>()` by Resample
#' m_set <- metric_set(<%=fn %>(Resample))
#'
#' # use output like any other metric set
#' hpc_cv %>%
#' m_set(truth = obs, estimate = pred)
#'
#' # can mix fairness metrics and regular metrics
#' m_set_2 <- metric_set(sens, <%=fn %>(Resample))
#'
#' hpc_cv %>%
#' m_set_2(truth = obs, estimate = pred)
Loading