Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,15 +1,35 @@
^renv$
^renv\.lock$
# RStudio / IDE
^.*\.Rproj$
^\.Rproj\.user$
^\.vscode$

^Readme.Rmd$
^\.github$
^_pkgdown\.yml$
# renv
^renv$
^renv\.lock$

# pkgdown / docs
^docs$
^pkgdown$
^vignettes/introduction_cache
^_pkgdown\.yml$
^Readme\.Rmd$
^vignettes/introduction_cache$

# GitHub / CI
^\.github$

# R CMD build artifacts
^doc$
^Meta$
^\.vscode$
^dev/

# Development helpers
^dev$

# ---- C/C++ build artifacts (REQUIRED) ----
^src/.*\.o$
^src/.*\.so$
^src/.*\.dll$

# ---- Generated build files ----
^src/Makevars$
^src/Makevars\.win$
^src/sources\.mk$
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
src/*.o
src/*.so
src/*.dll
src/**/*.o
src/**/*.so
src/**/*.dll
.DS_Store
/doc/
/Meta/
Expand Down
12 changes: 0 additions & 12 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,6 @@ run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_
.Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)
}

get_explog_switch <- function() {
.Call(`_bgms_get_explog_switch`)
}

rcpp_ieee754_exp <- function(x) {
.Call(`_bgms_rcpp_ieee754_exp`, x)
}

rcpp_ieee754_log <- function(x) {
.Call(`_bgms_rcpp_ieee754_log`, x)
}

