Skip to content

Commit

Permalink
move prelim computations outside gibbs loop
Browse files Browse the repository at this point in the history
  • Loading branch information
ecmerkle committed May 10, 2024
1 parent 8b9e468 commit 7734f3e
Showing 1 changed file with 84 additions and 58 deletions.
142 changes: 84 additions & 58 deletions inst/stan/stanmarg_bsam.stan
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,11 @@ transformed data { // (re)construct skeleton matrices in Stan (not that interest

int Ncont = p + q - Nord;
array[max(nclus[,2]) > 1 ? max(nclus[,2]) : 0] int<lower = 0> intone;
array[len_alph] int paidx;
array[len_b, 2] int pbidx;
int pridx = 1;
int f1idx = 1;
int f2idx = 1;

array[Ng,2] int g_start1;
array[Ng,2] int g_start4;
Expand All @@ -631,6 +636,9 @@ transformed data { // (re)construct skeleton matrices in Stan (not that interest
array[15] int pos;
array[15] int len_free_c;
array[15] int pos_c;

array[Ng] int matdim;
int maxdim;

for (i in 1:15) {
len_free[i] = 0;
Expand Down Expand Up @@ -863,6 +871,46 @@ transformed data { // (re)construct skeleton matrices in Stan (not that interest
}
}

for (g in 1:Ng) {
if (g == Ng) {
matdim[g] = (len_free[4] - g_start4[g, 1] + 1) + (len_free[14] - g_start14[g, 1] + 1);
} else {
matdim[g] = (g_start4[(g + 1), 1] - g_start4[g, 1]) + (g_start14[(g + 1), 1] - g_start14[g, 1]);
}
maxdim = max(matdim);

// indexing of free params across rows of Alpha combined with B
for (r in 1:m) {
real askel = Alpha_skeleton[g, r, 1];
if (is_inf(askel)) {
paidx[f1idx] = pridx;
f1idx += 1;
pridx += 1;
}
for (c in 1:m) {
real bskel = B_skeleton[g, r, c];
if (is_inf(bskel)) {
// find columnwise "free" index
// this could be sent in as data to improve efficiency
int f3idx = 1;
for (cc in 1:c) {
for (rr in 1:m) {
if (is_inf(B_skeleton[g, rr, cc])) {
if (cc < c || (cc == c && rr < r)) {
f3idx += 1;
}
}
}
}
pbidx[f2idx, 1] = f3idx;
pbidx[f2idx, 2] = pridx;
f2idx += 1;
pridx += 1;
}
}
}
}

// for clusterwise loglik computations
if (max(nclus[,2]) > 1) for (i in 1:max(nclus[,2])) intone[i] = 1;

Expand Down Expand Up @@ -1260,10 +1308,13 @@ generated quantities { // these matrices are saved in the output but do not figu
vector[len_free[9]] Psi_pri;
vector[len_free[4]] b_primn;
vector[len_free[14]] alpha_primn;
array[Ng] int matdim;

array[Ntot] vector[m] eta;

// intermediate computations for gibbs sampler
array[Np] vector[maxdim] gamma0;
array[Np] matrix[maxdim, maxdim] Omega_inv;

// sign constraints and correlations
vector[len_free[1]] ly_sign;
vector[len_free[4]] bet_sign;
Expand Down Expand Up @@ -1350,12 +1401,6 @@ generated quantities { // these matrices are saved in the output but do not figu
Psi_prior_shape[g] = fill_matrix(to_vector(psi_sd_shape), Psi_skeleton[g], w9skel, g_start9[g,1], g_start9[g,2]);
Psi_prior_rate[g] = fill_matrix(to_vector(psi_sd_rate), Psi_skeleton[g], w9skel, g_start9[g,1], g_start9[g,2]);

if (g == Ng) {
matdim[g] = (len_free[4] - g_start4[g, 1] + 1) + (len_free[14] - g_start14[g, 1] + 1);
} else {
matdim[g] = (g_start4[(g + 1), 1] - g_start4[g, 1]) + (g_start14[(g + 1), 1] - g_start14[g, 1]);
}

// around here, rstan line numbers are off by about 135
Lambda[g] = fill_matrix(ly_sign, Lambda_y_skeleton[g], w1skel, g_start1[g,1], g_start1[g,2]);
B[g] = fill_matrix(B_free, B_skeleton[g], w4skel, g_start4[g,1], g_start4[g,2]);
Expand All @@ -1365,6 +1410,31 @@ generated quantities { // these matrices are saved in the output but do not figu
Psi[g] = quad_form_sym(Psi_r[g], Psi_sd[g]);
}

// arrange prior info
for (mm in 1:Np) {
int pidx = 1;
int g = grpnum[mm];

gamma0[mm] = rep_vector(0, maxdim);
Omega_inv[mm] = diag_matrix(gamma0[mm]);

for (r in 1:m) {
real askel = Alpha_skeleton[g, r, 1];
if (is_inf(askel)) {
gamma0[mm, pidx] = alpha_prior[g, r, 1];
Omega_inv[mm, pidx, pidx] = alpha_prior_prec[g, r, 1];
pidx += 1;
}
for (c in 1:m) {
real bskel = B_skeleton[g, r, c];
if (is_inf(bskel)) {
gamma0[mm, pidx] = b_prior[g, r, c];
Omega_inv[mm, pidx, pidx] = b_prior_prec[g, r, c];
pidx += 1;
}
}
}
}

for (i in 1:ngibbs) {
for (mm in 1:Np) {
Expand All @@ -1380,15 +1450,11 @@ generated quantities { // these matrices are saved in the output but do not figu
int r1 = startrow[mm];
int r2 = endrow[mm];
int g = grpnum[mm];
vector[matdim[g]] gamma0 = rep_vector(0, matdim[g]);
matrix[matdim[g], matdim[g]] Omega_inv = diag_matrix(gamma0);
vector[matdim[g]] params;
matrix[matdim[g], matdim[g]] FVF = rep_matrix(0, matdim[g], matdim[g]);
vector[matdim[g]] FVz = rep_vector(0, matdim[g]);
matrix[matdim[g], matdim[g]] Dinv;
int pidx = 1;
int f1idx = 1;
int f2idx = 1;

IBinv = inverse(I - B[g]);
if (Ndum_x[mm] > 0) {
Expand All @@ -1397,24 +1463,6 @@ generated quantities { // these matrices are saved in the output but do not figu
IBinv[dum_lv_x_idx[mm, j], dum_lv_x_idx[mm, j]] = 1;
}
}

// priors
for (r in 1:m) {
real askel = Alpha_skeleton[g, r, 1];
if (is_inf(askel)) {
gamma0[pidx] = alpha_prior[g, r, 1];
Omega_inv[pidx, pidx] = alpha_prior_prec[g, r, 1];
pidx += 1;
}
for (c in 1:m) {
real bskel = B_skeleton[g, r, c];
if (is_inf(bskel)) {
gamma0[pidx] = b_prior[g, r, c];
Omega_inv[pidx, pidx] = b_prior_prec[g, r, c];
pidx += 1;
}
}
}

// sample lvs
Psi0_inv = inverse_spd( quad_form_sym(Psi[g], IBinv') );
Expand Down Expand Up @@ -1474,41 +1522,19 @@ generated quantities { // these matrices are saved in the output but do not figu
FVz += etamat' * Psi_inv * z;
}

FVF += Omega_inv;
FVz += Omega_inv * gamma0;
FVF += Omega_inv[mm, matdim[g], matdim[g]];
FVz += Omega_inv[mm, matdim[g], matdim[g]] * gamma0[mm, matdim[g]];

Dinv = inverse_spd(FVF);

params = multi_normal_rng(Dinv * FVz, Dinv);

// now put parameters in free parameter vectors
pidx = 1;
for (r in 1:m) {
real askel = Alpha_skeleton[g, r, 1];
if (is_inf(askel)) {
Alpha_free[f1idx] = params[pidx];
f1idx +=1;
pidx += 1;
}
for (c in 1:m) {
real bskel = B_skeleton[g, r, c];
if (is_inf(bskel)) {
// find columnwise "free" index
// this is inefficient and should be sent in as data
f2idx = 1;
for (cc in 1:c) {
for (rr in 1:m) {
if (is_inf(B_skeleton[g, rr, cc])) {
if (cc < c || (cc == c && rr < r)) {
f2idx += 1;
}
}
}
}
B_free[f2idx] = params[pidx];
pidx += 1;
}
}
for (j in 1:len_alph) {
Alpha_free[j] = params[paidx[j]];
}
for (j in 1:len_b) {
B_free[pbidx[j, 1]] = params[pbidx[j, 2]];
}
}

Expand Down

0 comments on commit 7734f3e

Please sign in to comment.