Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ranked probability scores for ordinal data #525

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Collate:
'prob-mn_log_loss.R'
'prob-pr_auc.R'
'prob-pr_curve.R'
'prob-ranked_prob_score.R'
'prob-roc_auc.R'
'prob-roc_aunp.R'
'prob-roc_aunu.R'
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ S3method(finalize_estimator_internal,mcc)
S3method(finalize_estimator_internal,mn_log_loss)
S3method(finalize_estimator_internal,pr_auc)
S3method(finalize_estimator_internal,pr_curve)
S3method(finalize_estimator_internal,ranked_prob_score)
S3method(finalize_estimator_internal,roc_auc)
S3method(finalize_estimator_internal,roc_curve)
S3method(format,metric)
Expand Down Expand Up @@ -79,6 +80,7 @@ S3method(print,conf_mat)
S3method(print,metric)
S3method(print,metric_factory)
S3method(print,metric_set)
S3method(ranked_prob_score,data.frame)
S3method(recall,data.frame)
S3method(recall,matrix)
S3method(recall,table)
Expand Down Expand Up @@ -200,6 +202,8 @@ export(pr_curve)
export(precision)
export(precision_vec)
export(prob_metric_summarizer)
export(ranked_prob_score)
export(ranked_prob_score_vec)
export(recall)
export(recall_vec)
export(rmse)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# yardstick (development version)