sample_omrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, iter) {
.Call(`_bgms_sample_omrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, iter)
}
Expand Down
26 changes: 26 additions & 0 deletions R/generate_makevars_sources.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
cpp <- list.files(
"src",
pattern = "\\.cpp$",
recursive = TRUE,
full.names = TRUE
)

# strip leading "src/"
cpp <- sub("^src/", "", cpp)

con <- file("src/sources.mk", open = "w")

writeLines(c(
"# ------------------------------------------------------------------",
"# THIS FILE IS AUTO-GENERATED - DO NOT EDIT",
"# Generated by configure",
"# To add C++ code, place .cpp files anywhere under src/",
"# ------------------------------------------------------------------",
"SOURCES = \\"
), con)

writeLines(paste0(" ", cpp, " \\"), con)
writeLines("", con)

close(con)

5 changes: 4 additions & 1 deletion configure
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#!/bin/sh

# Get flags from RcppParallel
# RcppParallel flags
RCPP_PARALLEL_CPPFLAGS=`"${R_HOME}/bin/Rscript" -e "cat(RcppParallel::CxxFlags())"`
RCPP_PARALLEL_LIBS=`"${R_HOME}/bin/Rscript" -e "cat(RcppParallel::LdFlags())"`

# Generate sources.mk using R
"${R_HOME}/bin/Rscript" R/generate_makevars_sources.R > src/sources.mk

# Substitute into Makevars
sed -e "s|@RCPP_PARALLEL_CPPFLAGS@|${RCPP_PARALLEL_CPPFLAGS}|" \
-e "s|@RCPP_PARALLEL_LIBS@|${RCPP_PARALLEL_LIBS}|" \
Expand Down
4 changes: 4 additions & 0 deletions configure.win
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#!/bin/sh

# RcppParallel flags
RCPP_PARALLEL_CPPFLAGS=`"${R_HOME}/bin${R_ARCH_BIN}/Rscript.exe" -e "cat(RcppParallel::CxxFlags())"`
RCPP_PARALLEL_LIBS=`"${R_HOME}/bin${R_ARCH_BIN}/Rscript.exe" -e "cat(RcppParallel::LdFlags())"`

# Generate sources.mk using R
"${R_HOME}/bin/Rscript" R/generate_makevars_sources.R > src/sources.mk

# Substitute into Makevars.win
sed -e "s|@RCPP_PARALLEL_CPPFLAGS@|${RCPP_PARALLEL_CPPFLAGS}|" \
-e "s|@RCPP_PARALLEL_LIBS@|${RCPP_PARALLEL_LIBS}|" \
Expand Down
6 changes: 5 additions & 1 deletion src/Makevars.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
CXX_STD = CXX20

PKG_CPPFLAGS = @RCPP_PARALLEL_CPPFLAGS@ -DARMA_NO_DEBUG
include sources.mk

OBJECTS = $(SOURCES:.cpp=.o)

PKG_CPPFLAGS = @RCPP_PARALLEL_CPPFLAGS@ -DARMA_NO_DEBUG -I.

PKG_LIBS = $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS) @RCPP_PARALLEL_LIBS@
35 changes: 0 additions & 35 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,38 +101,6 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// get_explog_switch
Rcpp::String get_explog_switch();
RcppExport SEXP _bgms_get_explog_switch() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_explog_switch());
return rcpp_result_gen;
END_RCPP
}
// rcpp_ieee754_exp
Rcpp::NumericVector rcpp_ieee754_exp(Rcpp::NumericVector x);
RcppExport SEXP _bgms_rcpp_ieee754_exp(SEXP xSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type x(xSEXP);
rcpp_result_gen = Rcpp::wrap(rcpp_ieee754_exp(x));
return rcpp_result_gen;
END_RCPP
}
// rcpp_ieee754_log
Rcpp::NumericVector rcpp_ieee754_log(Rcpp::NumericVector x);
RcppExport SEXP _bgms_rcpp_ieee754_log(SEXP xSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type x(xSEXP);
rcpp_result_gen = Rcpp::wrap(rcpp_ieee754_log(x));
return rcpp_result_gen;
END_RCPP
}
// sample_omrf_gibbs
IntegerMatrix sample_omrf_gibbs(int no_states, int no_variables, IntegerVector no_categories, NumericMatrix interactions, NumericMatrix thresholds, int iter);
RcppExport SEXP _bgms_sample_omrf_gibbs(SEXP no_statesSEXP, SEXP no_variablesSEXP, SEXP no_categoriesSEXP, SEXP interactionsSEXP, SEXP thresholdsSEXP, SEXP iterSEXP) {
Expand Down Expand Up @@ -185,9 +153,6 @@ END_RCPP
static const R_CallMethodDef CallEntries[] = {
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 36},
{"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 34},
{"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0},
{"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1},
{"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1},
{"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6},
{"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8},
{"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4},
Expand Down
4 changes: 2 additions & 2 deletions src/bgm_helper.cpp → src/bgm/bgm_helper.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <RcppArmadillo.h>
#include "bgm_helper.h"
#include "common_helpers.h"
#include "bgm/bgm_helper.h"
#include "utils/common_helpers.h"



Expand Down
2 changes: 1 addition & 1 deletion src/bgm_helper.h → src/bgm/bgm_helper.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <RcppArmadillo.h>
#include "rng_utils.h"
#include "rng/rng_utils.h"

// Vectorize main_effect matrix
arma::vec vectorize_main_effects_bgm(
Expand Down
10 changes: 5 additions & 5 deletions src/bgm_logp_and_grad.cpp → src/bgm/bgm_logp_and_grad.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <RcppArmadillo.h>
#include "bgm_helper.h"
#include "bgm_logp_and_grad.h"
#include "common_helpers.h"
#include "explog_switch.h"
#include "variable_helpers.h"
#include "bgm/bgm_helper.h"
#include "bgm/bgm_logp_and_grad.h"
#include "utils/common_helpers.h"
#include "math/explog_switch.h"
#include "utils/variable_helpers.h"



Expand Down
File renamed without changes.
21 changes: 21 additions & 0 deletions src/bgm/bgm_output.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once
#include <RcppArmadillo.h>

struct bgmOutput {
// required
arma::mat main_samples;
arma::mat pairwise_samples;

// optional (only if edge_selection)
arma::imat indicator_samples;
arma::imat allocation_samples;

// optional (only if NUTS)
arma::ivec treedepth_samples;
arma::ivec divergent_samples;
arma::vec energy_samples;

// metadata
int chain_id = -1;
bool userInterrupt = false;
};
46 changes: 22 additions & 24 deletions src/bgm_sampler.cpp → src/bgm/bgm_sampler.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
#include <RcppArmadillo.h>
#include "bgm_helper.h"
#include "bgm_logp_and_grad.h"
#include "bgm_sampler.h"
#include "common_helpers.h"
#include "mcmc_adaptation.h"
#include "mcmc_hmc.h"
#include "mcmc_leapfrog.h"
#include "mcmc_nuts.h"
#include "mcmc_rwm.h"
#include "mcmc_utils.h"
#include "sbm_edge_prior.h"
#include "rng_utils.h"
#include "progress_manager.h"
#include "chainResults.h"
#include "bgm/bgm_helper.h"
#include "bgm/bgm_logp_and_grad.h"
#include "bgm/bgm_sampler.h"
#include "bgm/bgm_output.h"
#include "mcmc/mcmc_adaptation.h"
#include "mcmc/mcmc_hmc.h"
#include "mcmc/mcmc_leapfrog.h"
#include "mcmc/mcmc_nuts.h"
#include "mcmc/mcmc_rwm.h"
#include "mcmc/mcmc_utils.h"
#include "priors/sbm_edge_prior.h"
#include "sbm_edge_prior_interface.h"
#include "rng/rng_utils.h"
#include "utils/common_helpers.h"
#include "utils/progress_manager.h"



Expand Down Expand Up @@ -1186,8 +1187,8 @@ void gibbs_update_step_bgm (
* - Parallel execution across chains is handled by `run_bgm_parallel()`;
* this function is for one chain only.
*/
void run_gibbs_sampler_bgm(
ChainResult& chain_result,
bgmOutput run_gibbs_sampler_bgm(
int chain_id,
arma::imat observations,
const arma::ivec& num_categories,
const double pairwise_scale,
Expand Down Expand Up @@ -1221,9 +1222,6 @@ void run_gibbs_sampler_bgm(
SafeRNG& rng,
ProgressManager& pm
) {

int chain_id = chain_result.chain_id;

// --- Setup: dimensions and storage structures
const int num_variables = observations.n_cols;
const int num_persons = observations.n_rows;
Expand Down Expand Up @@ -1432,22 +1430,22 @@ void run_gibbs_sampler_bgm(
}
}

bgmOutput chain_result;
chain_result.chain_id = chain_id;
chain_result.userInterrupt = userInterrupt;

chain_result.main_effect_samples = main_effect_samples;
chain_result.pairwise_effect_samples = pairwise_effect_samples;

chain_result.main_samples = main_effect_samples;
chain_result.pairwise_samples = pairwise_effect_samples;
if (update_method == nuts) {
chain_result.treedepth_samples = treedepth_samples;
chain_result.divergent_samples = divergent_samples;
chain_result.energy_samples = energy_samples;
}

if (edge_selection) {
chain_result.indicator_samples = indicator_samples;

if (edge_prior == Stochastic_Block)
chain_result.allocation_samples = allocation_samples;
}

return chain_result;
}
9 changes: 5 additions & 4 deletions src/bgm_sampler.h → src/bgm/bgm_sampler.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#pragma once
#include <RcppArmadillo.h>
#include "common_helpers.h"
#include "utils/common_helpers.h"
#include "bgm/bgm_output.h"

// forward declaration
struct SafeRNG;
class ProgressManager;
struct ChainResult;

void run_gibbs_sampler_bgm(
ChainResult& chain_result,
bgmOutput run_gibbs_sampler_bgm(
int chain_id,
arma::imat observations,
const arma::ivec& num_categories,
const double pairwise_scale,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <RcppArmadillo.h>
#include <cmath>
#include "bgmCompare_helper.h"
#include "common_helpers.h"
#include "bgmCompare/bgmCompare_helper.h"
#include "utils/common_helpers.h"



Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <RcppArmadillo.h>
#include "rng_utils.h"
#include "rng/rng_utils.h"



Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include <RcppArmadillo.h>
#include "bgmCompare_helper.h"
#include "bgmCompare_logp_and_grad.h"
#include "bgmCompare/bgmCompare_helper.h"
#include "bgmCompare/bgmCompare_logp_and_grad.h"
#include <cmath>
#include "explog_switch.h"
#include "common_helpers.h"
#include "variable_helpers.h"
#include "math/explog_switch.h"
#include "utils/common_helpers.h"
#include "utils/variable_helpers.h"



Expand Down
4 changes: 3 additions & 1 deletion src/sampler_output.h → src/bgmCompare/bgmCompare_output.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
* - chain_id: Identifier of the chain.
* - has_indicator: True if indicator samples are stored.
*/
struct SamplerOutput {
struct bgmCompareOutput {
arma::mat main_samples;
arma::mat pairwise_samples;
arma::imat indicator_samples;

arma::ivec treedepth_samples;
arma::ivec divergent_samples;
arma::vec energy_samples;

int chain_id;
bool has_indicator;
bool userInterrupt;
Expand Down
Loading
Loading