diff --git a/R/RcppExports.R b/R/RcppExports.R index a1a4bc9..6e04e6f 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -5,8 +5,8 @@ run_bgmCompare_parallel <- function(observations, num_groups, counts_per_categor .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type) } -run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) { - .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) +run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) { + .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) } get_explog_switch <- function() { diff --git a/R/bgm.R b/R/bgm.R index e1b1f71..dd4268d 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -207,8 +207,14 @@ #' #' @param beta_bernoulli_alpha,beta_bernoulli_beta Double. Shape parameters #' for the beta distribution in the Beta–Bernoulli and the Stochastic-Block -#' priors. Must be positive. Defaults: \code{beta_bernoulli_alpha = 1} and -#' \code{beta_bernoulli_beta = 1}. +#' priors. Must be positive. For the Stochastic-Block prior these are the shape +#' parameters for the within-cluster edge inclusion probabilities. +#' Defaults: \code{beta_bernoulli_alpha = 1} and \code{beta_bernoulli_beta = 1}. +#' +#' @param beta_bernoulli_alpha_between,beta_bernoulli_beta_between Double. +#' Shape parameters for the between-cluster edge inclusion probabilities in the +#' Stochastic-Block prior. Must be positive. +#' Default: \code{beta_bernoulli_alpha_between = 1} and \code{beta_bernoulli_beta_between = 1} #' #' @param dirichlet_alpha Double. Concentration parameter of the Dirichlet #' prior on block assignments (used with the Stochastic Block model). @@ -359,6 +365,8 @@ bgm = function( inclusion_probability = 0.5, beta_bernoulli_alpha = 1, beta_bernoulli_beta = 1, + beta_bernoulli_alpha_between = 1, + beta_bernoulli_beta_between = 1, dirichlet_alpha = 1, lambda = 1, na_action = c("listwise", "impute"), @@ -418,7 +426,7 @@ bgm = function( } else if(update_method == "hamiltonian-mc") { target_accept = 0.65 } else if(update_method == "nuts") { - target_accept = 0.80 + target_accept = 0.60 } } @@ -444,9 +452,21 @@ bgm = function( inclusion_probability = inclusion_probability, beta_bernoulli_alpha = beta_bernoulli_alpha, beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, dirichlet_alpha = dirichlet_alpha, lambda = lambda) + # check hyperparameters input + # If user left them NULL, pass -1 to C++ (means: ignore between prior) + if (is.null(beta_bernoulli_alpha_between) && is.null(beta_bernoulli_beta_between)) { + beta_bernoulli_alpha_between <- -1.0 + beta_bernoulli_beta_between <- -1.0 + } else if (is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) { + stop("If you wish to specify different between and within cluster probabilites, + provide both beta_bernoulli_alpha_between and beta_bernoulli_beta_between, + otherwise leave both NULL.") + } # ---------------------------------------------------------------------------- # The vector variable_type is now coded as boolean. # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) @@ -572,6 +592,8 @@ bgm = function( inclusion_probability = inclusion_probability, beta_bernoulli_alpha = beta_bernoulli_alpha, beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, dirichlet_alpha = dirichlet_alpha, lambda = lambda, interaction_index_matrix = interaction_index_matrix, iter = iter, warmup = warmup, counts_per_category = counts_per_category, @@ -603,6 +625,7 @@ bgm = function( na_action = na_action, na_impute = na_impute, edge_selection = edge_selection, edge_prior = edge_prior, inclusion_probability = inclusion_probability, beta_bernoulli_alpha = beta_bernoulli_alpha, beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, beta_bernoulli_beta_between = beta_bernoulli_beta_between, dirichlet_alpha = dirichlet_alpha, lambda = lambda, variable_type = variable_type, update_method = update_method, @@ -634,6 +657,8 @@ bgm = function( edge_selection = edge_selection, edge_prior = edge_prior, inclusion_probability = inclusion_probability, beta_bernoulli_alpha = beta_bernoulli_alpha, beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, dirichlet_alpha = dirichlet_alpha, lambda = lambda, variable_type = variable_type, update_method = update_method, diff --git a/R/function_input_utils.R b/R/function_input_utils.R index 7240a47..4377b10 100644 --- a/R/function_input_utils.R +++ b/R/function_input_utils.R @@ -31,6 +31,8 @@ check_model = function(x, inclusion_probability = 0.5, beta_bernoulli_alpha = 1, beta_bernoulli_beta = 1, + beta_bernoulli_alpha_between = 1, + beta_bernoulli_beta_between = 1, dirichlet_alpha = dirichlet_alpha, lambda = lambda) { @@ -204,18 +206,42 @@ check_model = function(x, is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta)) stop("Values for both scale parameters of the beta distribution need to be specified.") } + if(edge_prior == "Stochastic-Block") { theta = matrix(0.5, nrow = ncol(x), ncol = ncol(x)) - if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0 || dirichlet_alpha <= 0 || lambda <= 0) - stop("The scale parameters of the beta and Dirichlet distribution need to be positive.") - if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta) || !is.finite(dirichlet_alpha) || !is.finite(lambda)) - stop("The scale parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, and the rate parameter of the Poisson distribution need to be finite.") - if(is.na(beta_bernoulli_alpha) || is.na(beta_bernoulli_beta) || - is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta) || - is.null(dirichlet_alpha) || is.null(dirichlet_alpha) || is.null(lambda) || is.null(lambda)) - stop("Values for both scale parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, and the rate parameter of the Poisson distribution need to be specified.") + + # Check that all beta parameters are provided + if (is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta) || + is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) { + stop("The Stochastic-Block prior requires all four beta parameters: ", + "beta_bernoulli_alpha, beta_bernoulli_beta, ", + "beta_bernoulli_alpha_between, and beta_bernoulli_beta_between.") + } + + # Check that all beta parameters are positive + if (beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0 || + beta_bernoulli_alpha_between <= 0 || beta_bernoulli_beta_between <= 0 || + dirichlet_alpha <= 0 || lambda <= 0) { + stop("The parameters of the beta and Dirichlet distributions need to be positive.") + } + + # Check that all beta parameters are finite + if (!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta) || + !is.finite(beta_bernoulli_alpha_between) || !is.finite(beta_bernoulli_beta_between) || + !is.finite(dirichlet_alpha) || !is.finite(lambda)) { + stop("The shape parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, ", + "and the rate parameter of the Poisson distribution need to be finite.") + } + + # Check for NAs + if (is.na(beta_bernoulli_alpha) || is.na(beta_bernoulli_beta) || + is.na(beta_bernoulli_alpha_between) || is.na(beta_bernoulli_beta_between) || + is.na(dirichlet_alpha) || is.na(lambda)) { + stop("Values for all shape parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, ", + "and the rate parameter of the Poisson distribution cannot be NA.") + } } - } else { + }else { theta = matrix(0.5, nrow = 1, ncol = 1) edge_prior = "Not Applicable" } diff --git a/R/output_utils.R b/R/output_utils.R index 8789e6e..37f3ab4 100644 --- a/R/output_utils.R +++ b/R/output_utils.R @@ -2,7 +2,8 @@ prepare_output_bgm = function( out, x, num_categories, iter, data_columnnames, is_ordinal_variable, warmup, pairwise_scale, main_alpha, main_beta, na_action, na_impute, edge_selection, edge_prior, inclusion_probability, - beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, + beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, + beta_bernoulli_beta_between,dirichlet_alpha, lambda, variable_type, update_method, target_accept, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains ) { @@ -22,6 +23,8 @@ prepare_output_bgm = function( inclusion_probability = inclusion_probability, beta_bernoulli_alpha = beta_bernoulli_alpha, beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, dirichlet_alpha = dirichlet_alpha, lambda = lambda, na_action = na_action, diff --git a/man/bgm.Rd b/man/bgm.Rd index d71c0cf..b0982d9 100644 --- a/man/bgm.Rd +++ b/man/bgm.Rd @@ -18,6 +18,8 @@ bgm( inclusion_probability = 0.5, beta_bernoulli_alpha = 1, beta_bernoulli_beta = 1, + beta_bernoulli_alpha_between = 1, + beta_bernoulli_beta_between = 1, dirichlet_alpha = 1, lambda = 1, na_action = c("listwise", "impute"), @@ -81,8 +83,14 @@ of each edge (used with the Bernoulli prior). Default: \code{0.5}.} \item{beta_bernoulli_alpha, beta_bernoulli_beta}{Double. Shape parameters for the beta distribution in the Beta–Bernoulli and the Stochastic-Block -priors. Must be positive. Defaults: \code{beta_bernoulli_alpha = 1} and -\code{beta_bernoulli_beta = 1}.} +priors. Must be positive. For the Stochastic-Block prior these are the shape +parameters for the within-cluster edge inclusion probabilities. +Defaults: \code{beta_bernoulli_alpha = 1} and \code{beta_bernoulli_beta = 1}.} + +\item{beta_bernoulli_alpha_between, beta_bernoulli_beta_between}{Double. +Shape parameters for the between-cluster edge inclusion probabilities in the +Stochastic-Block prior. Must be positive. +Default: \code{beta_bernoulli_alpha_between = 1} and \code{beta_bernoulli_beta_between = 1}} \item{dirichlet_alpha}{Double. Concentration parameter of the Dirichlet prior on block assignments (used with the Stochastic Block model). diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 0d829f6..f3fc966 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -58,8 +58,8 @@ BEGIN_RCPP END_RCPP } // run_bgm_parallel -Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int warmup, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, int seed, int progress_type); -RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP warmupSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) { +Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double beta_bernoulli_alpha_between, double beta_bernoulli_beta_between, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int warmup, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, int seed, int progress_type); +RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP warmupSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -70,6 +70,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const arma::mat& >::type inclusion_probability(inclusion_probabilitySEXP); Rcpp::traits::input_parameter< double >::type beta_bernoulli_alpha(beta_bernoulli_alphaSEXP); Rcpp::traits::input_parameter< double >::type beta_bernoulli_beta(beta_bernoulli_betaSEXP); + Rcpp::traits::input_parameter< double >::type beta_bernoulli_alpha_between(beta_bernoulli_alpha_betweenSEXP); + Rcpp::traits::input_parameter< double >::type beta_bernoulli_beta_between(beta_bernoulli_beta_betweenSEXP); Rcpp::traits::input_parameter< double >::type dirichlet_alpha(dirichlet_alphaSEXP); Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP); Rcpp::traits::input_parameter< const arma::imat& >::type interaction_index_matrix(interaction_index_matrixSEXP); @@ -95,7 +97,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP); Rcpp::traits::input_parameter< int >::type seed(seedSEXP); Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); - rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)); + rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)); return rcpp_result_gen; END_RCPP } @@ -182,7 +184,7 @@ END_RCPP static const R_CallMethodDef CallEntries[] = { {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 36}, - {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 32}, + {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 34}, {"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0}, {"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1}, {"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1}, diff --git a/src/bgm_parallel.cpp b/src/bgm_parallel.cpp index 5664098..f0e6c94 100644 --- a/src/bgm_parallel.cpp +++ b/src/bgm_parallel.cpp @@ -45,6 +45,8 @@ struct GibbsChainRunner : public Worker { const arma::mat& inclusion_probability; double beta_bernoulli_alpha; double beta_bernoulli_beta; + double beta_bernoulli_alpha_between; + double beta_bernoulli_beta_between; double dirichlet_alpha; double lambda; const arma::imat& interaction_index_matrix; @@ -82,6 +84,8 @@ struct GibbsChainRunner : public Worker { const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, + double beta_bernoulli_alpha_between, + double beta_bernoulli_beta_between, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, @@ -114,6 +118,8 @@ struct GibbsChainRunner : public Worker { inclusion_probability(inclusion_probability), beta_bernoulli_alpha(beta_bernoulli_alpha), beta_bernoulli_beta(beta_bernoulli_beta), + beta_bernoulli_alpha_between(beta_bernoulli_alpha_between), + beta_bernoulli_beta_between(beta_bernoulli_beta_between), dirichlet_alpha(dirichlet_alpha), lambda(lambda), interaction_index_matrix(interaction_index_matrix), @@ -154,11 +160,13 @@ struct GibbsChainRunner : public Worker { chain_result, observations, num_categories, - pairwise_scale, + pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, + beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, @@ -252,6 +260,8 @@ Rcpp::List run_bgm_parallel( const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, + double beta_bernoulli_alpha_between, + double beta_bernoulli_beta_between, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, @@ -296,6 +306,7 @@ Rcpp::List run_bgm_parallel( GibbsChainRunner worker( observations, num_categories, pairwise_scale, edge_prior_enum, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, diff --git a/src/bgm_sampler.cpp b/src/bgm_sampler.cpp index 592e20a..b1397f2 100644 --- a/src/bgm_sampler.cpp +++ b/src/bgm_sampler.cpp @@ -1174,6 +1174,8 @@ void run_gibbs_sampler_bgm( arma::mat inclusion_probability, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, + const double beta_bernoulli_alpha_between, + const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const arma::imat& interaction_index_matrix, @@ -1255,7 +1257,8 @@ void run_gibbs_sampler_bgm( cluster_prob = block_probs_mfm_sbm( cluster_allocations, arma::conv_to::from(inclusion_indicator), - num_variables, beta_bernoulli_alpha, beta_bernoulli_beta, rng + num_variables, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, + beta_bernoulli_beta_between, rng ); for (int i = 0; i < num_variables - 1; i++) { @@ -1360,13 +1363,15 @@ void run_gibbs_sampler_bgm( cluster_allocations = block_allocations_mfm_sbm( cluster_allocations, num_variables, log_Vn, cluster_prob, arma::conv_to::from(inclusion_indicator), dirichlet_alpha, - beta_bernoulli_alpha, beta_bernoulli_beta, rng + beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, + beta_bernoulli_beta_between, rng ); cluster_prob = block_probs_mfm_sbm( cluster_allocations, arma::conv_to::from(inclusion_indicator), num_variables, - beta_bernoulli_alpha, beta_bernoulli_beta, rng + beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, + beta_bernoulli_beta_between, rng ); for (int i = 0; i < num_variables - 1; i++) { diff --git a/src/bgm_sampler.h b/src/bgm_sampler.h index 05875e2..2fcbf35 100644 --- a/src/bgm_sampler.h +++ b/src/bgm_sampler.h @@ -15,6 +15,8 @@ void run_gibbs_sampler_bgm( arma::mat inclusion_probability, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, + const double beta_bernoulli_alpha_between, + const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const arma::imat& interaction_index_matrix, diff --git a/src/sbm_edge_prior.cpp b/src/sbm_edge_prior.cpp index 4264878..182201b 100644 --- a/src/sbm_edge_prior.cpp +++ b/src/sbm_edge_prior.cpp @@ -27,11 +27,14 @@ arma::uvec table_cpp(arma::uvec x) { // ----------------------------------------------------------------------------| // Add a row and column to a matrix (and fill with beta variables) +// Modified to support separate within/between cluster hyperparameters // ----------------------------------------------------------------------------| arma::mat add_row_col_block_prob_matrix(arma::mat X, - double beta_alpha, - double beta_beta, - SafeRNG& rng) { + double beta_alpha, + double beta_beta, + SafeRNG& rng, + double beta_bernoulli_alpha_between, + double beta_bernoulli_beta_between) { arma::uword dim = X.n_rows; arma::mat Y(dim+1,dim+1,arma::fill::zeros); @@ -41,10 +44,14 @@ arma::mat add_row_col_block_prob_matrix(arma::mat X, } } + // Add new row and column for the new cluster for(arma::uword i = 0; i < dim; i++) { - Y(dim, i) = rbeta(rng, beta_alpha, beta_beta); + // Between-cluster edge probabilities (new cluster to existing clusters) + Y(dim, i) = rbeta(rng, beta_bernoulli_alpha_between, beta_bernoulli_beta_between); Y(i, dim) = Y(dim, i); } + + // Within-cluster edge probability (diagonal element for new cluster) Y(dim, dim) = rbeta(rng, beta_alpha, beta_beta); return Y; @@ -56,9 +63,9 @@ arma::mat add_row_col_block_prob_matrix(arma::mat X, // ----------------------------------------------------------------------------| // [[Rcpp::export]] arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, - double dirichlet_alpha, - arma::uword t_max, - double lambda) { + double dirichlet_alpha, + arma::uword t_max, + double lambda) { arma::vec log_Vn(t_max); double r; @@ -107,13 +114,16 @@ double log_likelihood_mfm_sbm(arma::uvec cluster_assign, // ----------------------------------------------------------------------------| // Compute log-marginal for the MFM - SBM +// Modified to support separate within/between cluster hyperparameters // ----------------------------------------------------------------------------| double log_marginal_mfm_sbm(arma::uvec cluster_assign, arma::umat indicator, arma::uword node, arma::uword no_variables, double beta_bernoulli_alpha, - double beta_bernoulli_beta) { + double beta_bernoulli_beta, + double beta_bernoulli_alpha_between, + double beta_bernoulli_beta_between) { arma::uvec indices = arma::regspace(0, no_variables-1); // vector of variables indices [0, 1, ..., no_variables-1] arma::uvec select_variables = indices(arma::find(indices != node)); // vector of variables indices excluding 'node' @@ -121,13 +131,24 @@ double log_marginal_mfm_sbm(arma::uvec cluster_assign, arma::uvec indicator_node = indicator.col(node); // column of indicator matrix corresponding to 'node' arma::vec gamma_node = arma::conv_to::from(indicator_node(select_variables)); // selecting only indicators between 'node' and the remaining variables (thus excluding indicator of node with itself -- that is indicator[node,node]) arma::uvec table_cluster = table_cpp(cluster_assign_wo_node); // frequency table of clusters excluding node + + // Get the cluster assignment of the current node + arma::uword node_cluster = cluster_assign(node); + double output = 0; for(arma::uword i = 0; i < table_cluster.n_elem; i++){ if(table_cluster(i) > 0){ // if the cluster is empty -- table_cluster(i) = 0 == then it is the previous cluster of 'node' where 'node' was the only member - a singleton, thus skip) arma::uvec which_variables_cluster_i = arma::find(cluster_assign_wo_node == i); // which variables belong to cluster i double sumG = arma::accu(gamma_node(which_variables_cluster_i)); // sum the indicator variables between node and those variables double sumN = static_cast(which_variables_cluster_i.n_elem); // take the size of the group as maximum number of relations - output += R::lbeta(sumG + beta_bernoulli_alpha, sumN - sumG + beta_bernoulli_beta) - R::lbeta(beta_bernoulli_alpha, beta_bernoulli_beta); // calculate log-density for cluster i and sum it to the marginal log-likelihood + + // Determine if this is within-cluster or between-cluster + bool is_within_cluster = (i == node_cluster); + + double alpha = is_within_cluster ? beta_bernoulli_alpha : beta_bernoulli_alpha_between; + double beta = is_within_cluster ? beta_bernoulli_beta : beta_bernoulli_beta_between; + + output += R::lbeta(sumG + alpha, sumN - sumG + beta) - R::lbeta(alpha, beta); // calculate log-density for cluster i and sum it to the marginal log-likelihood } } return output; @@ -171,16 +192,20 @@ arma::uword sample_cluster(arma::vec cluster_prob, // ----------------------------------------------------------------------------| // Sample the block allocations for the MFM - SBM +// Modified to support separate within/between cluster hyperparameters // ----------------------------------------------------------------------------| arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign, - arma::uword no_variables, - arma::vec log_Vn, - arma::mat block_probs, - arma::umat indicator, - arma::uword dirichlet_alpha, - double beta_bernoulli_alpha, - double beta_bernoulli_beta, - SafeRNG& rng) { + arma::uword no_variables, + arma::vec log_Vn, + arma::mat block_probs, + arma::umat indicator, + arma::uword dirichlet_alpha, + double beta_bernoulli_alpha, + double beta_bernoulli_beta, + double beta_bernoulli_alpha_between, + double beta_bernoulli_beta_between, + SafeRNG& rng) { + arma::uword old; arma::uword cluster; arma::uword no_clusters; @@ -212,14 +237,14 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign, if (c < no_clusters) { if(c != old){ - loglike = log_likelihood_mfm_sbm(cluster_assign_tmp, - block_probs, - indicator, - node, - no_variables); - - prob = (static_cast(dirichlet_alpha) + static_cast(cluster_size_node(c))) * - MY_EXP(loglike); + loglike = log_likelihood_mfm_sbm(cluster_assign_tmp, + block_probs, + indicator, + node, + no_variables); + + prob = (static_cast(dirichlet_alpha) + static_cast(cluster_size_node(c))) * + MY_EXP(loglike); } else{ // if old group, the probability is set to 0.0 prob = 0.0; @@ -231,7 +256,9 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign, node, no_variables, beta_bernoulli_alpha, - beta_bernoulli_beta); + beta_bernoulli_beta, + beta_bernoulli_alpha_between, + beta_bernoulli_beta_between); prob = static_cast(dirichlet_alpha) * MY_EXP(logmarg) * @@ -283,7 +310,9 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign, node, no_variables, beta_bernoulli_alpha, - beta_bernoulli_beta); + beta_bernoulli_beta, + beta_bernoulli_alpha_between, + beta_bernoulli_beta_between); prob = static_cast(dirichlet_alpha) * MY_EXP(logmarg) * @@ -302,7 +331,10 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign, if (cluster == no_clusters) { block_probs = add_row_col_block_prob_matrix(block_probs, beta_bernoulli_alpha, - beta_bernoulli_beta, rng); + beta_bernoulli_beta, + rng, + beta_bernoulli_alpha_between, + beta_bernoulli_beta_between); } } } @@ -312,13 +344,16 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign, // ----------------------------------------------------------------------------| // Sample the block parameters for the MFM - SBM +// Modified to support separate within/between cluster hyperparameters // ----------------------------------------------------------------------------| arma::mat block_probs_mfm_sbm(arma::uvec cluster_assign, - arma::umat indicator, - arma::uword no_variables, - double beta_bernoulli_alpha, - double beta_bernoulli_beta, - SafeRNG& rng) { + arma::umat indicator, + arma::uword no_variables, + double beta_bernoulli_alpha, + double beta_bernoulli_beta, + double beta_bernoulli_alpha_between, + double beta_bernoulli_beta_between, + SafeRNG& rng) { arma::uvec cluster_size = table_cpp(cluster_assign); arma::uword no_clusters = cluster_size.n_elem; @@ -332,18 +367,25 @@ arma::mat block_probs_mfm_sbm(arma::uvec cluster_assign, for(arma::uword r = 0; r < no_clusters; r++) { for(arma::uword s = r; s < no_clusters; s++) { sumG = 0; + if(r == s) { + // Within-cluster: always use main parameters update_sumG(sumG, cluster_assign, indicator, r, r, no_variables); size = static_cast(cluster_size(r)) * (static_cast(cluster_size(r)) - 1) / 2; + block_probs(r, s) = rbeta(rng, + sumG + beta_bernoulli_alpha, + size - sumG + beta_bernoulli_beta); } else { + // Between-cluster: use between parameters update_sumG(sumG, cluster_assign, indicator, r, s, no_variables); update_sumG(sumG, cluster_assign, indicator, s, r, no_variables); size = static_cast(cluster_size(s)) * static_cast(cluster_size(r)); + + block_probs(r, s) = rbeta(rng, sumG + beta_bernoulli_alpha_between, size - sumG + beta_bernoulli_beta_between); } - block_probs(r, s) = rbeta(rng, sumG + beta_bernoulli_alpha, size - sumG + beta_bernoulli_beta); block_probs(s, r) = block_probs(r, s); } } return block_probs; -} +} \ No newline at end of file diff --git a/src/sbm_edge_prior.h b/src/sbm_edge_prior.h index 46c035f..a782347 100644 --- a/src/sbm_edge_prior.h +++ b/src/sbm_edge_prior.h @@ -23,6 +23,8 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign, arma::uword dirichlet_alpha, double beta_bernoulli_alpha, double beta_bernoulli_beta, + double beta_bernoulli_alpha_between, + double beta_bernoulli_beta_between, SafeRNG& rng); // ----------------------------------------------------------------------------| @@ -33,4 +35,6 @@ arma::mat block_probs_mfm_sbm(arma::uvec cluster_assign, arma::uword no_variables, double beta_bernoulli_alpha, double beta_bernoulli_beta, + double beta_bernoulli_alpha_between, + double beta_bernoulli_beta_between, SafeRNG& rng); \ No newline at end of file