Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: epiverse-trace/serofoi
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 130946f6225041134ddedf7a7bd681072f2b8058
Choose a base ref
..
head repository: epiverse-trace/serofoi
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 1373c19da7a60618af73d01e7c2c156769cc0ed2
Choose a head ref
12 changes: 0 additions & 12 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -153,18 +153,6 @@ references:
name-particle: van den
orcid: https://orcid.org/0000-0002-9335-7468
year: '2024'
- type: software
title: Hmisc
abstract: 'Hmisc: Harrell Miscellaneous'
notes: Imports
url: https://hbiostat.org/R/Hmisc/
repository: https://CRAN.R-project.org/package=Hmisc
authors:
- family-names: Harrell Jr
given-names: Frank E
email: fh@fharrell.com
orcid: https://orcid.org/0000-0002-8271-5493
year: '2024'
- type: software
title: loo
abstract: 'loo: Efficient Leave-One-Out Cross-Validation and WAIC for Bayesian Models'
3 changes: 1 addition & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: serofoi
Title: Estimates the Force-of-Infection of a Given Pathogen from
Population Based Seroprevalence Studies
Version: 1.0.1
Version: 1.0.2
Authors@R: c(
person("Zulma M.", "Cucunubá", , "zulma.cucunuba@javeriana.edu.co", role = c("aut", "cre"),
comment = c(ORCID = "0000-0002-8165-3198")),
@@ -29,7 +29,6 @@ Imports:
ggplot2,
glue,
graphics,
Hmisc,
loo,
Matrix,
methods,
178 changes: 132 additions & 46 deletions R/plot_seromodel.R
Original file line number Diff line number Diff line change
@@ -17,22 +17,32 @@ prepare_serosurvey_for_plotting <- function( #nolint
serosurvey,
alpha = 0.05
) {
# The binomial confidence interval calculation is based on:
# https://forum.posit.co/t/apply-binomial-test-for-each-row-in-a-data-table/32112/2 #nolint
serosurvey$seroprev <- serosurvey$n_seropositive / serosurvey$n_sample

get_seroprev_binconf <- function(seropositive, sample, seroprevalence) {
ci <- stats::binom.test(
x = seropositive,
n = sample,
p = seroprevalence,
conf.level = 1 - alpha
)$conf.int
names(ci) <- c("seroprev_lower", "seroprev_upper")
return(ci)
}

serosurvey <- cbind(
serosurvey,
Hmisc::binconf(
serosurvey$n_seropositive,
serosurvey$n_sample,
alpha = alpha,
method = "exact",
return.df = TRUE
)
) |>
dplyr::rename(
seroprev = "PointEst",
seroprev_lower = "Lower",
seroprev_upper = "Upper"
) |>
serosurvey <- serosurvey |>
dplyr::rowwise() |>
dplyr::mutate(
binconf = list(get_seroprev_binconf(
seropositive = .data$n_seropositive,
sample = .data$n_sample,
seroprevalence = .data$seroprev
)
)) |>
dplyr::ungroup() |>
tidyr::unnest_wider(.data$binconf) |>
dplyr::arrange(.data$age_group) |>
dplyr::relocate(!!dplyr::sym("age_group"))

@@ -288,6 +298,7 @@ plot_seroprevalence_estimates <- function(
#'
#' @inheritParams extract_central_estimates
#' @inheritParams fit_seromodel
#' @inheritParams plot_rhats
#' @param foi_df Dataframe with columns
#' \describe{
#' \item{`year`/`age`}{Year/Age (depending on the model)}
@@ -316,15 +327,29 @@ plot_foi_estimates <- function(
alpha = 0.05,
foi_df = NULL,
foi_max = NULL,
size_text = 11
size_text = 11,
plot_constant = FALSE,
x_axis = NA
) {
# TODO: Add checks for foi_df (size, colnames, etc.)
checkmate::assert_class(seromodel, "stanfit", null.ok = TRUE)

model_name <- seromodel@model_name
stopifnot(
"seromodel@name should start with either 'age' or 'time'" =
startsWith(model_name, "age") | startsWith(model_name, "time")
"seromodel@name should start with either 'age', 'time' or 'constant' " =
startsWith(model_name, "age") | startsWith(model_name, "time") |
startsWith(model_name, "constant")
)

error_msg_x_axis <- paste0(
"x_axis specification as either 'age' or 'time' when plotting ",
"constant FOI estimates is required to avoid ambiguity"
)
validate_plot_constant(
plot_constant = plot_constant,
x_axis = x_axis,
model_name = model_name,
error_msg_x_axis = error_msg_x_axis
)

foi_central_estimates <- extract_central_estimates(
@@ -337,23 +362,29 @@ plot_foi_estimates <- function(
if (is.null(foi_max))
foi_max <- max(foi_central_estimates$upper)

if (startsWith(model_name, "age")) {
if (startsWith(model_name, "age") || (plot_constant && (x_axis == "age"))) {
xlab <- "Age"
ages <- 1:max(serosurvey$age_max)
foi_central_estimates <- dplyr::mutate(
foi_central_estimates,
age = ages
)

if (!is.null(foi_df)) {
foi_central_estimates <- dplyr::left_join(
foi_central_estimates, foi_df,
by = "age"
)
}

foi_plot <- ggplot2::ggplot(
data = foi_central_estimates, ggplot2::aes(x = .data$age)
)
} else if (startsWith(model_name, "time")) {

} else if (
startsWith(model_name, "time") ||
(plot_constant && x_axis == "time")
) {
checkmate::assert_names(names(serosurvey), must.include = "survey_year")
xlab <- "Year"
ages <- rev(1:max(serosurvey$age_max))
@@ -404,8 +435,11 @@ plot_foi_estimates <- function(

#' Plot r-hats convergence criteria for the specified model
#'
#' @inheritParams extract_central_estimates
#' @inheritParams plot_serosurvey
#' @inheritParams plot_summary
#' @param x_axis either `"time"` or `"age"`. Specifies time axis values
#' label for constant model additional plots. Only relevant when
#'and `seromodel@model_name == "constant"`
#' @return ggplot object showing the r-hats of the model to be compared with the
#' convergence criteria (horizontal dashed line)
#' @examples
@@ -425,20 +459,33 @@ plot_foi_estimates <- function(
plot_rhats <- function(
seromodel,
serosurvey,
par_name = "foi_expanded",
size_text = 11
size_text = 11,
plot_constant = FALSE,
x_axis = NA
) {
checkmate::assert_class(seromodel, "stanfit", null.ok = TRUE)

model_name <- seromodel@model_name
stopifnot(
"seromodel@name should start with either 'age' or 'time'" =
startsWith(model_name, "age") | startsWith(model_name, "time")
"seromodel@name should start with either 'age', 'time' or 'constant' " =
startsWith(model_name, "age") | startsWith(model_name, "time") |
startsWith(model_name, "constant")
)

error_msg_x_axis <- paste0(
"x_axis specification as either 'age' or 'time' when plotting rhat ",
"estimates of constant models is required to avoid ambiguity"
)
validate_plot_constant(
plot_constant = plot_constant,
x_axis = x_axis,
model_name = model_name,
error_msg_x_axis = error_msg_x_axis
)

rhats <- bayesplot::rhat(seromodel, par_name)
rhats <- bayesplot::rhat(seromodel, "foi_expanded")

if (startsWith(model_name, "age")) {
if (startsWith(model_name, "age") || (plot_constant && (x_axis == "age"))) {
xlab <- "Age"
ages <- 1:max(serosurvey$age_max)
rhats_df <- data.frame(
@@ -449,7 +496,10 @@ plot_rhats <- function(
rhats_plot <- ggplot2::ggplot(
data = rhats_df, ggplot2::aes(x = .data$age)
)
} else if (startsWith(model_name, "time")) {
} else if (
startsWith(model_name, "time") ||
(plot_constant && x_axis == "time")
) {
checkmate::assert_names(names(serosurvey), must.include = "survey_year")
xlab <- "Year"
ages <- rev(1:max(serosurvey$age_max))
@@ -482,13 +532,16 @@ plot_rhats <- function(
ggplot2::xlab(xlab) +
ggplot2::ylab("Convergence (r-hats)")

return(rhats_plot)
return(rhats_plot)
}

#' Plots model summary
#'
#' @inheritParams summarise_seromodel
#' @inheritParams plot_serosurvey
#' @param plot_constant boolean specifying whether to plot single FOI estimate
#' and its corresponding rhat value instead of showing this information in the
#' summary. Only relevant when `seromodel@model_name == "constant"`)
#' @return ggplot object with a summary of the specified model
#' @examples
#' data(veev2012)
@@ -501,23 +554,35 @@ plot_summary <- function(
loo_estimate_digits = 1,
central_estimate_digits = 2,
rhat_digits = 2,
size_text = 11
size_text = 11,
plot_constant = FALSE
) {
checkmate::assert_class(seromodel, "stanfit", null.ok = TRUE)

summary_table <- t( #convert summary to table
summarise_seromodel(
seromodel = seromodel,
serosurvey = serosurvey,
loo_estimate_digits = loo_estimate_digits,
central_estimate_digits = central_estimate_digits,
rhat_digits = rhat_digits
)
summary_list <- summarise_seromodel(
seromodel = seromodel,
serosurvey = serosurvey,
loo_estimate_digits = loo_estimate_digits,
central_estimate_digits = central_estimate_digits,
rhat_digits = rhat_digits
)

if (plot_constant) {
if (startsWith(seromodel@model_name, "constant")) {
drop <- c("foi", "foi_rhat")
summary_list <- summary_list[!(names(summary_list) %in% drop)]
} else {
error_msg <- paste0(
"plot_constant is only relevant when ",
"`seromodel@model_name == 'constant'`"
)
stop(error_msg)
}
}

summary_df <- data.frame(
row = rev(seq_len(NCOL(summary_table))),
text = paste0(colnames(summary_table), ": ", summary_table[1, ])
row = rev(seq_len(NCOL(t(summary_list)))),
text = paste0(colnames(t(summary_list)), ": ", summary_list)
)

summary_plot <- ggplot2::ggplot(
@@ -562,17 +627,32 @@ plot_seromodel <- function(
central_estimate_digits = 2,
seroreversion_digits = 2,
rhat_digits = 2,
size_text = 11
size_text = 11,
plot_constant = FALSE,
x_axis = NA
) {
checkmate::assert_class(seromodel, "stanfit", null.ok = TRUE)

model_name <- seromodel@model_name
error_msg_x_axis <- paste0(
"x_axis specification as either 'age' or 'time' when plotting ",
"FOI and rhat estimates of constant models is required to avoid ambiguity"
)
validate_plot_constant(
plot_constant = plot_constant,
x_axis = x_axis,
model_name = model_name,
error_msg_x_axis = error_msg_x_axis
)

summary_plot <- plot_summary(
seromodel = seromodel,
serosurvey = serosurvey,
loo_estimate_digits = loo_estimate_digits,
central_estimate_digits = central_estimate_digits,
rhat_digits = rhat_digits,
size_text = size_text
size_text = size_text,
plot_constant = plot_constant
)

seroprev_plot <- plot_seroprevalence_estimates(
@@ -589,21 +669,27 @@ plot_seromodel <- function(
seroprev_plot
)

model_name <- seromodel@model_name
if (!startsWith(model_name, "constant")) {
# This condition (!p | q) is equivalent to !(p & !q)
# This is to preserve the default behavior for single FOI estimation for
# constant models
if (!startsWith(model_name, "constant") || plot_constant) {
foi_plot <- plot_foi_estimates(
seromodel,
serosurvey,
alpha = alpha,
foi_df = foi_df,
foi_max = foi_max,
size_text
size_text,
plot_constant = plot_constant,
x_axis = x_axis
)

rhats_plot <- plot_rhats(
seromodel = seromodel,
serosurvey = serosurvey,
size_text = size_text
size_text = size_text,
plot_constant = plot_constant,
x_axis = x_axis
)

plot_list <- c(
24 changes: 24 additions & 0 deletions R/validation.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: Add documentation and return calls to validation functions

validate_serosurvey <- function(serosurvey) {
# Check that necessary columns are present
col_types <- list(
@@ -140,3 +142,25 @@ validate_foi_index <- function(

return(foi_index)
}

validate_plot_constant <- function(
plot_constant,
x_axis,
model_name,
error_msg_x_axis
) {
if (plot_constant) {
if (!startsWith(model_name, "constant")) {
error_msg <- paste0(
"plot_constant is only relevant when ",
"`seromodel@model_name == 'constant'`"
)
stop(error_msg)
}
if (!(x_axis %in% c("age", "time"))) {
stop(error_msg_x_axis)
}
}

return(TRUE)
}
Loading