* The ranked probability score for ordinal classification data was added with `ranked_prob_score()`. (#524)

# yardstick 1.3.1

## Bug Fixes
Expand Down
7 changes: 7 additions & 0 deletions R/estimator-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ finalize_estimator_internal.mn_log_loss <- finalize_estimator_internal.accuracy
#' @export
finalize_estimator_internal.brier_class <- finalize_estimator_internal.accuracy

#' @export
finalize_estimator_internal.ranked_prob_score <- function(metric_dispatcher,
x,
estimator,
call = caller_env()) {
"multiclass"
}

# Classification cost extends naturally to multiclass and produce the same
# result regardless of the "event" level.
Expand Down
165 changes: 165 additions & 0 deletions R/prob-ranked_prob_score.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#' Ranked probability scores for ordinal classification models
#'
#' Compute the ranked probability score (RPS) for a classification model using
#' ordered classes.
#'
#' @param truth The column identifier for the true class results
#' (that is an _ordered_ `factor`). This should be an unquoted column name
#' although this argument is passed by expression and supports
#' [quasiquotation][rlang::quasiquotation] (you can unquote column names).
#' For `_vec()` functions, a factor vector with class `ordered`.
#' @param estimate A matrix with as many columns as factor levels of `truth`. _It
#' is assumed that these are in the same order as the levels of `truth`._
#' @family class probability metrics
#' @templateVar fn ranked_prob_score
#' @template return
#' @details
#'
#' The ranked probability score is a Brier score for ordinal data that uses the
#' _cumulative_ probability of an event (i.e. `Pr[class <= i]` for `i` = 1,
#' 2, ..., `C` classes). These probabilities are compared to indicators for the
#' truth being less than or equal to class `i`.
#'
#' Since the cumulative sum of a vector of probability predictions add up to
#' one, there is an embedded redundancy in the data. For this reason, the raw
#' mean is divided by the number of classes minus one.
#'
#' Smaller values of the score are associated with better model performance.
#'
#' @section Multiclass:
#' Ranked probability scores can be computed in the same way for any number of
#' classes. Because of this, no averaging types are supported.
#'
#' @inheritParams pr_auc
#'
#' @author Max Kuhn
#'
#' @references
#'
#' Wilks, D. S. (2011). _Statistical Methods in the Atmospheric Sciences_.
#' Academic press. (see Chapter 7)
#'
#' Janitza, S., Tutz, G., & Boulesteix, A. L. (2016). Random forest for ordinal
#' responses: prediction and variable selection. Computational Statistics and
#' Data Analysis, 96, 57-73. (see Section 2)
#'
#' Lechner, M., & Okasa, G. (2019). Random forest estimation of the ordered
#' choice model. arXiv preprint arXiv:1907.02436. (see Section 5)
#'
#' @examples
#' library(dplyr)
#' data(hpc_cv)
#'
#' hpc_cv$obs <- as.ordered(hpc_cv$obs)
#'
#' # You can use the col1:colN tidyselect syntax
#' hpc_cv %>%
#' filter(Resample == "Fold01") %>%
#' ranked_prob_score(obs, VF:L)
#'
#' # Groups are respected
#' hpc_cv %>%
#' group_by(Resample) %>%
#' ranked_prob_score(obs, VF:L)
#'
#' @export
ranked_prob_score <- function(data, ...) {
UseMethod("ranked_prob_score")
}
ranked_prob_score <- new_prob_metric(
ranked_prob_score,
direction = "minimize"
)

#' @export
#' @rdname ranked_prob_score
ranked_prob_score.data.frame <- function(data,
truth,
...,
na_rm = TRUE,
case_weights = NULL) {
case_weights_quo <- enquo(case_weights)

prob_metric_summarizer(
name = "ranked_prob_score",
fn = ranked_prob_score_vec,
data = data,
truth = !!enquo(truth),
...,
na_rm = na_rm,
case_weights = !!case_weights_quo
)
}

#' @rdname ranked_prob_score
#' @export
ranked_prob_score_vec <- function(truth,
estimate,
na_rm = TRUE,
case_weights = NULL,
...) {
abort_if_class_pred(truth)
if (!is.ordered(truth)) {
cli::cli_abort("The ranked probability score requires the outcome to be an
{.strong ordered} factor, not {.obj_type_friendly {truth}}.")
}

num_lvl <- length(levels(truth))
if (NCOL(estimate) == 1) {
cli::cli_abort("For these data, the ranked probability score requires
{.arg estimate} to have {num_lvl} probability columns.")
}
estimate <- as.matrix(estimate)

# TODO should `...` be empty?
estimator <- finalize_estimator(truth, metric_class = "ranked_prob_score")

check_prob_metric(truth, estimate, case_weights, estimator)

if (na_rm) {
result <- yardstick_remove_missing(truth, estimate, case_weights)

truth <- result$truth
estimate <- result$estimate
case_weights <- result$case_weights
} else if (yardstick_any_missing(truth, estimate, case_weights)) {
return(NA_real_)
}

ranked_prob_score_estimator_impl(
truth = truth,
estimate = estimate,
case_weights = case_weights
)
}

ranked_prob_score_estimator_impl <- function(truth,
estimate,
case_weights) {
rps_factor(
truth = truth,
estimate = estimate,
case_weights = case_weights
)
}


cumulative_rows <- function(x) {
t(apply(x, 1, cumsum))
}

# When `truth` is a factor
rps_factor <- function(truth, estimate, case_weights = NULL) {
num_class <- length(levels(truth))
inds <- hardhat::fct_encode_one_hot(truth)
cum_ind <- cumulative_rows(inds)
cum_estimate <- cumulative_rows(estimate)

case_weights <- vctrs::vec_cast(case_weights, to = double())

# RPS divides by the number of classes minus one since the cumulative
# probabilities always sum to one and this "pads" the differences by 1, .i.e
# there are num_class - 1 independent pieces of information.
# Also brier_ind() divides the raw mean by 2 so we take that out
brier_ind(cum_ind, cum_estimate, case_weights) / (num_class - 1) * 2
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ reference:
- mn_log_loss
- classification_cost
- brier_class
- ranked_prob_score

- title: Regression Metrics
contents:
Expand Down
1 change: 1 addition & 0 deletions man/average_precision.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/brier_class.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/classification_cost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/gain_capture.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mn_log_loss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/pr_auc.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

117 changes: 117 additions & 0 deletions man/ranked_prob_score.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading