Skip to content

Commit

Permalink
fixed the log-posterior overflow errors
Browse files Browse the repository at this point in the history
  • Loading branch information
tonymugen committed Aug 14, 2020
1 parent 7373512 commit 8c4b3fb
Showing 1 changed file with 15 additions and 53 deletions.
68 changes: 15 additions & 53 deletions src/gmmvb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,70 +309,32 @@ double GmmVB::logPost_(){
for (size_t m = 0; m < Npop; m++) {
scSum.push_back( log(alpha0_ + (*N_)[m]) + 0.5*(d_*log(nuNm[m]) + lnDet_[m]) );
}
vector<size_t> maxInd(N, 0); // index of the largest kernel value
vector<double> curMaxVal( N, -numeric_limits<double>::infinity() ); // store the current largest kernel value here
for (size_t m = 0; m < Npop; m++) {
for (size_t iRow = 0; iRow < N; iRow++) {
const double newVal = scSum[m] - 0.5*nuNm[m]*K.getElem(iRow, m);
if (newVal > curMaxVal[iRow]) {
maxInd[iRow] = m;
curMaxVal[iRow] = newVal;
}
K.setElem(iRow, m, newVal);
}
}
// Subtract the first column from the rest
for (size_t m = 1; m < Npop; m++) {
// Subtract the largest column from the rest and sum the exponents row-wise
vector<double> expRowSums(N, 0.0);
for (size_t m = 0; m < Npop; m++) {
for (size_t iRow = 0; iRow < N; iRow++) {
const double diff = K.getElem(iRow, m) - K.getElem(iRow, 0);
K.setElem(iRow, m, diff);
if (m != maxInd[iRow]) {
expRowSums[iRow] += exp(K.getElem(iRow, m) - K.getElem(iRow, maxInd[iRow]));
}
}
}
// sum everything
double lnP = 0.0; // sum of the additive kernel will go here
for (size_t iRow = 1; iRow < N; iRow++) {
double rowVal = 0.0;
for (size_t m = 0; m < Npop; m++) {
rowVal += exp(K.getElem(iRow, m));
}
lnP += log1p(rowVal);
}
for (size_t iRow = 0; iRow < N; iRow++) {
lnP += K.getElem(iRow, 0);
}
/*
TODO: get this part to work
for (size_t iRow = 0; iRow < N; iRow++) { // sacrificing the tight loop to make numerical safety happen
double regSum = 0.0;
double bigSum = 0.0; // will be used if large values of Km are encountered
for (size_t m = 1; m < Npop; m++) {
const double df = K.getElem(iRow, m);
if (df >= 100) { // well into approximation territory, but don't want to do this too often
if (bigSum > 0.0) { // something already added
double ldif = bigSum - df;
if ( (ldif > 0.0) && (ldif <= 5.0) ) { // over 5.0 the correction is unnecessary regardless of the df or bigSum value
bigSum += log1p( exp(-ldif) );
} else if ( (ldif < 0.0) && (ldif >= -5.0) ) {
bigSum = df + log1p( exp(ldif) );
} else if (ldif < 0.0) {
bigSum = df;
} // or leave bigSum as is
} else {
bigSum = df;
}
} else if (bigSum > 0.0) {
if (df >= 95) {
bigSum += log1p( exp(df-bigSum) );
}
// otherwise do nothing
} else {
const double edf = exp(df);
if ( (numeric_limits<double>::max() - regSum) <= edf ) { // do not bother adding any more if regSum will overflow
regSum += edf;
}
}
}
if (bigSum > 0.0) {
lnP += K.getElem(iRow, 0) + bigSum;
} else {
lnP += K.getElem(iRow, 0) + log1p(regSum);
}
double lnP = 0.0;
for (size_t i = 0; i < N; i++) {
lnP += curMaxVal[i] + log1p(expRowSums[i]);
}
*/
return(lnP);
}

Expand Down

0 comments on commit 8c4b3fb

Please sign in to comment.