Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shifted lognormal uniform #20

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export(dlognormal)
export(dlognormal_natural)
export(dlomax)
export(dshifted_inv_gaussian)
export(dshifted_lognormal_uniform)
export(dsimplex)
export(dsoftplusnormal)
export(dsymlognormal)
Expand All @@ -47,6 +48,7 @@ export(logistic)
export(logit)
export(logitnormal)
export(lognormal_natural)
export(logsumexp)
export(lomax)
export(pkumaraswamy)
export(qbeta_mean)
Expand Down Expand Up @@ -76,13 +78,15 @@ export(rlognormal)
export(rlognormal_natural)
export(rlomax)
export(rshifted_inv_gaussian)
export(rshifted_lognormal_uniform)
export(rsimplex)
export(rsoftplusnormal)
export(rstudent_mean)
export(rsymlognormal)
export(runit_lindley)
export(rweibull_median)
export(shifted_inv_gaussian)
export(shifted_lognormal_uniform)
export(simplex)
export(softplus)
export(softplusnormal)
Expand Down
13 changes: 13 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,16 @@ symlog <- function(x) {
inv_symlog <- function(x) {
return(sign(x) * (exp(abs(x)) - 1))
}

#' Logarithm of the sum of exponentials.
#'
#' A more numerically stable equivalent to \code{log(sum(exp(x)))}
#'
#' @source http://tr.im/hH5A
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The source link seems to be dead

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, will replace with a link to Wikipedia

#' @param x a vector of values
#' @return log(sum(exp(x)))
#' @export
logsumexp <- function (x) {
y = max(x)
y + log(sum(exp(x - y)))
}
259 changes: 259 additions & 0 deletions R/shifted_lognormal_uniform.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
#' Probability density and RNG function for the mixture of shifted lognormal and
#' uniform distributions.
#'
#' @source Some background, discussion and examples at
#' \url{http://www.martinmodrak.cz/2021/04/01/using-brms-to-model-reaction-times-contaminated-with-errors/}
#'
#' @details The mixture of shifted lognormal and uniform can be described as
#' \deqn{y_i =
#' \begin{cases}
#' u_i & \mathrm{if} \quad z_i = 0 \\
#' s_i + r_i & \mathrm{if} \quad z_i = 1
#' \end{cases}
#' \\
#' u_i \sim Uniform(0, \alpha) \\
#' \log(r_i) \sim Normal(\mu_i, \sigma) \\
#' P(z_i = 0) = \theta}
#'
#'
#' Where θ corresponds to \code{mix}, α to \code{max_uniform}
#' and \eqn{s_i} to \code{shift}.
#'
#' @param n the number of values to draw from the RNG
#' @param y the observed value
#' @param meanlog the mean of the lognormal component
#' @param sdlog the sd of the lognormal component
#' @param mix the probability of the value comming from the uniform component
#' @param shift the shift the lognormal distribution
#' @param max_uniform the maximum value of the uniform component
#' @name shifted_lognormal_uniform_distribution
NULL

#' @rdname shifted_lognormal_uniform_distribution
#' @export
rshifted_lognormal_uniform <- function(n, meanlog = 0, sdlog = 1, mix = 0.1, shift = 0, max_uniform = 100) {
stopifnot(is.numeric(n) & length(n) == 1 & n >= 0)
n <- as.integer(n)
stopifnot(all(sdlog > 0))
stopifnot(all(mix >= 0 & mix <= 1))
stopifnot(all(shift >= 0))
stopifnot(all(max_uniform > 0))

ifelse(runif(n) < mix,
runif(n, 0, max_uniform),
shift + rlnorm(n, meanlog = meanlog, sdlog = sdlog))
}


#' @rdname shifted_lognormal_uniform_distribution
#' @export
dshifted_lognormal_uniform <- function(y, meanlog = 0, sdlog = 1, mix = 0.1, shift = 0, max_uniform = 100) {
stopifnot(all(y > 0))
stopifnot(all(sdlog > 0))
stopifnot(all(mix >= 0 & mix <= 1))
stopifnot(all(shift >= 0))
stopifnot(all(max_uniform > 0))

unif_llh = dunif(y , min = 0, max = max_uniform, log = TRUE)
lognormal_llh = dlnorm(y - shift, meanlog = meanlog, sdlog = sdlog, log = TRUE) -
plnorm(max_uniform - shift, meanlog = meanlog, sdlog = sdlog, log.p = TRUE)


# Computing logsumexp(log(mix) + unif_llh, log1p(-mix) + lognormal_llh)
# but vectorized
llh_matrix <- array(NA_real_, dim = c(2, max(length(unif_llh), length(lognormal_llh))))
llh_matrix[1,] <- log(mix) + unif_llh
llh_matrix[2,] <- log1p(-mix) + lognormal_llh
return(apply(llh_matrix, MARGIN = 2, FUN = logsumexp))
}

posterior_predict_shifted_lognormal_uniform <- function(i, prep, ...) {
if((!is.null(prep$data$lb) && prep$data$lb[i] > 0) ||
(!is.null(prep$data$ub) && prep$data$ub[i] < Inf)) {
stop("Predictions for truncated distributions not supported")
}

mu <- brms::get_dpar(prep, "mu", i = i)
sigma <- brms::get_dpar(prep, "sigma", i = i)
mix <- brms::get_dpar(prep, "mix", i = i)
shiftprop <- brms::get_dpar(prep, "shiftprop", i = i)

max_shift <- prep$data$vreal1[i]
max_uniform <- prep$data$vreal2[i]
shift = shiftprop * max_shift

return(
rshifted_lognormal_uniform(prep$ndraws, meanlog = mu, sdlog = sigma,
mix = mix, shift = shift, max_uniform = max_uniform)
)
}


posterior_epred_shifted_lognormal_uniform <- function(prep) {
if((!is.null(prep$data$lb) && prep$data$lb[i] > 0) ||
(!is.null(prep$data$ub) && prep$data$ub[i] < Inf)) {
stop("Predictions for truncated distributions not supported")
}

mu <- brms::get_dpar(prep, "mu")
sigma <- brms::get_dpar(prep, "sigma")
mix <- brms::get_dpar(prep, "mix")
shiftprop <- brms::get_dpar(prep, "shiftprop")

max_shift <- prep$data$vreal1
max_uniform <- prep$data$vreal2
shift = shiftprop * max_shift

shifted_lognormal_mean <- shift + exp(mu + sigma^2 / 2)
uniform_mean <- 0.5 * max_uniform

return(
mix * uniform_mean + (1 - mix) * shifted_lognormal_mean
)
}


log_lik_shifted_lognormal_uniform <- function(i, prep) {
mu <- brms::get_dpar(prep, "mu", i = i)
sigma <- brms::get_dpar(prep, "sigma", i = i)
mix <- brms::get_dpar(prep, "mix", i = i)
shiftprop <- brms::get_dpar(prep, "shiftprop", i = i)

max_shift <- prep$data$vreal1[i]
max_uniform <- prep$data$vreal2[i]
shift = shiftprop * max_shift

y <- prep$data$Y[i]
dshifted_lognormal_uniform(y, meanlog = mu, sdlog = sigma,
mix = mix, shift = shift, max_uniform = max_uniform)

}


#' A mixture of shifted lognormal and uniform distribution suitable for
#' modelling reaction times.
#'
#' A contaminated response time distribution. The mixture can be described as
#' \deqn{y_i =
#' \begin{cases}
#' u_i & \mathrm{if} \quad z_i = 0 \\
#' p_i s_i + r_i & \mathrm{if} \quad z_i = 1
#' \end{cases}
#' \\
#' u_i \sim Uniform(0, \alpha) \\
#' \log(r_i) \sim Normal(\mu_i, \sigma) \\
#' P(z_i = 0) = \theta \\
#' 0 < p_i < 1
#' }
#' Here \eqn{\mu, \sigma, \theta} (\code{mix}) and \eqn{p} (\code{shiftprop})
#' are estimated, whereas \eqn{s_i} (\code{max_shift})
#' and \eqn{\alpha} (\code{max_uniform}) are given as data via \code{vreal()}.
#'
#' @details
#' Note that you cannot build this distribution with the built-in support
#' for mixtures in \code{brms},
#' because the uniform component is effectively a zero-parameter distribution
#' which cannot be expressed in \code{brms}.
#'

#' @source Idea by
#' Nathaniel Haines (https://twitter.com/Nate__Haines), code by Martin Modrák.
#' Some background, discussion and examples at
#' \url{http://www.martinmodrak.cz/2021/04/01/using-brms-to-model-reaction-times-contaminated-with-errors/}
#'
#' @examples library(brms)
#' set.seed(31546522)
#' # Bounds of the data
#' max_shift <- 0.3
#' shift <- runif(1) * max_shift
#' max_uniform <- 10
#' mix <- 0.1
#'
#' # Generate parameters
#' N <- 100
#' Intercept <- 0.3
#' beta <- 0.5
#' X <- rnorm(N)
#' mu <- rep(Intercept, N) + beta * X
#' sigma <- 0.5
#'
#' rt <- rshifted_lognormal_uniform(N, meanlog = mu, sdlog = sigma, mix = mix,
#' shift = shift, max_uniform = max_uniform)
#'
#' dd <- data.frame(rt = rt, x = X,
#' max_shift = max_shift, max_uniform = max_uniform)
#'
#' fam <- shifted_lognormal_uniform()
#' fit_mix <- brm(rt | vreal(max_shift, max_uniform) ~ x, data = dd, family = fam,
#' stanvars = fam$stanvars,
#' prior = c(prior(beta(1, 5), class = "mix")))
#' plot(fit_mix)
#' @export
shifted_lognormal_uniform <- function(link = "identity", link_sigma = "log",
link_mix = "logit", link_shiftprop = "logit") {
fam <- brms::custom_family(
"shifted_lognormal_uniform",
dpars = c("mu", "sigma", "mix", "shiftprop"), # Those will be estimated
links = c(link, link_sigma, link_mix, link_shiftprop),
type = "real",
lb = c(NA, 0, 0, 0), # bounds for the parameters
ub = c(NA, NA, 1, 1),
vars = c("vreal1[n]", "vreal2[n]"), # Data for max_shift and max_uniform (known)
posterior_predict = posterior_predict_shifted_lognormal_uniform,
posterior_epred = posterior_epred_shifted_lognormal_uniform,
log_lik = log_lik_shifted_lognormal_uniform
)

fam$stanvars <- brms::stanvar(block = "functions", scode = "
real shifted_lognormal_uniform_lpdf(real y, real mu, real sigma, real mix,
real shiftprop, real max_shift, real max_uniform) {
real shift = shiftprop * max_shift;
if(y <= shift) {
// Could only be created by the contamination
return log(mix) + uniform_lpdf(y | 0, max_uniform);
} else if(y >= max_uniform) {
// Could only come from the lognormal
return log1m(mix) + lognormal_lpdf(y - shift | mu, sigma);
} else {
// Actually mixing
real lognormal_llh = lognormal_lpdf(y - shift | mu, sigma);
real uniform_llh = uniform_lpdf(y | 0, max_uniform);
return log_mix(mix, uniform_llh, lognormal_llh);
}
}

real shifted_lognormal_uniform_lcdf(real y, real mu, real sigma, real mix,
real shiftprop, real max_shift, real max_uniform) {
real shift = shiftprop * max_shift;
if(y <= shift) {
return log(mix) + uniform_lcdf(y | 0, max_uniform);
} else if(y >= max_uniform) {
// The whole uniform part is below, so the mixture part is log(1) = 0
return log_mix(mix, 0, lognormal_lcdf(y - shift | mu, sigma));
} else {
real lognormal_llh = lognormal_lcdf(y - shift | mu, sigma);
real uniform_llh = uniform_lcdf(y | 0, max_uniform);
return log_mix(mix, uniform_llh, lognormal_llh);
}
}

real shifted_lognormal_uniform_lccdf(real y, real mu, real sigma, real mix,
real shiftprop, real max_shift, real max_uniform) {

real shift = shiftprop * max_shift;
if(y <= shift) {
// The whole lognormal part is above, so the mixture part is log(1) = 0
return log_mix(mix, uniform_lccdf(y | 0, max_uniform), 0);
} else if(y >= max_uniform) {
return log1m(mix) + lognormal_lccdf(y - shift | mu, sigma);
} else {
real lognormal_llh = lognormal_lccdf(y - shift | mu, sigma);
real uniform_llh = uniform_lccdf(y | 0, max_uniform);
return log_mix(mix, uniform_llh, lognormal_llh);
}

}
")
return(fam)
}

19 changes: 15 additions & 4 deletions R/test-helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ test_rng_quantiles <- function(rng_fun,
#' Vector is used as is, scalar will be interpreted as c(thresh, 1-thresh).
#' Default = 0.05
#' @param debug Scalar Boolean argument, whether debug info is printed or not. Default = False.
#' @param formula the formula used in the brms fit
#' @param prior any priors for brms
#'
#' @return None
#'
Expand Down Expand Up @@ -493,7 +495,9 @@ expect_brms_family <- function(n_data_sampels = 1000,
seed = 1235813,
data_threshold = NULL,
thresh = 0.05,
debug = FALSE) {
debug = FALSE,
formula = y ~ 1,
prior = NULL) {
if (is.null(ref_intercept)) {
ref_intercept <- intercept
}
Expand All @@ -505,7 +509,9 @@ expect_brms_family <- function(n_data_sampels = 1000,
family,
rng,
seed = seed,
data_threshold = data_threshold
data_threshold = data_threshold,
formula = formula,
prior = prior
)

success <- test_brms_quantile(
Expand Down Expand Up @@ -563,6 +569,8 @@ expect_brms_family <- function(n_data_sampels = 1000,
#' @param data_threshold Usually unused. But in rare cases, data too close at the boundary may cause trouble.
#' If so, set a two entry real vector c(lower, upper). If one of them is NA, the data will not be capped for that boundary.
#' Default = Null, will be in R terms "invisible" and will not cap any input data.
#' @param formula the formula used in the brms fit
#' @param prior any priors for brms
#'
#' @return brms model for the specified family.
#'
Expand All @@ -584,7 +592,9 @@ construct_brms <- function(n_data_sampels,
family,
rng,
seed = NULL,
data_threshold = NULL) {
data_threshold = NULL,
formula = y ~ 1,
prior = NULL) {
if (!(is.function(family) && is.function(rng) && is.function(rng_link))) {
stop("family, rng or rng_link argument were not a function!")
}
Expand Down Expand Up @@ -631,8 +641,9 @@ construct_brms <- function(n_data_sampels,
data <- list(y = y_data)

posterior_fit <- brms::brm(
y ~ 1,
formula,
data = data,
prior = prior,
family = family(),
stanvars = family()$stanvars,
chains = 2,
Expand Down
20 changes: 20 additions & 0 deletions man/logsumexp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading