Richard McElreath recently shared a tutorial on how to do estimate latent group means (à la Mundlak) in Stan. I try to follow his code, do the same with brms, and try to implement a solution Matti Vuorre posted.
In a recent lecture, Richard McElreath gave an introduction to Mundlak devices (code). Briefly, social scientists would often like to remove confounding at the group level.1
Econometricians and epidemiologists tend to use fixed effects regression to do so. Psychologists often use random effects regressions and adjust for the mean of the predictor of interest. The two solutions yield similar results for the within-group effect, but differ on how they treat group-level effects.
Arguably, there are more ways to mess up the random effects model, but it’s more flexible and efficient and it makes it easier to model e.g., cross-level interactions and effect heterogeneity.
One way you can mess up, is if your group sizes are small (e.g., sibling groups in a low fertility country) and your estimate of the group mean is a poor stand-in for the confounds you’d like to adjust for. A solution to this is to estimate the latent group mean instead, i.e. to account for the fact that we are estimating it with uncertainty. Doing so3 is fairly easy in Stan, but it’s less clear how to do it with everyone’s favourite Stan interface, brms.
In which way does this sampling error at the group level bias your results? It attenuates your estimate of b_Xbar
(Lüdtke’s bias). I thought it would also bias my estimate of b_X
, because I’m underadjusting by ignoring the measurement error in my covariate. That is not so. Why? To get the intuition, it helped me to consider the case where I first subtract Xbar
from X
. Xdiff
is necessarily uncorrelated with Xbar
, meaning any changes in the association of Xbar
and Y
(which we get by modelling sampling error) are irrelevant to Y ~ Xdiff
.
I think I’m not the only who is confused about this. Matti Vuorre shared a solution to center X with the latent group mean using brms’ non-linear syntax. Centering with the latent group mean is not necessary though. Also, I tried the syntax and it doesn’t correctly recover simulated effects.4
My two attempts at a solution:
Using me()
Xbar
, e.g. df %>% group_by(id) %>% mutate(Xbar = mean(X))
.df %>% group_by(id) %>% mutate(Xse = sd(X)/sqrt(n()))
.me(Xbar, Xse, gr = id)
explicitly specifying the grouping variable.Using mi()
X->X2
bf(X2 | mi(Xse) ~ (1|id))
in brms to estimate the latent group mean with shrinkage.mi(X2)
in the regression on Y
.me
approach didn’t work for me at first until I specified the grouping variable, which makes sense..01
to Xse
when SE was zero (in my Bernoulli exposure simulation).mi()
cannot be combined with non-Gaussian variables. If I treat binary exposure X
as Gaussian in a LPM, convergence was difficult and I had to fix the sigma to 0.01. The me()
approach seems to work though. If I do all this, the three approaches (rethinking, me, mi) converge.b_X
. It improves estimates for b_Xbar
(i.e. the group-level confound), so it’s doing what it’s supposed to, but b_X
is always already estimated close to the true value with a non-latent Mundlak model. I tried to create conditions to make outperformance likely (small group sizes, much within-subject variation in X relative to between-subject variation, so that Xbar poorly correlates with the group-level confound). Am I missing something in my simulations or is latent group mean centering rarely worth the effort? I’d be glad for any pointers. Edit: Niclas Kuper gave me the pointer I needed, so I have clarified above that we do not expect estimation of b_X
to improve by using latent group means.The simulations and their results are documented in detail below. Click on the “Implementations” to see the model code.
I started with Richard’s simulations of a binary outcome. I was not able to reproduce his model performances exactly and ended up increasing the simulated sample size to make estimates a bit more stable (to more clearly show bias vs. variance).
Up to the latent Mundlak model, the brms equivalents, perform, well, equivalently. For the latent Mundlak model, the brms gets slightly different coefficients, but a similar overall result.
library(tidyverse)
library(rethinking)
library(tidybayes)
library(tidybayes.rethinking)
library(brms)
options(brms.backend = "cmdstanr", # I use the cmdstanr backend
mc.cores = 8,
brms.threads = 1, # which allows me to multithread
brms.file_refit = "on_change") # this is useful when doing
set.seed(201910)
# families
N_groups <- 300
a0 <- (-2)
b_X <- (1)
b_Z <- (-0.5)
b_Ug <- (3)
# 2 or more siblings
g_sizes <- 2 + rpois(N_groups, lambda = 0.2) # sample into groups
table(g_sizes)
N_id <- sum(g_sizes)
g <- rep(1:N_groups, times = g_sizes)
Ug <- rnorm(N_groups, sd = 0.8) # group confounds
X <- rnorm(N_id, Ug[g] ) # individual varying trait
Z <- rnorm(N_groups) # group varying trait (observed)
Y <- rbern(N_id, p=inv_logit( a0 + b_X * X + b_Ug*Ug[g] + b_Z*Z[g] ) )
# glm(Y ~ X + Z[g] + Ug[g], binomial())
# glm(Y ~ X + Z[g], binomial())
groups <- tibble(id = factor(1:N_groups), Ug, Z)
sim <- tibble(id = factor(g), X, Y) %>% full_join(groups, by = "id") %>% arrange(id) %>% group_by(id) %>%
mutate(Xbar = mean(X)) %>% ungroup()
sim %>% distinct(id, Ug, Xbar) %>% select(-id) %>% cor(use = "p")
glm(Y ~ X + Z, data = sim, binomial(link = "logit"))
glm(Y ~ Ug + X + Z, data = sim, binomial(link = "logit"))
glm(Y ~ id + X + Z, data = sim, binomial(link = "logit"))
sim <- sim %>% group_by(id) %>%
mutate(Xse = sd(X)/sqrt(n())) %>% ungroup() %>%
mutate(X2 = X)
dat <- list(Y = Y, X = X, g = g, Ng = N_groups, Z = Z)
# fixed effects
mf <- ulam(alist(
Y ~ bernoulli(p),
logit(p) <- a[g] + b_X*X + b_Z*Z[g],
a[g] ~ dnorm(0,1.5), # can't get it to work using dunif(-1,1)
c(b_X,b_Z) ~ dnorm(0,1)
), data=dat, chains = 4, cores=4 )
summary(mf)
mf %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mf <- brm(Y ~ 1 + id + X + Z, data = sim,
family = bernoulli(),
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z")
,set_prior("uniform(-1, 1)",lb = -1, ub = 1, class = "b"))
# , sample_prior = "only"
)
# b_mf %>% gather_draws(`b_id.*`, regex=T) %>%
# ggplot(aes(inv_logit(.value))) + geom_histogram(binwidth = .01)
b_mf %>% gather_draws(b_X, b_Z) %>% mean_hdci()
# varying effects (non-centered - next week! )
mr <- ulam(
alist(
Y ~ bernoulli(p),
logit(p) <- a[g] + b_X*X + b_Z*Z[g],
transpars > vector[Ng]:a <<- abar + z*tau,
z[g] ~ dnorm(0,1),
c(b_X,b_Z) ~ dnorm(0,1),
abar ~ dnorm(0,1),
tau ~ dexp(1)
), data=dat , chains=4, cores=4, sample=TRUE)
mr %>% gather_draws(b_X, b_Z) %>% mean_hdci()
# The Mundlak Machine
xbar <- sapply(1:N_groups, function(j) mean (X[g==j]))
dat$Xbar <- xbar
mrx <- ulam(
alist(
Y ~ bernoulli(p),
logit (p) <- a[g] + b_X*X + b_Z*Z[g] + buy*Xbar[g],
transpars> vector[Ng]:a <<- abar + z*tau,
z[g] ~ dnorm(0,1),
c(b_X, buy, b_Z) ~ dnorm(0,1),
abar ~ dnorm(0,1),
tau ~ dexp(1)
),
data=dat, chains=4, cores=4 ,
sample=TRUE )
mrx %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mrx <- brm(Y ~ (1|id) + X + Z + Xbar, data = sim,
family = bernoulli(),
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("normal(0, 1)", class = "b", coef = "Xbar"),
set_prior("exponential(1)", class = "sd")))
b_mrx
b_mrx %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mrc <- brm(Y ~ (1|id) + X + Z, data = sim %>% mutate(X = X - Xbar),
family = bernoulli(),
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("exponential(1)", class = "sd")))
b_mrc
b_mrc %>% gather_draws(b_X, b_Z) %>% mean_hdci()
# The Latent Mundlak Machine
mru <- ulam(
alist(
# y model
Y ~ bernoulli(p),
logit(p) <- a[g] + b_X*X + b_Z*Z[g] + buy*u[g],
transpars> vector[Ng]:a <<- abar + z*tau,
# X model
X ~ normal(mu,sigma),
mu <- aX + bux*u[g],
vector[Ng]:u ~ normal (0,1),
# priors
z[g] ~ dnorm(0,1),
c(aX, b_X, buy, b_Z) ~ dnorm(0, 1),
bux ~ dexp(1),
abar ~ dnorm (0,1),
tau ~ dexp(1),
sigma ~ dexp(1)
),
data = dat, chains = 4, cores=4, sample=TRUE)
mru %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mru_gr <- brm(Y ~ 1 +(1|id) + X + Z + me(Xbar, Xse, gr = id), data = sim,
family = bernoulli(),
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("normal(0, 1)", class = "b", coef = "meXbarXsegrEQid"),
set_prior("exponential(1)", class = "sd"),
set_prior("exponential(1)", class = "sdme")))
b_mru_mi <- brm(bf(Y ~ (1|id) + X + Z + mi(X2), family = bernoulli()) +
bf(X2 | mi(Xse) ~ (1|id),
family = gaussian()), data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X", resp = "Y"),
set_prior("normal(0, 1)", class = "b", coef = "Z", resp = "Y"),
set_prior("normal(0, 1)", class = "b", coef = "miX2", resp = "Y"),
set_prior("exponential(1)", class = "sd", resp = "Y"),
set_prior("exponential(1)", class = "sd", resp = "X2"),
set_prior("exponential(1)", class = "sigma", resp = "X2")))
b_mru_mi
b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z) %>% mean_hdci()
This is the approach Matti Vuorre posted on his blog, except that I rearranged to instead the slope of Xbar. This approach does not work at all, as far as I can tell.
latent_formula <- bf(
# Y model
Y ~ interceptY + bX*X + bXbar*Xbar + bZ*Z,
interceptY ~ 1 + (1 | id),
bX + bZ + bXbar ~ 1,
# Xbar model
nlf(X ~ Xbar),
Xbar ~ 1 + (1 | id),
nl = TRUE
) #+
# bernoulli()
get_prior(latent_formula, data = sim)
prior class coef group resp dpar nlpar lb
student_t(3, 0, 2.5) sigma 0
(flat) b bX
(flat) b Intercept bX
(flat) b bXbar
(flat) b Intercept bXbar
(flat) b bZ
(flat) b Intercept bZ
(flat) b interceptY
(flat) b Intercept interceptY
student_t(3, 0, 2.5) sd interceptY 0
student_t(3, 0, 2.5) sd id interceptY 0
student_t(3, 0, 2.5) sd Intercept id interceptY 0
(flat) b Xbar
(flat) b Intercept Xbar
student_t(3, 0, 2.5) sd Xbar 0
student_t(3, 0, 2.5) sd id Xbar 0
student_t(3, 0, 2.5) sd Intercept id Xbar 0
ub source
default
default
(vectorized)
default
(vectorized)
default
(vectorized)
default
(vectorized)
default
(vectorized)
(vectorized)
default
(vectorized)
default
(vectorized)
(vectorized)
b_mru_mv <- brm(latent_formula, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", nlpar = "bX"),
set_prior("normal(0, 1)", class = "b", nlpar = "bZ"),
set_prior("normal(0, 1)", class = "b", nlpar = "bXbar"),
set_prior("normal(0, 1)", class = "b", nlpar = "Xbar"),
set_prior("normal(0, 1)", class = "b", nlpar = "interceptY"),
set_prior("exponential(1)", class = "sd", nlpar = "interceptY"),
set_prior("exponential(1)", class = "sd", nlpar = "Xbar")))
SAMPLING FOR MODEL '3f860d32989dccf8dfe1419c5089d23e' NOW (CHAIN 1).
Chain 1:
Chain 1: Gradient evaluation took 0.000188 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 1.88 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1:
Chain 1:
Chain 1: Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 1: Iteration: 200 / 2000 [ 10%] (Warmup)
Chain 1: Iteration: 400 / 2000 [ 20%] (Warmup)
Chain 1: Iteration: 600 / 2000 [ 30%] (Warmup)
Chain 1: Iteration: 800 / 2000 [ 40%] (Warmup)
Chain 1: Iteration: 1000 / 2000 [ 50%] (Warmup)
Chain 1: Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 1: Iteration: 1200 / 2000 [ 60%] (Sampling)
Chain 1: Iteration: 1400 / 2000 [ 70%] (Sampling)
Chain 1: Iteration: 1600 / 2000 [ 80%] (Sampling)
Chain 1: Iteration: 1800 / 2000 [ 90%] (Sampling)
Chain 1: Iteration: 2000 / 2000 [100%] (Sampling)
Chain 1:
Chain 1: Elapsed Time: 57.3365 seconds (Warm-up)
Chain 1: 43.5455 seconds (Sampling)
Chain 1: 100.882 seconds (Total)
Chain 1:
SAMPLING FOR MODEL '3f860d32989dccf8dfe1419c5089d23e' NOW (CHAIN 2).
Chain 2:
Chain 2: Gradient evaluation took 0.000173 seconds
Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 1.73 seconds.
Chain 2: Adjust your expectations accordingly!
Chain 2:
Chain 2:
Chain 2: Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 2: Iteration: 200 / 2000 [ 10%] (Warmup)
Chain 2: Iteration: 400 / 2000 [ 20%] (Warmup)
Chain 2: Iteration: 600 / 2000 [ 30%] (Warmup)
Chain 2: Iteration: 800 / 2000 [ 40%] (Warmup)
Chain 2: Iteration: 1000 / 2000 [ 50%] (Warmup)
Chain 2: Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 2: Iteration: 1200 / 2000 [ 60%] (Sampling)
Chain 2: Iteration: 1400 / 2000 [ 70%] (Sampling)
Chain 2: Iteration: 1600 / 2000 [ 80%] (Sampling)
Chain 2: Iteration: 1800 / 2000 [ 90%] (Sampling)
Chain 2: Iteration: 2000 / 2000 [100%] (Sampling)
Chain 2:
Chain 2: Elapsed Time: 52.9574 seconds (Warm-up)
Chain 2: 73.3569 seconds (Sampling)
Chain 2: 126.314 seconds (Total)
Chain 2:
SAMPLING FOR MODEL '3f860d32989dccf8dfe1419c5089d23e' NOW (CHAIN 3).
Chain 3:
Chain 3: Gradient evaluation took 0.000139 seconds
Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 1.39 seconds.
Chain 3: Adjust your expectations accordingly!
Chain 3:
Chain 3:
Chain 3: Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 3: Iteration: 200 / 2000 [ 10%] (Warmup)
Chain 3: Iteration: 400 / 2000 [ 20%] (Warmup)
Chain 3: Iteration: 600 / 2000 [ 30%] (Warmup)
Chain 3: Iteration: 800 / 2000 [ 40%] (Warmup)
Chain 3: Iteration: 1000 / 2000 [ 50%] (Warmup)
Chain 3: Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 3: Iteration: 1200 / 2000 [ 60%] (Sampling)
Chain 3: Iteration: 1400 / 2000 [ 70%] (Sampling)
Chain 3: Iteration: 1600 / 2000 [ 80%] (Sampling)
Chain 3: Iteration: 1800 / 2000 [ 90%] (Sampling)
Chain 3: Iteration: 2000 / 2000 [100%] (Sampling)
Chain 3:
Chain 3: Elapsed Time: 69.7308 seconds (Warm-up)
Chain 3: 62.2176 seconds (Sampling)
Chain 3: 131.948 seconds (Total)
Chain 3:
SAMPLING FOR MODEL '3f860d32989dccf8dfe1419c5089d23e' NOW (CHAIN 4).
Chain 4:
Chain 4: Gradient evaluation took 0.000137 seconds
Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 1.37 seconds.
Chain 4: Adjust your expectations accordingly!
Chain 4:
Chain 4:
Chain 4: Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 4: Iteration: 200 / 2000 [ 10%] (Warmup)
Chain 4: Iteration: 400 / 2000 [ 20%] (Warmup)
Chain 4: Iteration: 600 / 2000 [ 30%] (Warmup)
Chain 4: Iteration: 800 / 2000 [ 40%] (Warmup)
Chain 4: Iteration: 1000 / 2000 [ 50%] (Warmup)
Chain 4: Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 4: Iteration: 1200 / 2000 [ 60%] (Sampling)
Chain 4: Iteration: 1400 / 2000 [ 70%] (Sampling)
Chain 4: Iteration: 1600 / 2000 [ 80%] (Sampling)
Chain 4: Iteration: 1800 / 2000 [ 90%] (Sampling)
Chain 4: Iteration: 2000 / 2000 [100%] (Sampling)
Chain 4:
Chain 4: Elapsed Time: 63.9066 seconds (Warm-up)
Chain 4: 69.5717 seconds (Sampling)
Chain 4: 133.478 seconds (Total)
Chain 4:
b_mru_mv
Family: gaussian
Links: mu = identity; sigma = identity
Formula: Y ~ interceptY + bX * X + bXbar * Xbar + bZ * Z
interceptY ~ 1 + (1 | id)
bX ~ 1
bZ ~ 1
bXbar ~ 1
X ~ Xbar
Xbar ~ 1 + (1 | id)
Data: sim (Number of observations: 646)
Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup draws = 4000
Group-Level Effects:
~id (Number of levels: 300)
Estimate Est.Error l-95% CI u-95% CI Rhat
sd(interceptY_Intercept) 0.20 0.11 0.01 0.36 1.12
sd(Xbar_Intercept) 0.55 0.44 0.05 1.75 1.02
Bulk_ESS Tail_ESS
sd(interceptY_Intercept) 30 194
sd(Xbar_Intercept) 242 255
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat
interceptY_Intercept 0.22 0.51 -0.95 1.24 1.01
bX_Intercept 0.08 0.83 -1.55 1.73 1.04
bZ_Intercept -0.03 0.02 -0.07 0.02 1.00
bXbar_Intercept 0.11 0.83 -1.52 1.74 1.05
Xbar_Intercept 0.02 0.75 -1.45 1.53 1.01
Bulk_ESS Tail_ESS
interceptY_Intercept 1921 1750
bX_Intercept 79 1543
bZ_Intercept 2207 2403
bXbar_Intercept 52 1353
Xbar_Intercept 1418 1121
Family Specific Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 0.31 0.01 0.29 0.34 1.00 1625 1830
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
draws <- bind_rows(
rethinking_naive = mn %>% gather_draws(b_X, b_Z),
rethinking_fixed = mf %>% gather_draws(b_X, b_Z),
rethinking_multilevel = mr %>% gather_draws(b_X, b_Z),
rethinking_mundlak = mrx %>% gather_draws(b_X, b_Z, buy),
brms_naive = b_mn %>% gather_draws(b_X, b_Z),
brms_fixed = b_mf %>% gather_draws(b_X, b_Z),
brms_multilevel = b_mr %>% gather_draws(b_X, b_Z),
brms_mundlak = b_mrx %>% gather_draws(b_X, b_Z, b_Xbar),
brms_mundlak_centered = b_mrc %>% gather_draws(b_X, b_Z),
rethinking_latent_mundlak = mru %>% gather_draws(b_X, b_Z, buy),
brms_latent_mundlak = b_mru_gr %>% gather_draws(b_X, b_Z, bsp_meXbarXsegrEQid),
brms_latent_mundlak_mi = b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z, bsp_Y_miX2),
.id = "model") %>%
separate(model, c("package", "model"), extra = "merge") %>%
mutate(model = fct_inorder(factor(model)),
.variable = str_replace(.variable, "_Y", ""),
.variable = recode(.variable,
"buy" = "b_Xbar",
"bsp_miX2" = "b_Xbar",
"bsp_meXbarXsegrEQid" = "b_Xbar"),
.variable = factor(.variable, levels = c("b_X", "b_Z", "b_Xbar")))
draws <- draws %>% group_by(package, model, .variable) %>%
mean_hdci(.width = c(.95, .99)) %>%
ungroup()
ggplot(draws, aes(y = package, x = .value, xmin = .lower, xmax = .upper)) +
geom_pointinterval(position = position_dodge(width = .4)) +
ggrepel::geom_text_repel(aes(label = if_else(.width == .95, sprintf("%.2f", .value), NA_character_)), nudge_y = .1) +
geom_vline(aes(xintercept = true_val), linetype = 'dashed', data = tibble(true_val = c(b_X, b_Z, b_Ug - b_X), .variable = factor(c("b_X", "b_Z", "b_Xbar"), levels = c("b_X", "b_Z", "b_Xbar")))) + scale_color_discrete(breaks = rev(levels(draws$model))) +
facet_grid(model ~ .variable, scales = "free_x") +
theme_bw() +
theme(legend.position = c(0.99,0.99),
legend.justification = c(1,1))
For some reason, I did not manage to specify a uniform prior in rethinking. I think this explains the poor showing for the rethinking fixed effects model. Fixed effects regression + Bayes is not really a happy couple.
Then, I simulated a Gaussian outcome instead. I didn’t run the rethinking models here, except for the latent Mundlak model.
set.seed(201910)
# families
N_groups <- 300
a0 <- (-2)
b_X <- (1)
b_Z <- (-0.5)
b_Ug <- (3)
# 2 or more siblings
g_sizes <- 2 + rpois(N_groups, lambda = 0.2) # sample into groups
table(g_sizes)
N_id <- sum(g_sizes)
g <- rep(1:N_groups, times = g_sizes)
Ug <- rnorm(N_groups, sd = 0.8) # group confounds
X <- rnorm(N_id, Ug[g] ) # individual varying trait
Z <- rnorm(N_groups) # group varying trait (observed)
Y <- a0 + b_X * X + b_Ug*Ug[g] + b_Z*Z[g] + rnorm(N_id)
groups <- tibble(id = factor(1:N_groups), Ug, Z)
sim <- tibble(id = factor(g), X, Y) %>% full_join(groups, by = "id") %>% arrange(id) %>% group_by(id) %>%
mutate(Xbar = mean(X)) %>% ungroup()
sim %>% distinct(id, Ug, Xbar) %>% select(-id) %>% cor(use = "p")
lm(Y ~ X + Z, data = sim)
lm(Y ~ Ug + X + Z, data = sim)
lm(Y ~ id + X + Z, data = sim)
sim <- sim %>% group_by(id) %>%
mutate(Xse = sd(X)/sqrt(n())) %>% ungroup() %>%
mutate(X2 = X)
b_mn <- brm(Y ~ X + Z, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("normal(0, 2)", class = "Intercept")))
b_mn
b_mn %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mf <- brm(Y ~ 1 + id + X + Z, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("normal(0, 2)", class = "b"))
# , sample_prior = "only"
)
# b_mf %>% gather_draws(`b_id.*`, regex=T) %>%
# ggplot(aes(inv_logit(.value))) + geom_histogram(binwidth = .01)
b_mf %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mr <- brm(Y ~ (1|id) + X + Z, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("exponential(1)", class = "sd")))
b_mr
b_mr %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mrx <- brm(Y ~ (1|id) + X + Z + Xbar, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("normal(0, 1)", class = "b", coef = "Xbar"),
set_prior("exponential(1)", class = "sd")))
b_mrx
b_mrx %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mrc <- brm(Y ~ (1|id) + X + Z, data = sim %>% mutate(X = X - Xbar),
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("exponential(1)", class = "sd")))
b_mrc
b_mrc %>% gather_draws(b_X, b_Z) %>% mean_hdci()
dat <- list(Y = Y, X = X, g = g, Ng = N_groups, Z = Z)
# The Latent Mundlak Machine
mru <- ulam(
alist(
# y model
Y ~ normal(muY, sigmaY),
muY <- a[g] + b_X*X + b_Z*Z[g] + buy*u[g],
transpars> vector[Ng]:a <<- abar + z*tau,
# X model
X ~ normal(mu,sigma),
mu <- aX + bux*u[g],
vector[Ng]:u ~ normal (0,1),
# priors
z[g] ~ dnorm(0,1),
c(aX, b_X, buy, b_Z) ~ dnorm(0, 1),
bux ~ dexp(1),
abar ~ dnorm (0,1),
tau ~ dexp(1),
c(sigma, sigmaY) ~ dexp(1)
),iter = 2000,
data = dat, chains = 4, cores=4, sample=TRUE)
mru %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mru_gr <- brm(Y ~ 1 +(1|id) + X + Z + me(Xbar, Xse, gr = id), data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("normal(0, 1)", class = "b", coef = "meXbarXsegrEQid"),
set_prior("exponential(1)", class = "sd"),
set_prior("exponential(1)", class = "sdme")))
b_mru_mi <- brm(bf(Y ~ (1|id) + X + Z + mi(X2)) +
bf(X2 | mi(Xse) ~ (1|id), family = gaussian()), data = sim %>% mutate(X2 = X),
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X", resp = "Y"),
set_prior("normal(0, 1)", class = "b", coef = "Z", resp = "Y"),
set_prior("normal(0, 1)", class = "b", coef = "miX2", resp = "Y"),
set_prior("exponential(1)", class = "sd", resp = "Y"),
set_prior("exponential(1)", class = "sd", resp = "X2"),
set_prior("exponential(1)", class = "sigma", resp = "X2")))
b_mru_mi
b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z) %>% mean_hdci()
latent_formula <- bf(nl = TRUE,
# Y model
Y ~ interceptY + bX*X + X2l + bZ*Z,
X2l ~ 0 + mi(X2),
interceptY ~ 1 + (1 | id),
bX + bZ ~ 1, family = gaussian()) +
bf(X2 | mi(Xse) ~ 1 + (1|id), family = gaussian()) #+
# bernoulli()
# get_prior(latent_formula, data = sim)
b_mru_minl <- brm(latent_formula, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", nlpar = "X2l", resp = "Y"),
set_prior("normal(0, 1)", class = "b", nlpar = "bX", resp = "Y"),
set_prior("normal(0, 1)", class = "b", nlpar = "bZ", resp = "Y"),
# set_prior("normal(0, 1)", class = "b", nlpar = "bX2", resp = "Y"),
set_prior("normal(0, 1)", class = "b", nlpar = "interceptY", resp = "Y"),
set_prior("exponential(1)", class = "sd", nlpar = "interceptY", resp = "Y"),
set_prior("normal(0, 1)", class = "Intercept", resp = "X2"),
set_prior("exponential(1)", class = "sd", resp = "X2")))
b_mru_minl
This is the approach Matti Vuorre posted on his blog, except that I rearranged to instead the slope of Xbar. This approach does not work at all, as far as I can tell.
latent_formula <- bf(
# Y model
Y ~ interceptY + bX*X + bXbar*Xbar + bZ*Z,
interceptY ~ 1 + (1 | id),
bX + bZ + bXbar ~ 1,
# Xbar model
nlf(X ~ Xbar),
Xbar ~ 1 + (1 | id),
nl = TRUE
) #+
# bernoulli()
get_prior(latent_formula, data = sim)
b_mru_mv <- brm(latent_formula, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", nlpar = "bX"),
set_prior("normal(0, 1)", class = "b", nlpar = "bZ"),
set_prior("normal(0, 1)", class = "b", nlpar = "bXbar"),
set_prior("normal(0, 1)", class = "b", nlpar = "Xbar"),
set_prior("normal(0, 1)", class = "b", nlpar = "interceptY"),
set_prior("exponential(1)", class = "sd", nlpar = "interceptY"),
set_prior("exponential(1)", class = "sd", nlpar = "Xbar")))
b_mru_mv
Family: gaussian
Links: mu = identity; sigma = identity
Formula: Y ~ interceptY + bX * X + bXbar * Xbar + bZ * Z
interceptY ~ 1 + (1 | id)
bX ~ 1
bZ ~ 1
bXbar ~ 1
X ~ Xbar
Xbar ~ 1 + (1 | id)
Data: sim (Number of observations: 646)
Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup draws = 4000
Group-Level Effects:
~id (Number of levels: 300)
Estimate Est.Error l-95% CI u-95% CI Rhat
sd(interceptY_Intercept) 1.76 1.19 0.06 3.43 1.35
sd(Xbar_Intercept) 1.54 0.99 0.03 3.67 1.19
Bulk_ESS Tail_ESS
sd(interceptY_Intercept) 10 88
sd(Xbar_Intercept) 16 84
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat
interceptY_Intercept -0.72 0.92 -2.25 1.14 1.03
bX_Intercept 0.11 1.10 -2.08 2.16 1.24
bZ_Intercept -0.41 0.18 -0.75 -0.04 1.01
bXbar_Intercept 0.11 1.08 -1.98 2.13 1.24
Xbar_Intercept 0.06 1.17 -1.80 3.70 1.19
Bulk_ESS Tail_ESS
interceptY_Intercept 190 554
bX_Intercept 12 140
bZ_Intercept 1252 1963
bXbar_Intercept 12 136
Xbar_Intercept 15 51
Family Specific Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 1.47 0.06 1.37 1.59 1.02 257 96
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
draws <- bind_rows(
brms_naive = b_mn %>% gather_draws(b_X, b_Z),
brms_fixed = b_mf %>% gather_draws(b_X, b_Z),
brms_multilevel = b_mr %>% gather_draws(b_X, b_Z),
brms_mundlak = b_mrx %>% gather_draws(b_X, b_Z, b_Xbar),
brms_mundlak_centered = b_mrc %>% gather_draws(b_X, b_Z),
rethinking_latent_mundlak = mru %>% gather_draws(b_X, b_Z, buy),
brms_latent_mundlak = b_mru_gr %>% gather_draws(b_X, b_Z, bsp_meXbarXsegrEQid),
brms_latent_mundlak_mi = b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z, bsp_Y_miX2),
.id = "model") %>%
separate(model, c("package", "model"), extra = "merge") %>%
mutate(model = fct_inorder(factor(model)),
.variable = str_replace(.variable, "_Y", ""),
.variable = recode(.variable,
"buy" = "b_Xbar",
"bsp_miX2" = "b_Xbar",
"bsp_meXbarXsegrEQid" = "b_Xbar"),
.variable = factor(.variable, levels = c("b_X", "b_Z", "b_Xbar")))
draws <- draws %>% group_by(package, model, .variable) %>%
mean_hdci(.width = c(.95, .99)) %>%
ungroup()
ggplot(draws, aes(y = package, x = .value, xmin = .lower, xmax = .upper)) +
geom_pointinterval(position = position_dodge(width = .4)) +
ggrepel::geom_text_repel(aes(label = if_else(.width == .95, sprintf("%.2f", .value), NA_character_)), nudge_y = .1) +
geom_vline(aes(xintercept = true_val), linetype = 'dashed', data = tibble(true_val = c(b_X, b_Z, b_Ug - b_X), .variable = factor(c("b_X", "b_Z", "b_Xbar"), levels = c("b_X", "b_Z", "b_Xbar")))) + scale_color_discrete(breaks = rev(levels(draws$model))) +
facet_grid(model ~ .variable, scales = "free_x") +
theme_bw() +
theme(legend.position = c(0.99,0.99),
legend.justification = c(1,1))
Finally, I simulated a Gaussian outcome and a binary exposure (X). Especially in small groups, the group average of a binary variable can be pretty far off. I didn’t run the rethinking models here, except for the latent Mundlak model.
I was not able to implement this model in brms. brms does not permit the specification of a binary variable as missing using mi()
. And adjusting for Xbar
using a linear probability model did not work.
set.seed(201910)
# families
N_groups <- 300
a0 <- (-2)
b_X <- (1)
b_Z <- (-0.5)
b_Ug <- (3)
# 2 or more siblings
g_sizes <- 2 + rpois(N_groups, lambda = 0.2) # sample into groups
table(g_sizes)
N_id <- sum(g_sizes)
g <- rep(1:N_groups, times = g_sizes)
Ug <- rnorm(N_groups, sd = 0.8) # group confounds
X <- rbern(N_id, p=inv_logit(rnorm(N_id, Ug[g] ) ) ) # individual varying trait
table(X)
Z <- rnorm(N_groups) # group varying trait (observed)
Y <- a0 + b_X * X + b_Ug*Ug[g] + b_Z*Z[g] + rnorm(N_id)
groups <- tibble(id = factor(1:N_groups), Ug, Z)
sim <- tibble(id = factor(g), X, Y) %>% full_join(groups, by = "id") %>% arrange(id) %>% group_by(id) %>%
mutate(Xbar = mean(X)) %>% ungroup()
sim %>% distinct(id, Ug, Xbar) %>% select(-id) %>% cor(use = "p")
lm(Y ~ X + Z, data = sim)
lm(Y ~ Ug + X + Z, data = sim)
lm(Y ~ id + X + Z, data = sim)
sim <- sim %>% group_by(id) %>%
mutate(Xse = sd(X)/sqrt(n())) %>% ungroup() %>%
mutate(X2 = X)
b_mn <- brm(Y ~ X + Z, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("normal(0, 2)", class = "Intercept")))
b_mn
b_mn %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mf <- brm(Y ~ 1 + id + X + Z, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("normal(0, 2)", class = "b"))
# , sample_prior = "only"
)
# b_mf %>% gather_draws(`b_id.*`, regex=T) %>%
# ggplot(aes(inv_logit(.value))) + geom_histogram(binwidth = .01)
b_mf %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mr <- brm(Y ~ (1|id) + X + Z, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("exponential(1)", class = "sd")))
b_mr
b_mr %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mrx <- brm(Y ~ (1|id) + X + Z + Xbar, data = sim,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("normal(0, 1)", class = "b", coef = "Xbar"),
set_prior("exponential(1)", class = "sd")))
b_mrx
b_mrx %>% gather_draws(b_X, b_Z) %>% mean_hdci()
b_mrc <- brm(Y ~ (1|id) + X + Z, data = sim %>% mutate(X = X - Xbar),
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("exponential(1)", class = "sd")))
b_mrc
b_mrc %>% gather_draws(b_X, b_Z) %>% mean_hdci()
dat <- list(Y = Y, X = X, g = g, Ng = N_groups, Z = Z)
# The Latent Mundlak Machine
mru <- ulam(
alist(
# y model
Y ~ normal(muY, sigmaY),
muY <- a[g] + b_X*X + b_Z*Z[g] + buy*u[g],
transpars> vector[Ng]:a <<- abar + z*tau,
# X model
X ~ bernoulli(p),
logit(p) <- aX + bux*u[g],
vector[Ng]:u ~ normal (0,1),
# priors
z[g] ~ dnorm(0,1),
c(aX, b_X, buy, b_Z) ~ dnorm(0, 1),
bux ~ dexp(1),
abar ~ dnorm (0,1),
tau ~ dexp(1),
c(sigma, sigmaY) ~ dexp(1)
),
data = dat, chains = 4, cores=4, sample=TRUE)
mru %>% gather_draws(b_X, b_Z) %>% mean_hdci()
I had to add .01 to the Xse
because brms does not permit 0 measurement error (when modelling measurement error).
b_mru_gr <- brm(Y ~ 1 +(1|id) + X + Z + me(Xbar, Xse, gr = id), data = sim %>% mutate(Xse = Xse + 0.01),
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X"),
set_prior("normal(0, 1)", class = "b", coef = "Z"),
set_prior("normal(0, 1)", class = "b", coef = "meXbarXsegrEQid"),
set_prior("exponential(1)", class = "sd"),
set_prior("exponential(1)", class = "sdme")))
brms does not support the mi() notation for binary variables. Also, it does not accept zero measurement error, so I averaged across groups to take the mean Xse.
b_mru_mi <- brm(bf(Y ~ (1|id) + X + Z + mi(X2)) +
bf(X2 | mi(Xse) ~ (1|id)), family = gaussian(), data = sim %>% mutate(Xse = Xse + 0.01), iter = 4000,
prior = c(
set_prior("normal(0, 1)", class = "b", coef = "X", resp = "Y"),
set_prior("normal(0, 1)", class = "b", coef = "Z", resp = "Y"),
set_prior("normal(0, 1)", class = "b", coef = "miX2", resp = "Y"),
set_prior("exponential(1)", class = "sd", resp = "Y"),
set_prior("exponential(1)", class = "sd", resp = "X2"),
set_prior("constant(0.01)", class = "sigma", resp = "X2")))
b_mru_mi
b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z) %>% mean_hdci()
draws <- bind_rows(
brms_naive = b_mn %>% gather_draws(b_X, b_Z),
brms_fixed = b_mf %>% gather_draws(b_X, b_Z),
brms_multilevel = b_mr %>% gather_draws(b_X, b_Z),
brms_mundlak = b_mrx %>% gather_draws(b_X, b_Z, b_Xbar),
brms_mundlak_centered = b_mrc %>% gather_draws(b_X, b_Z),
rethinking_latent_mundlak = mru %>% gather_draws(b_X, b_Z, buy),
brms_latent_mundlak = b_mru_gr %>% gather_draws(b_X, b_Z, bsp_meXbarXsegrEQid),
brms_latent_mundlak_mi = b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z, bsp_Y_miX2),
.id = "model") %>%
separate(model, c("package", "model"), extra = "merge") %>%
mutate(model = fct_inorder(factor(model)),
.variable = str_replace(.variable, "_Y", ""),
.variable = recode(.variable,
"buy" = "b_Xbar",
"bsp_miX2" = "b_Xbar",
"bsp_meXbarXsegrEQid" = "b_Xbar"),
.variable = factor(.variable, levels = c("b_X", "b_Z", "b_Xbar")))
draws <- draws %>% group_by(package, model, .variable) %>%
mean_hdci(.width = c(.95, .99)) %>%
ungroup()
ggplot(draws, aes(y = package, x = .value, xmin = .lower, xmax = .upper)) +
geom_pointinterval(position = position_dodge(width = .4)) +
ggrepel::geom_text_repel(aes(label = if_else(.width == .95, sprintf("%.2f", .value), NA_character_)), nudge_y = .1) +
geom_vline(aes(xintercept = true_val), linetype = 'dashed', data = tibble(true_val = c(b_X, b_Z, b_Ug - b_X), .variable = factor(c("b_X", "b_Z", "b_Xbar"), levels = c("b_X", "b_Z", "b_Xbar")))) + scale_color_discrete(breaks = rev(levels(draws$model))) +
facet_grid(model ~ .variable, scales = "free_x") +
theme_bw() +
theme(legend.position = c(0.99,0.99),
legend.justification = c(1,1))
the group here can be an individual measured several times, a sibling group, a school, a nation, etc.↩︎
Mundlak, Y. 1978: On the pooling of time series and cross section data. Econometrica 46:69-85.↩︎
What McElreath calls Full Luxury Bayes: Latent Mundlak machine↩︎
I think the reason he thought it works is that in the data he used, from McNeish & Hamaker 2020, the group mean varies very little except for measurement error, so we expect it not to make a difference.↩︎
If you see mistakes or want to suggest changes, please create an issue on the source repository.
Text and figures are licensed under Creative Commons Attribution CC BY 4.0. Source code is available at https://github.com/rubenarslan/rubenarslan.github.io, unless otherwise noted. The figures that have been reused from other sources don't fall under this license and can be recognized by a note in their caption: "Figure from ...".
For attribution, please cite this work as
Arslan (2023, March 15). One lives only to make blunders: Latent group means with brms. Retrieved from https://rubenarslan.github.io/posts/2023-03-11-latent-group-mean-centering-revisited/
BibTeX citation
@misc{arslan2023latent, author = {Arslan, Ruben C.}, title = {One lives only to make blunders: Latent group means with brms}, url = {https://rubenarslan.github.io/posts/2023-03-11-latent-group-mean-centering-revisited/}, year = {2023} }