-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbrm_mv_model_testing_mediumtaxa.R
112 lines (89 loc) · 4.55 KB
/
brm_mv_model_testing_mediumtaxa.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Simulate multivariate normal data in two levels: maternal taxa abundance means are MVN,
# then offspring taxa abundance means are MVN from that.
# This uses an intermediate number of taxa for local testing.
library(mvtnorm)
library(brms)
options(mc.cores = 4, brms.backend = 'cmdstanr', brms.file_refit = 'on_change')
# Increase number of taxa
n_mothers <- 20
n_taxa <- 50
offspring_per_mother <- 10 # 5 will be retained for traits, 5 for microbiome
set.seed(1)
X_maternal <- rmvnorm(n_mothers, mean = rep(0, n_taxa), sigma = diag(n_taxa))
sigma_maternal <- cov(X_maternal)
# Coefficients indicating which taxa predict the outcome.
# We will not include any interaction effect.
# Only include a few taxa with a nonzero effect.
beta <- c(50, 10, -50, rep(0, n_taxa-3))
y_maternal <- 0 + X_maternal %*% beta + rnorm(n_mothers, 0, 1)
# To get offspring microbiome, take multivariate normal draws from the mean vector for each mother (rows of X_maternal)
X_offspring <- apply(X_maternal, 1, function(Xi) rmvnorm(offspring_per_mother, mean = Xi, sigma = sigma_maternal), simplify = FALSE)
# Use regression coefficients (beta) to get value for offspring trait, plus noise
y_offspring <- lapply(X_offspring, function(Xoi) Xoi %*% beta + rnorm(offspring_per_mother, 0, 1))
# Combine together
dt <- data.frame(
maternal_id = factor(rep(1:n_mothers, each = offspring_per_mother)),
offspring_id = 1:offspring_per_mother,
do.call(rbind, X_offspring),
y = do.call(c, y_offspring)
)
# Within each mother, set half of the values to be missing for x, and the other half for y.
xmiss <- lapply(1:nrow(dt), function(i) {
if (dt[i, 'offspring_id'] %in% 1:(offspring_per_mother/2)) {
setNames(dt[i, paste0('X', 1:n_taxa)], paste0('Xmiss',1:n_taxa))
} else {
setNames(rep(NA, n_taxa), paste0('Xmiss', 1:n_taxa))
}
})
dt <- cbind(dt, do.call(rbind, xmiss))
dt$ymiss <- ifelse(dt$offspring_id %in% 1:(offspring_per_mother/2), NA, dt$y)
# Model without missing data ----------------------------------------------
# Model without missing data, including regularized horseshoe prior
# Let's see if the coefficients can be recovered.
# Interactions between taxa aren't included.
# Construct priors programmatically because we have one for each taxon.
sd_X_priors <- lapply(1:n_taxa, function(i) prior_string('gamma(1, 1)', class = 'sd', resp = paste0('X', i)))
sd_X_priors <- do.call(c, sd_X_priors)
sigma_X_priors <- lapply(1:n_taxa, function(i) prior_string('gamma(1, 1)', class = 'sigma', resp = paste0('X', i)))
sigma_X_priors <- do.call(c, sigma_X_priors)
# Also construct formula programmatically.
X_formula <- paste0('mvbind(', paste(paste0('X',1:n_taxa), collapse = ','), ') ~ (1||maternal_id)')
y_formula <- paste0('y ~ ', paste(paste0('X',1:n_taxa), collapse = '+'), ' + (1||maternal_id)')
modmv_nomiss_reghorseshoe <- brm(
bf(X_formula) + bf(y_formula) + set_rescor(FALSE),
prior = c(
sd_X_priors,
sigma_X_priors,
prior(gamma(1, 1), class = sd, resp = y),
prior(gamma(1, 1), class = sigma, resp = y),
prior(horseshoe(df = 1, df_global = 1, scale_slab = 20, df_slab = 4, par_ratio = 3/(n_taxa-3)), class = b, resp = y)
),
data = dt,
chains = 4, iter = 4500, warmup = 2000,
init = 0, seed = 1240,
file = 'project/fits/brmtest_mv_nomiss_reghorseshoe_midtaxa'
)
# Model with missing data -------------------------------------------------
# With regularized horseshoe prior on fixed effects.
# Construct sd and sigma priors programmatically.
sd_Xmiss_priors <- lapply(1:n_taxa, function(i) prior_string('gamma(1, 1)', class = 'sd', resp = paste0('Xmiss', i)))
sd_Xmiss_priors <- do.call(c, sd_Xmiss_priors)
sigma_Xmiss_priors <- lapply(1:n_taxa, function(i) prior_string('gamma(1, 1)', class = 'sigma', resp = paste0('Xmiss', i)))
sigma_Xmiss_priors <- do.call(c, sigma_Xmiss_priors)
# Also construct formula programmatically.
Xmiss_formula <- paste0('mvbind(', paste(paste0('Xmiss', 1:n_taxa), collapse = ','), ') | mi() ~ (1||maternal_id)')
ymiss_formula <- paste0('ymiss | mi() ~ ', paste(paste0('mi(Xmiss', 1:n_taxa, ')'), collapse = '+'), ' + (1||maternal_id)')
modmv_miss_reghorseshoe <- brm(
bf(Xmiss_formula) + bf(ymiss_formula) + set_rescor(FALSE),
prior = c(
sd_Xmiss_priors,
sigma_Xmiss_priors,
prior(gamma(1, 1), class = sd, resp = ymiss),
prior(gamma(1, 1), class = sigma, resp = ymiss),
prior(horseshoe(df = 1, df_global = 1, scale_slab = 20, df_slab = 4, par_ratio = 3/(n_taxa-3)), class = b, resp = ymiss)
),
data = dt,
chains = 4, iter = 4500, warmup = 2000,
init = 0, seed = 1239,
file = 'project/fits/brmtest_mv_miss_reghorseshoe_midtaxa'
)