Skip to content

Commit

Permalink
started on support for missing data
Browse files Browse the repository at this point in the history
  • Loading branch information
tonymugen committed Sep 25, 2020
1 parent 70f5a4c commit 0083bd3
Show file tree
Hide file tree
Showing 11 changed files with 1,312 additions and 797 deletions.
23 changes: 21 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

vbFit <- function(yVec, d, nPop, alphaPr, sigSqPr, ppRatio, nReps) {
.Call(`_MuGaMix_vbFit`, yVec, d, nPop, alphaPr, sigSqPr, ppRatio, nReps)
selectTraits <- function(yVec, pVec, d, nPops, pi) {
.Call(`_MuGaMix_selectFeatures`, yVec, pVec, d, nPops, pi)
}

testLpostNR <- function(yVec, d, Npop, theta, P, ind, limit, incr) {
Expand Down Expand Up @@ -57,6 +57,25 @@ gradTestSI <- function(yVec, lnFac, Npop, theta, iSigTheta, ind, limit, incr) {
.Call(`_MuGaMix_gradTestSI`, yVec, lnFac, Npop, theta, iSigTheta, ind, limit, incr)
}

#' Variational Bayes model fit
#'
#' Fits a Gaussian mixture model using variational Bayes. Assumes no missing data.
#'
#' @param yVec vectorized data matrix
#' @param d number of traits
#' @param nPop number of populations
#' @param alphaPr prior population size
#' @param sigSqPr prior variance
#' @param ppRatio population to error covariance ratio
#' @param nReps number of model fit attempts before picking the best fit
#' @return list containing population means (\code{popMeans}), covariances (\code{covariances}), effective population sizes (\code{effNm}), population assignment probabilities (\code{p}), and the deviance information criterion (DIC, \code{DIC}).
#'
#' @keywords internal
#'
vbFit <- function(yVec, d, nPop, alphaPr, sigSqPr, ppRatio, nReps) {
.Call(`_MuGaMix_vbFit`, yVec, d, nPop, alphaPr, sigSqPr, ppRatio, nReps)
}

#' Run the sampler with no replication
#'
#' Runs the sampler on the data assuming no fixed effects, missing trait data, or replication.
Expand Down
36 changes: 26 additions & 10 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@

using namespace Rcpp;

// vbFit
Rcpp::List vbFit(const std::vector<double>& yVec, const int32_t& d, const int32_t& nPop, const double& alphaPr, const double& sigSqPr, const double& ppRatio, const int32_t nReps);
RcppExport SEXP _MuGaMix_vbFit(SEXP yVecSEXP, SEXP dSEXP, SEXP nPopSEXP, SEXP alphaPrSEXP, SEXP sigSqPrSEXP, SEXP ppRatioSEXP, SEXP nRepsSEXP) {
// selectFeatures
Rcpp::List selectFeatures(const std::vector<double>& yVec, const std::vector<double>& pVec, const int32_t& d, const int32_t& nPops, const double& pi);
RcppExport SEXP _MuGaMix_selectFeatures(SEXP yVecSEXP, SEXP pVecSEXP, SEXP dSEXP, SEXP nPopsSEXP, SEXP piSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const std::vector<double>& >::type yVec(yVecSEXP);
Rcpp::traits::input_parameter< const std::vector<double>& >::type pVec(pVecSEXP);
Rcpp::traits::input_parameter< const int32_t& >::type d(dSEXP);
Rcpp::traits::input_parameter< const int32_t& >::type nPop(nPopSEXP);
Rcpp::traits::input_parameter< const double& >::type alphaPr(alphaPrSEXP);
Rcpp::traits::input_parameter< const double& >::type sigSqPr(sigSqPrSEXP);
Rcpp::traits::input_parameter< const double& >::type ppRatio(ppRatioSEXP);
Rcpp::traits::input_parameter< const int32_t >::type nReps(nRepsSEXP);
rcpp_result_gen = Rcpp::wrap(vbFit(yVec, d, nPop, alphaPr, sigSqPr, ppRatio, nReps));
Rcpp::traits::input_parameter< const int32_t& >::type nPops(nPopsSEXP);
Rcpp::traits::input_parameter< const double& >::type pi(piSEXP);
rcpp_result_gen = Rcpp::wrap(selectFeatures(yVec, pVec, d, nPops, pi));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -253,6 +251,23 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// vbFit
Rcpp::List vbFit(const std::vector<double>& yVec, const int32_t& d, const int32_t& nPop, const double& alphaPr, const double& sigSqPr, const double& ppRatio, const int32_t nReps);
RcppExport SEXP _MuGaMix_vbFit(SEXP yVecSEXP, SEXP dSEXP, SEXP nPopSEXP, SEXP alphaPrSEXP, SEXP sigSqPrSEXP, SEXP ppRatioSEXP, SEXP nRepsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const std::vector<double>& >::type yVec(yVecSEXP);
Rcpp::traits::input_parameter< const int32_t& >::type d(dSEXP);
Rcpp::traits::input_parameter< const int32_t& >::type nPop(nPopSEXP);
Rcpp::traits::input_parameter< const double& >::type alphaPr(alphaPrSEXP);
Rcpp::traits::input_parameter< const double& >::type sigSqPr(sigSqPrSEXP);
Rcpp::traits::input_parameter< const double& >::type ppRatio(ppRatioSEXP);
Rcpp::traits::input_parameter< const int32_t >::type nReps(nRepsSEXP);
rcpp_result_gen = Rcpp::wrap(vbFit(yVec, d, nPop, alphaPr, sigSqPr, ppRatio, nReps));
return rcpp_result_gen;
END_RCPP
}
// runSamplerNR
Rcpp::List runSamplerNR(const std::vector<double>& yVec, const int32_t& d, const int32_t& Npop, const int32_t& Nadapt, const int32_t& Nsamp, const int32_t& Nthin, const int32_t& Nchains);
RcppExport SEXP _MuGaMix_runSamplerNR(SEXP yVecSEXP, SEXP dSEXP, SEXP NpopSEXP, SEXP NadaptSEXP, SEXP NsampSEXP, SEXP NthinSEXP, SEXP NchainsSEXP) {
Expand Down Expand Up @@ -307,7 +322,7 @@ END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_MuGaMix_vbFit", (DL_FUNC) &_MuGaMix_vbFit, 7},
{"_MuGaMix_selectFeatures", (DL_FUNC) &_MuGaMix_selectFeatures, 5},
{"_MuGaMix_testLpostNR", (DL_FUNC) &_MuGaMix_testLpostNR, 8},
{"_MuGaMix_testLpostP", (DL_FUNC) &_MuGaMix_testLpostP, 8},
{"_MuGaMix_testLpostLocNR", (DL_FUNC) &_MuGaMix_testLpostLocNR, 8},
Expand All @@ -321,6 +336,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_MuGaMix_lpTestSI", (DL_FUNC) &_MuGaMix_lpTestSI, 8},
{"_MuGaMix_gradTestSInr", (DL_FUNC) &_MuGaMix_gradTestSInr, 8},
{"_MuGaMix_gradTestSI", (DL_FUNC) &_MuGaMix_gradTestSI, 8},
{"_MuGaMix_vbFit", (DL_FUNC) &_MuGaMix_vbFit, 7},
{"_MuGaMix_runSamplerNR", (DL_FUNC) &_MuGaMix_runSamplerNR, 7},
{"_MuGaMix_runSampler", (DL_FUNC) &_MuGaMix_runSampler, 7},
{"_MuGaMix_runSamplerMiss", (DL_FUNC) &_MuGaMix_runSamplerMiss, 8},
Expand Down
66 changes: 33 additions & 33 deletions src/danuts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ SamplerNUTS::SamplerNUTS(SamplerNUTS &&in) {
theta_ = in.theta_;
in.model_ = nullptr;
in.theta_ = nullptr;
memcpy( lastEpsilons_, in.lastEpsilons_, 20*sizeof(double) );
memcpy( lastEpsilons_, in.lastEpsilons_, 20 * sizeof(double) );
}
}
SamplerNUTS& SamplerNUTS::operator=(SamplerNUTS &&in){
Expand All @@ -88,7 +88,7 @@ SamplerNUTS& SamplerNUTS::operator=(SamplerNUTS &&in){
theta_ = in.theta_;
in.model_ = nullptr;
in.theta_ = nullptr;
memcpy( lastEpsilons_, in.lastEpsilons_, 20*sizeof(double) );
memcpy( lastEpsilons_, in.lastEpsilons_, 20 * sizeof(double) );
}
return *this;
}
Expand Down Expand Up @@ -128,14 +128,14 @@ void SamplerNUTS::findInitialEpsilon_(){
a = '1';
}
} else {
logp -= 0.5*nuc_.dotProd(r0);
logpPrime -= 0.5*nuc_.dotProd(rPrime);
logp -= 0.5 * nuc_.dotProd(r0);
logpPrime -= 0.5 * nuc_.dotProd(rPrime);
a = ((logpPrime - logp) > -0.6931472 ? '1' : '\0' ); // -0.6931472 = log(0.5); taking a log of the I() condition; '\0' equivalent to a = -1.0 in Algorithm 4
}

if (a) { // a = 1.0
for (uint16_t i = 0; i < 7; i++) { // do not do more than seven doublings; initial values may be wrong and result in epsilon_ too large for regular operation
epsilon_ = 2.0*epsilon_;
epsilon_ = 2.0 * epsilon_;
thetaPrime = *theta_;
rPrime = r0;
leapfrog_(thetaPrime, rPrime, epsilon_);
Expand All @@ -150,15 +150,15 @@ void SamplerNUTS::findInitialEpsilon_(){
throw string("log-posterior evaluates to +Inf in findInitialEpsilon_. This should never happen. Check your implementation.");
}
} else {
logpPrime -= 0.5*nuc_.dotProd(rPrime);
logpPrime -= 0.5 * nuc_.dotProd(rPrime);
if ((logpPrime - logp) > -0.6931472) { // take a log of the while() test inequality; a = 1.0 so the direction is the same as in the description
break;
}
}
}
} else { // a = -1.0
for (uint16_t i = 0; i < 7; i++) { // do not do more than seven halves or epsilon_ will be too small
epsilon_ = 0.5*epsilon_;
epsilon_ = 0.5 * epsilon_;
thetaPrime = *theta_;
rPrime = r0;
leapfrog_(thetaPrime, rPrime, epsilon_);
Expand All @@ -173,7 +173,7 @@ void SamplerNUTS::findInitialEpsilon_(){
throw string("log-posterior evaluates to +Inf in findInitialEpsilon_. This should never happen. Check your implementation.");
}
} else {
logpPrime -= 0.5*nuc_.dotProd(rPrime);
logpPrime -= 0.5 * nuc_.dotProd(rPrime);
if ((logpPrime - logp) < -0.6931472) { // take a log of the while() test inequality; a = -1.0, so the inequality is switched
break;
}
Expand All @@ -187,13 +187,13 @@ void SamplerNUTS::leapfrog_(vector<double> &theta, vector<double> &r, const doub
vector<double> thtGrad; // Make sure that the model implementing the gradient resizes it properly!
model_->gradient(theta, thtGrad);
for (size_t j = 0; j < theta.size(); j++) {
r[j] += 0.5*epsilon*thtGrad[j]; // half-step update of r
theta[j] += epsilon*r[j]; // leapfrog update of theta
r[j] += 0.5 * epsilon*thtGrad[j]; // half-step update of r
theta[j] += epsilon * r[j]; // leapfrog update of theta
}
model_->gradient(theta, thtGrad);
// one more half-step update of r
for (size_t k = 0; k < theta.size(); k++) {
r[k] += 0.5*epsilon*thtGrad[k];
r[k] += 0.5 * epsilon*thtGrad[k];
}

}
Expand All @@ -218,7 +218,7 @@ void SamplerNUTS::buildTreePos_(const vector<double> &theta, const vector<double
throw string("log-posterior evaluates to +Inf in buildTreePos_. This should never happen. Check your implementation.");
}
} else {
testVal -= 0.5*nuc_.dotProd(rPrime);
testVal -= 0.5 * nuc_.dotProd(rPrime);
nPrime = (lu <= testVal ? 1.0 : 0.0);
s = (lu < (deltaMax_ + testVal) ? '1' : '\0');
}
Expand All @@ -231,7 +231,7 @@ void SamplerNUTS::buildTreePos_(const vector<double> &theta, const vector<double
vector<double> thetaDprm; // theta''
buildTreePos_(thetaPlus, rPlus, lu, epsilon, j-1, thetaPlus, rPlus, thetaMinus, rMinus, thetaDprm, nDprm, sDPrm);
nPrime += nDprm;
if ( nPrime && (rng_.runif() <= nDprm/nPrime) ) { // nPrime now nPrime+nDprm
if ( nPrime && (rng_.runif() <= nDprm / nPrime) ) { // nPrime now nPrime+nDprm
thetaPrime = move(thetaDprm);
}
if (sDPrm) { // only now necessary to test the dot-product condition; equivalent to s''I(...) in Algorithm 3
Expand Down Expand Up @@ -277,7 +277,7 @@ void SamplerNUTS::buildTreeNeg_(const vector<double> &theta, const vector<double
throw string("log-posterior evaluates to +Inf in buildTreeNeg_. This should never happen. Check your implementation.");
}
} else {
testVal -= 0.5*nuc_.dotProd(rPrime);
testVal -= 0.5 * nuc_.dotProd(rPrime);
nPrime = (lu <= testVal ? 1.0 : 0.0);
s = (lu < (deltaMax_ + testVal) ? '1' : '\0');
}
Expand All @@ -290,7 +290,7 @@ void SamplerNUTS::buildTreeNeg_(const vector<double> &theta, const vector<double
vector<double> thetaDprm; // theta''
buildTreeNeg_(thetaMinus, rMinus, lu, epsilon, j-1, thetaPlus, rPlus, thetaMinus, rMinus, thetaDprm, nDprm, sDPrm);
nPrime += nDprm;
if ( nPrime && (rng_.runif() <= nDprm/nPrime) ) { // nPrime now nPrime+nDprm
if ( nPrime && (rng_.runif() <= nDprm / nPrime) ) { // nPrime now nPrime+nDprm
thetaPrime = move(thetaDprm);
}
if (sDPrm) { // only now necessary to test the dot-product condition; equivalent to s''I(...) in Algorithm 3
Expand Down Expand Up @@ -336,7 +336,7 @@ void SamplerNUTS::buildTreePos_(const vector<double> &theta, const vector<double
throw string("log-posterior evaluates to +Inf in adaptive buildTreePos_. This should never happen. Check your implementation.");
}
} else {
testVal -= 0.5*nuc_.dotProd(rPrime);
testVal -= 0.5 * nuc_.dotProd(rPrime);
nPrime = (lu <= testVal ? 1.0 : 0.0);
s = (lu < (deltaMax_ + testVal) ? '1' : '\0');
const double pDiff = testVal - nH0_;
Expand All @@ -357,7 +357,7 @@ void SamplerNUTS::buildTreePos_(const vector<double> &theta, const vector<double
alphaPrime += alphaDprm;
nAlphaPrime += nAlphaDprm;
nPrime += nDprm;
if ( (nPrime > 0.0) && (nDprm > 0.0) && (rng_.runif() <= nDprm/nPrime) ) { // nPrime now nPrime+nDprm
if ( (nPrime > 0.0) && (nDprm > 0.0) && (rng_.runif() <= nDprm / nPrime) ) { // nPrime now nPrime+nDprm
thetaPrime = move(thetaDprm);
}
if (sDPrm) { // only now necessary to test the dot-product condition; equivalent to s''I(...) in Algorithm 3 and 6
Expand Down Expand Up @@ -402,7 +402,7 @@ void SamplerNUTS::buildTreeNeg_(const vector<double> &theta, const vector<double
throw string("log-posterior evaluates to +Inf in adaptive buildTreeNeg_. This should never happen. Check your implementation.");
}
} else {
testVal -= 0.5*nuc_.dotProd(rPrime);
testVal -= 0.5 * nuc_.dotProd(rPrime);
nPrime = (lu <= testVal ? 1.0 : 0.0);
s = (lu < (deltaMax_ + testVal) ? '1' : '\0');
const double pDiff = testVal - nH0_;
Expand All @@ -423,7 +423,7 @@ void SamplerNUTS::buildTreeNeg_(const vector<double> &theta, const vector<double
alphaPrime += alphaDprm;
nAlphaPrime += nAlphaDprm;
nPrime += nDprm;
if ( (nPrime > 0.0) && (nDprm > 0.0) && (rng_.runif() <= nDprm/nPrime) ) { // nPrime now nPrime+nDprm
if ( (nPrime > 0.0) && (nDprm > 0.0) && (rng_.runif() <= nDprm / nPrime) ) { // nPrime now nPrime+nDprm
thetaPrime = move(thetaDprm);
}
if (sDPrm) { // only now necessary to test the dot-product condition; equivalent to s''I(...) in Algorithm 3 and 6
Expand Down Expand Up @@ -485,11 +485,11 @@ int16_t SamplerNUTS::adapt(){
}
}
const double mt0 = m_ + t0_;
Hprevious_ = (1.0 - 1.0/mt0)*Hprevious_ + delta_/mt0;
const double logEps = mu_ - (sqrt(m_)*Hprevious_)/gamma_;
Hprevious_ = (1.0 - 1.0 / mt0) * Hprevious_ + delta_ / mt0;
const double logEps = mu_ - (sqrt(m_) * Hprevious_) / gamma_;
epsilon_ = exp(logEps);
const double mPwr = pow(m_, negKappa_);
logEpsBarPrevious_ = mPwr*logEps + (1.0 - mPwr)*logEpsBarPrevious_;
logEpsBarPrevious_ = mPwr * logEps + (1.0 - mPwr) * logEpsBarPrevious_;
lastEpsilons_[static_cast<size_t>(m_)%20] = epsilon_;
m_ += 1.0;
return -1;
Expand All @@ -498,7 +498,7 @@ int16_t SamplerNUTS::adapt(){
}

}
nH0_ -= 0.5*nuc_.dotProd(r0);
nH0_ -= 0.5 * nuc_.dotProd(r0);
const double lu = log( rng_.runifnz() ) + nH0_; // log(slice variable)

vector<double> thetaPlus(*theta_);
Expand All @@ -517,7 +517,7 @@ int16_t SamplerNUTS::adapt(){
buildTreeNeg_(thetaMinus, rMinus, lu, -epsilon_, j, thetaPlus, rPlus, thetaMinus, rMinus, thetaPrime, nPrime, sPrime, alpha, nAlpha);
}
if (sPrime) {
if ( (nPrime >= n) || (rng_.runif() <= nPrime/n) ) {
if ( (nPrime >= n) || (rng_.runif() <= nPrime / n) ) {
(*theta_) = move(thetaPrime);
nAcc += 1.0;
}
Expand Down Expand Up @@ -551,16 +551,16 @@ int16_t SamplerNUTS::adapt(){
// Supplement the Hoffman and Gelman approach by looking at the actual acceptance rates when there is a large enough number of HMC steps.
// This seems to bump up epsilon a bit to reduce the number of steps.
// Using the nAcc/n statistic by itself makes epsilon too large, primarily because small n does not allow for a good acceptance rate estimate.
double aFrac = alpha/nAlpha;
double aFrac = alpha / nAlpha;
if (n >= 5) {
aFrac = max(aFrac, nAcc/n);
aFrac = max(aFrac, nAcc / n);
}
const double mt0 = m_ + t0_;
Hprevious_ = (1.0 - 1.0/mt0)*Hprevious_ + (delta_ - aFrac)/mt0;
const double logEps = mu_ - (sqrt(m_)*Hprevious_)/gamma_;
Hprevious_ = (1.0 - 1.0 / mt0) * Hprevious_ + (delta_ - aFrac) / mt0;
const double logEps = mu_ - (sqrt(m_) * Hprevious_) / gamma_;
epsilon_ = exp(logEps);
const double mPwr = pow(m_, negKappa_);
logEpsBarPrevious_ = mPwr*logEps + (1.0 - mPwr)*logEpsBarPrevious_;
logEpsBarPrevious_ = mPwr * logEps + (1.0 - mPwr) * logEpsBarPrevious_;
lastEpsilons_[static_cast<size_t>(m_)%20] = epsilon_;
m_ += 1.0;

Expand Down Expand Up @@ -614,7 +614,7 @@ int16_t SamplerNUTS::update() {
}

}
const double lu = log( rng_.runifnz() ) + lPost - 0.5*nuc_.dotProd(rPlus); // log(slice variable)
const double lu = log( rng_.runifnz() ) + lPost - 0.5 * nuc_.dotProd(rPlus); // log(slice variable)

vector<double> thetaPlus(*theta_);
vector<double> thetaMinus(*theta_);
Expand All @@ -631,7 +631,7 @@ int16_t SamplerNUTS::update() {
buildTreeNeg_(thetaMinus, rMinus, lu, -epsilon_, j, thetaPlus, rPlus, thetaMinus, rMinus, thetaPrime, nPrime, sPrime);
}
if (sPrime) {
if ( (nPrime >= n) || (rng_.runif() <= nPrime/n) ) {
if ( (nPrime >= n) || (rng_.runif() <= nPrime / n) ) {
(*theta_) = move(thetaPrime);
}
vector<double> thetaDiff;
Expand Down Expand Up @@ -682,7 +682,7 @@ SamplerMetro& SamplerMetro::operator=(SamplerMetro &&in){
int16_t SamplerMetro::adapt(){
vector<double> thetaPrime = *theta_;
for (auto &t : thetaPrime) {
t += incr_*rng_.rnorm();
t += incr_ * rng_.rnorm();
}
double lAlpha = model_->logPost(thetaPrime) - model_->logPost(*theta_);
double lU = log( rng_.runifnz() );
Expand All @@ -697,7 +697,7 @@ int16_t SamplerMetro::adapt(){
int16_t SamplerMetro::update(){
vector<double> thetaPrime = *theta_;
for (auto &t : thetaPrime) {
t += incr_*rng_.rnorm();
t += incr_ * rng_.rnorm();
}
double lAlpha = model_->logPost(thetaPrime) - model_->logPost(*theta_);
double lU = log( rng_.runifnz() );
Expand Down
Loading

0 comments on commit 0083bd3

Please sign in to comment.