Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
^doc$
^Meta$
^\.vscode$
^dev/
9 changes: 5 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
## Other changes

* reparameterized the Blume-capel model to use (score-baseline) instead of score.
* implemented a new way to compute the denominators and probabilities. This made their computation both faster and more stable.

## Bug fixes

* Fixed numerical problems with Blume-Capel variables using HMC and NUTS for bgm().
* fixed numerical problems with Blume-Capel variables using HMC and NUTS.

# bgms 0.1.6.1

Expand All @@ -22,9 +23,9 @@

## Bug fixes

* Fixed a problem with warmup scheduling for adaptive-metropolis in bgmCompare()
* Fixed stability problems with parallel sampling for bgm()
* Fixed spurious output errors printing to console after user interrupt.
* fixed a problem with warmup scheduling for adaptive-metropolis in bgmCompare()
* fixed stability problems with parallel sampling for bgm()
* fixed spurious output errors printing to console after user interrupt.

# bgms 0.1.6.0

Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ sample_omrf_gibbs <- function(no_states, no_variables, no_categories, interactio
.Call(`_bgms_sample_omrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, iter)
}

sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter) {
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter)
sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, variable_type, baseline_category, iter) {
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, baseline_category, iter)
}

compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) {
Expand Down
6 changes: 3 additions & 3 deletions R/bgm.R
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,9 @@ bgm = function(
# Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE)
bc_vars = which(!variable_bool)
for(i in bc_vars) {
blume_capel_stats[1, i] = sum(x[, i])
blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i])^2)
blume_capel_stats[1, i] = sum(x[, i] - baseline_category[i])
blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2)
x[, i] = x[, i] - baseline_category[i]
}
}
pairwise_stats = t(x) %*% x
Expand Down Expand Up @@ -627,7 +628,6 @@ bgm = function(
nThreads = cores, seed = seed, progress_type = progress_type
)


userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
if(userInterrupt) {
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")
Expand Down
7 changes: 4 additions & 3 deletions R/bgmCompare.R
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ bgmCompare = function(
} else if(update_method == "hamiltonian-mc") {
target_accept = 0.65
} else if(update_method == "nuts") {
target_accept = 0.80
target_accept = 0.65
}
}

Expand Down Expand Up @@ -414,13 +414,15 @@ bgmCompare = function(
blume_capel_stats = compute_blume_capel_stats(
x, baseline_category, ordinal_variable, group
)
for (i in which(!ordinal_variable)) {
x[, i] = x[, i] - baseline_category[i]
}

# Compute sufficient statistics for pairwise interactions
pairwise_stats = compute_pairwise_stats(
x, group
)


# Index vector used to sample interactions in a random order -----------------
Index = matrix(0, nrow = num_interactions, ncol = 3)
counter = 0
Expand Down Expand Up @@ -490,7 +492,6 @@ bgmCompare = function(

seed <- as.integer(seed)


# Call the Rcpp function
out = run_bgmCompare_parallel(
observations = observations,
Expand Down
20 changes: 10 additions & 10 deletions R/data_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ compute_counts_per_category = function(x, num_categories, group = NULL) {
counts_per_category_gr[category, variable] = sum(x[group == g, variable] == category)
}
}
counts_per_category[[g]] = counts_per_category_gr
counts_per_category[[length(counts_per_category) + 1]] = counts_per_category_gr
}
return(counts_per_category)
}
Expand All @@ -253,34 +253,34 @@ compute_blume_capel_stats = function(x, baseline_category, ordinal_variable, gro
if(is.null(group)) { # One-group design
sufficient_stats = matrix(0, nrow = 2, ncol = ncol(x))
bc_vars = which(!ordinal_variable)
for(i in bc_vars) {
sufficient_stats[1, i] = sum(x[, i])
sufficient_stats[2, i] = sum((x[, i] - baseline_category[i])^2)
for (i in bc_vars) {
sufficient_stats[1, i] = sum(x[, i] - baseline_category[i])
sufficient_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2)
}
return(sufficient_stats)
} else { # Multi-group design
sufficient_stats = list()
for(g in unique(group)) {
sufficient_stats_gr = matrix(0, nrow = 2, ncol = ncol(x))
bc_vars = which(!ordinal_variable)
for(i in bc_vars) {
sufficient_stats_gr[1, i] = sum(x[group == g, i])
sufficient_stats_gr[2, i] = sum((x[group == g, i] - baseline_category[i])^2)
for (i in bc_vars) {
sufficient_stats_gr[1, i] = sum(x[group == g, i] - baseline_category[i])
sufficient_stats_gr[2, i] = sum((x[group == g, i] - baseline_category[i]) ^ 2)
}
sufficient_stats[[g]] = sufficient_stats_gr
sufficient_stats[[length(sufficient_stats) + 1]] = sufficient_stats_gr
}
return(sufficient_stats)
}
}

# Helper function for computing sufficient statistics for pairwise interactions
compute_pairwise_stats <- function(x, group) {
result <- vector("list", length(unique(group)))
result <- list()

for(g in unique(group)) {
obs <- x[group == g, , drop = FALSE]
# cross-product: gives number of co-occurrences of categories
result[[g]] <- t(obs) %*% obs
result[[length(result) + 1]] <- t(obs) %*% obs
}

result
Expand Down
40 changes: 18 additions & 22 deletions R/nuts_diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,16 @@ summarize_nuts_diagnostics <- function(out, nuts_max_depth = 10, verbose = TRUE)
100 * divergence_rate,
total_divergences,
nrow(divergent_mat) * ncol(divergent_mat)
), "Consider increasing the target acceptance rate.")
} else if(divergence_rate > 0) {
message(
sprintf(
"Note: %.3f%% of transitions ended with a divergence (%d of %d).\n",
100 * divergence_rate,
total_divergences,
nrow(divergent_mat) * ncol(divergent_mat)
),
"Check R-hat and effective sample size (ESS) to ensure the chains are\n",
"mixing well."
)
), "Consider increasing the target acceptance rate or change to update_method = ``adaptive-metropolis''.")
} else if (divergence_rate > 0) {
message(sprintf(
"Note: %.3f%% of transitions ended with a divergence (%d of %d).\n",
100 * divergence_rate,
total_divergences,
nrow(divergent_mat) * ncol(divergent_mat)
),
"Check R-hat and effective sample size (ESS) to ensure the chains are\n",
"mixing well.")
}

depth_hit_rate <- max_tree_depth_hits / (nrow(treedepth_mat) * ncol(treedepth_mat))
Expand Down Expand Up @@ -84,16 +82,14 @@ summarize_nuts_diagnostics <- function(out, nuts_max_depth = 10, verbose = TRUE)
low_ebfmi_chains <- which(ebfmi_per_chain < 0.3)
min_ebfmi <- min(ebfmi_per_chain)

if(length(low_ebfmi_chains) > 0) {
warning(
sprintf(
"E-BFMI below 0.3 detected in %d chain(s): %s.\n",
length(low_ebfmi_chains),
paste(low_ebfmi_chains, collapse = ", ")
),
"This suggests inefficient momentum resampling in those chains.\n",
"Sampling efficiency may be reduced. Consider longer chains or checking convergence diagnostics."
)
if (length(low_ebfmi_chains) > 0) {
warning(sprintf(
"E-BFMI below 0.3 detected in %d chain(s): %s.\n",
length(low_ebfmi_chains),
paste(low_ebfmi_chains, collapse = ", ")
),
"This suggests inefficient momentum resampling in those chains.\n",
"Sampling efficiency may be reduced. Consider longer chains and check convergence diagnostics.")
}
}

Expand Down
42 changes: 19 additions & 23 deletions R/sampleMRF.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#' in specifying their model.
#'
#' The Blume-Capel option is specifically designed for ordinal variables that
#' have a special type of reference_category category, such as the neutral
#' have a special type of baseline_category category, such as the neutral
#' category in a Likert scale. The Blume-Capel model specifies the following
#' quadratic model for the threshold parameters:
#' \deqn{\mu_{\text{c}} = \alpha \times \text{c} + \beta \times (\text{c} - \text{r})^2,}{{\mu_{\text{c}} = \alpha \times \text{c} + \beta \times (\text{c} - \text{r})^2,}}
Expand All @@ -23,8 +23,8 @@
#' \eqn{\alpha > 0}{\alpha > 0} and decreasing threshold values if
#' \eqn{\alpha <0}{\alpha <0}), if \eqn{\beta < 0}{\beta < 0}, it offers an
#' increasing penalty for responding in a category further away from the
#' reference_category category r, while \eqn{\beta > 0}{\beta > 0} suggests a
#' preference for responding in the reference_category category.
#' baseline_category category r, while \eqn{\beta > 0}{\beta > 0} suggests a
#' preference for responding in the baseline_category category.
#'
#' @param no_states The number of states of the ordinal MRF to be generated.
#'
Expand Down Expand Up @@ -53,8 +53,8 @@
#' ``blume-capel''. Binary variables are automatically treated as ``ordinal’’.
#' Defaults to \code{variable_type = "ordinal"}.
#'
#' @param reference_category An integer vector of length \code{no_variables} specifying the
#' reference_category category that is used for the Blume-Capel model (details below).
#' @param baseline_category An integer vector of length \code{no_variables} specifying the
#' baseline_category category that is used for the Blume-Capel model (details below).
#' Can be any integer value between \code{0} and \code{no_categories} (or
#' \code{no_categories[i]}).
#'
Expand Down Expand Up @@ -106,7 +106,7 @@
#' interactions = Interactions,
#' thresholds = Thresholds,
#' variable_type = c("b", "b", "o", "b", "o"),
#' reference_category = 2
#' baseline_category = 2
#' )
#'
#' @export
Expand All @@ -116,7 +116,7 @@ mrfSampler = function(no_states,
interactions,
thresholds,
variable_type = "ordinal",
reference_category,
baseline_category,
iter = 1e3) {
# Check no_states, no_variables, iter --------------------------------------------
if(no_states <= 0 ||
Expand Down Expand Up @@ -187,24 +187,20 @@ mrfSampler = function(no_states,
}
}

# Check the reference_category for Blume-Capel variables ---------------------
# Check the baseline_category for Blume-Capel variables ---------------------
if(any(variable_type == "blume-capel")) {
if(length(reference_category) == 1) {
reference_category = rep(reference_category, no_variables)
if(length(baseline_category) == 1) {
baseline_category = rep(baseline_category, no_variables)
}
if(any(reference_category < 0) || any(abs(reference_category - round(reference_category)) > .Machine$double.eps)) {
stop(paste0(
"For variables ",
which(reference_category < 0),
" ``reference_category'' was either negative or not integer."
))
if(any(baseline_category < 0) || any(abs(baseline_category - round(baseline_category)) > .Machine$double.eps)) {
stop(paste0("For variables ",
which(baseline_category < 0),
" ``baseline_category'' was either negative or not integer."))
}
if(any(reference_category - no_categories > 0)) {
stop(paste0(
"For variables ",
which(reference_category - no_categories > 0),
" the ``reference_category'' category was larger than the maximum category value."
))
if(any(baseline_category - no_categories > 0)) {
stop(paste0("For variables ",
which(baseline_category - no_categories > 0),
" the ``baseline_category'' category was larger than the maximum category value."))
}
}

Expand Down Expand Up @@ -347,7 +343,7 @@ mrfSampler = function(no_states,
interactions = interactions,
thresholds = thresholds,
variable_type = variable_type,
reference_category = reference_category,
baseline_category = baseline_category,
iter = iter
)
}
Expand Down
Loading