Skip to content

Commit

Permalink
working VB fit
Browse files Browse the repository at this point in the history
  • Loading branch information
tonymugen committed Aug 3, 2020
1 parent c28a49b commit 39646af
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
45 changes: 32 additions & 13 deletions src/gmmvb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,28 @@ GmmVB::GmmVB(GmmVB &&in) : yVec_{in.yVec_}, Nm_{in.Nm_}, lambda0_{in.lambda0_},

void GmmVB::fitModel(vector<double> &lowerBound) {
lowerBound.clear();
for (uint16_t it = 0; it < 2; it++) {
for (uint16_t it = 0; it < maxIt_; it++) {
eStep_();
const double curLB = mStep_();
if ( lowerBound.size() && ( fabs( (lowerBound.back() - curLB)/lowerBound.back() ) <= stoppingDiff_) ) {
lowerBound.push_back(curLB);
break;
if ( lowerBound.size() ) {
if (lowerBound.back() >= curLB) { // do not bother saving the lower bound value if it decreased (can happen when we are close to the optimum)
break;
} else if ( ( curLB - lowerBound.back() )/fabs( lowerBound.back() ) <= stoppingDiff_ ) {
lowerBound.push_back(curLB);
break;
}
}
lowerBound.push_back(curLB);
}
// scale the outputs as necessary
for (size_t m = 0; m < M_.getNrows(); m++) { // crossproduct into covariance
S_[m] /= (*Nm_)[m];
}
double dm = 2.0;
for (size_t m = 2; m <= M_.getNrows(); m++) { // add ln K! as recommended in Bishop, page 484
lowerBound.back() += log(dm);
dm += 1.0;
}
}

void GmmVB::eStep_(){
Expand All @@ -145,12 +158,12 @@ void GmmVB::eStep_(){
// start with parameters not varying across individuals
vector<double> startSum;
vector<double> lamNmRat;
vector<double> invLamNm;
vector<double> nuNm;
for (size_t m = 0; m < R_.getNcols(); m++) {
const double lNm = lambda0_ + (*Nm_)[m];
lamNmRat.push_back( (*Nm_)[m]/lNm );
invLamNm.push_back(0.5*(nu0p1_ + (*Nm_)[m])/lNm);
startSum.push_back( nuc_.digamma(alpha0_ + (*Nm_)[m]) + 0.5*(sumDiGam_[m] + lnDet_[m] - d_/lNm) );
nuNm.push_back( 0.5*(nu0p1_ + (*Nm_)[m]) );
startSum.push_back( nuc_.digamma(alpha0_ + (*Nm_)[m]) + 0.5*(sumDiGam_[m] + lnDet_[m] - 0.5*d_/lNm) );
}
// scale the mean matrix
vector<double> vMsc(Npop*d, 0.0);
Expand Down Expand Up @@ -180,7 +193,7 @@ void GmmVB::eStep_(){
}
}
for (size_t iRow = 0; iRow < N; iRow++) {
const double lnRhoLoc = startSum[m] - invLamNm[m]*lnRho.getElem(iRow, m);
const double lnRhoLoc = startSum[m] - nuNm[m]*lnRho.getElem(iRow, m);
lnRho.setElem(iRow, m, lnRhoLoc);
}
}
Expand All @@ -193,7 +206,7 @@ void GmmVB::eStep_(){
continue;
} else {
double diff = lnRho.getElem(iRow, l) - lnRho.getElem(iRow, m);
if (diff >= lnMaxDbl_) { // will overflow right away
if (diff >= lnMaxDbl_) { // will overflow right away
R_.setElem(iRow, m, 0.0);
noOverflow = false;
break;
Expand Down Expand Up @@ -243,7 +256,7 @@ double GmmVB::mStep_() {
}
}
// unscaled Sm
Arsd.gemm(true, 1.0, wtArsd, false, 1.0, S_[m]);
Arsd.gemm(true, 1.0, wtArsd, false, 0.0, S_[m]);
const double lNmRatio = (lambda0_*(*Nm_)[m])/(lambda0_ + (*Nm_)[m]);
// lower triangle of Sigma_m
for (size_t jD = 0; jD < d; jD++) {
Expand All @@ -262,8 +275,12 @@ double GmmVB::mStep_() {
SigM_[m].addToElem(kk, kk, tau0_);
}
// invert
SigM_[m].chol();
SigM_[m].cholInv();
try {
SigM_[m].chol();
SigM_[m].cholInv();
} catch (string problem) {
SigM_[m].pseudoInv();
}

// calculate the lower bound portion for this population
sumDiGam_[m] = 0.0;
Expand Down Expand Up @@ -309,7 +326,9 @@ double GmmVB::mStep_() {
double rLnr = 0.0;
for (size_t i = 0; i < N; i++) {
const double r = R_.getElem(i, m);
rLnr += r*log(r);
if ( r > numeric_limits<double>::epsilon() ) { // lim x-> 0 (x ln(x) ) = 0
rLnr += r*log(r);
}
}
// put it all together (+= because we are summing across populations
lwrBound += psiElmt - 0.5*(nu0p1_ + (*Nm_)[m])*(matTr + aSaT) - rLnr + nuc_.lnGamma(alpha0_ + (*Nm_)[m]) - 0.5*d_*log(lmNmSm);
Expand Down
8 changes: 5 additions & 3 deletions src/utilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,11 @@ double NumerUtil::dotProd(const vector<double> &v1, const vector<double> &v2) co
return dotProd;
}
void NumerUtil::updateWeightedMean(const double &xn, const double &wn, double &mu, double &w) const{
const double a = mu*w;
w += wn;
mu = (a + wn*xn)/w;
if ( wn > numeric_limits<double>::epsilon() ) {
const double a = mu*w;
w += wn;
mu = (a + wn*xn)/w;
}
}
double NumerUtil::mean(const double arr[], const size_t &len){
double mean = 0.0;
Expand Down

0 comments on commit 39646af

Please sign in to comment.