Richard McElreath recently shared a tutorial on how to do estimate latent group means (à la Mundlak) in Stan. I try to follow his code, do the same with brms, and try to implement a solution Matti Vuorre posted.

Ruben C. Arslan

In a recent lecture, Richard McElreath gave an introduction to Mundlak devices (code). Briefly, social scientists would often like to remove confounding at the group level.1

Econometricians and epidemiologists tend to use fixed effects regression to do so. Psychologists often use random effects regressions and adjust for the mean of the predictor of interest. The two solutions yield similar results for the within-group effect, but differ on how they treat group-level effects.

Arguably, there are more ways to mess up the random effects model, but it’s more flexible and efficient and it makes it easier to model e.g., cross-level interactions and effect heterogeneity.

One way you can mess up, is if your group sizes are small (e.g., sibling groups in a low fertility country) and your estimate of the group mean is a poor stand-in for the confounds you’d like to adjust for. A solution to this is to estimate the latent group mean instead, i.e. to account for the fact that we are estimating it with uncertainty. Doing so3 is fairly easy in Stan, but it’s less clear how to do it with everyone’s favourite Stan interface, brms.

In which way does this sampling error at the group level bias your results? It attenuates your estimate of b_Xbar (Lüdtke’s bias). I thought it would also bias my estimate of b_X, because I’m underadjusting by ignoring the measurement error in my covariate. That is not so. Why? To get the intuition, it helped me to consider the case where I first subtract Xbar from X. Xdiff is necessarily uncorrelated with Xbar, meaning any changes in the association of Xbar and Y (which we get by modelling sampling error) are irrelevant to Y ~ Xdiff.

I think I’m not the only who is confused about this. Matti Vuorre shared a solution to center X with the latent group mean using brms’ non-linear syntax. Centering with the latent group mean is not necessary though. Also, I tried the syntax and it doesn’t correctly recover simulated effects.4

My two attempts at a solution:

Using me()

  1. Calculate the group mean Xbar, e.g. df %>% group_by(id) %>% mutate(Xbar = mean(X)).
  2. Calculate the standard error of the mean , e.g. df %>% group_by(id) %>% mutate(Xse = sd(X)/sqrt(n())).
  3. Adjust for the term me(Xbar, Xse, gr = id) explicitly specifying the grouping variable.
  4. Profit

Using mi()

  1. Duplicating the exposure X->X2
  2. Calculating the SE of the group mean outside the brms call
  3. Using a second formula bf(X2 | mi(Xse) ~ (1|id)) in brms to estimate the latent group mean with shrinkage.
  4. Adjusting for mi(X2) in the regression on Y.

How did it go? High-level summary

The simulations and their results are documented in detail below. Click on the “Implementations” to see the model code.

Bernoulli simulations

I started with Richard’s simulations of a binary outcome. I was not able to reproduce his model performances exactly and ended up increasing the simulated sample size to make estimates a bit more stable (to more clearly show bias vs. variance).

Up to the latent Mundlak model, the brms equivalents, perform, well, equivalently. For the latent Mundlak model, the brms gets slightly different coefficients, but a similar overall result.

Show code
options(brms.backend = "cmdstanr",  # I use the cmdstanr backend
        mc.cores = 8, 
        brms.threads = 1,           # which allows me to multithread
        brms.file_refit = "on_change") # this is useful when doing 


# families
N_groups <- 300
a0 <- (-2)
b_X <- (1)
b_Z <- (-0.5)
b_Ug <- (3)
# 2 or more siblings
g_sizes <- 2 + rpois(N_groups, lambda = 0.2) # sample into groups
N_id <- sum(g_sizes)
g <- rep(1:N_groups, times = g_sizes)
Ug <- rnorm(N_groups, sd = 0.8) # group confounds
X <- rnorm(N_id, Ug[g] ) # individual varying trait
Z <- rnorm(N_groups) # group varying trait (observed)
Y <- rbern(N_id, p=inv_logit( a0 + b_X * X + b_Ug*Ug[g] + b_Z*Z[g] ) )

# glm(Y ~ X + Z[g] + Ug[g], binomial())
# glm(Y ~ X + Z[g], binomial())
groups <- tibble(id = factor(1:N_groups), Ug, Z)
sim <- tibble(id = factor(g), X, Y) %>% full_join(groups, by = "id") %>% arrange(id) %>% group_by(id) %>% 
  mutate(Xbar = mean(X)) %>% ungroup()
sim %>% distinct(id, Ug, Xbar) %>% select(-id) %>% cor(use = "p")
glm(Y ~ X + Z, data = sim, binomial(link = "logit"))
glm(Y ~ Ug + X + Z, data = sim, binomial(link = "logit"))
glm(Y ~ id + X + Z, data = sim, binomial(link = "logit"))

sim <- sim %>% group_by(id) %>% 
  mutate(Xse = sd(X)/sqrt(n())) %>% ungroup() %>% 
  mutate(X2 = X)
Implementations of the various models



Show code
dat <-  list(Y = Y, X = X, g = g, Ng = N_groups, Z = Z)
mn <- ulam(
        Y ~ bernoulli(p),
        logit(p) <- a + b_X*X + b_Z*Z[g],
        a ~ dnorm(0,2),
        c(b_X,b_Z) ~ dnorm(0,1)
    ) , data=dat , chains=4 , cores=4 )

mn %>% gather_draws(b_X, b_Z) %>% mean_hdci()


Show code
b_mn <- brm(Y ~ X + Z, data = sim,
            family = bernoulli(),
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("normal(0, 2)", class = "Intercept")))

b_mn %>% gather_draws(b_X, b_Z) %>% mean_hdci()

Fixed effects


Show code
dat <-  list(Y = Y, X = X, g = g, Ng = N_groups, Z = Z)
# fixed effects
mf <- ulam(alist(
  Y ~ bernoulli(p),
  logit(p) <- a[g] + b_X*X + b_Z*Z[g],
  a[g] ~ dnorm(0,1.5), # can't get it to work using dunif(-1,1)
  c(b_X,b_Z) ~ dnorm(0,1)
), data=dat, chains = 4, cores=4 )

mf %>% gather_draws(b_X, b_Z) %>% mean_hdci()


Show code
b_mf <- brm(Y ~ 1 + id + X + Z, data = sim,
            family = bernoulli(),
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z")
              ,set_prior("uniform(-1, 1)",lb = -1, ub = 1, class = "b"))
            # , sample_prior = "only"

# b_mf %>% gather_draws(`b_id.*`, regex=T) %>% 
#   ggplot(aes(inv_logit(.value))) + geom_histogram(binwidth = .01)
b_mf %>% gather_draws(b_X, b_Z) %>% mean_hdci()



Show code
# varying effects (non-centered - next week! )
mr <- ulam(
    Y ~ bernoulli(p),
    logit(p) <- a[g] + b_X*X + b_Z*Z[g],
    transpars > vector[Ng]:a <<- abar + z*tau,
    z[g] ~ dnorm(0,1),
    c(b_X,b_Z) ~ dnorm(0,1),
    abar ~ dnorm(0,1),
    tau ~ dexp(1)
  ), data=dat , chains=4, cores=4, sample=TRUE)

mr %>% gather_draws(b_X, b_Z) %>% mean_hdci()


Show code
b_mr <- brm(Y ~ (1|id) + X + Z, data = sim,
            family = bernoulli(),
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("exponential(1)", class = "sd")))

b_mr %>% gather_draws(b_X, b_Z) %>% mean_hdci()

Multilevel Mundlak

Show code
# The Mundlak Machine
xbar <- sapply(1:N_groups, function(j) mean (X[g==j]))
dat$Xbar <- xbar
mrx <- ulam(
    Y ~ bernoulli(p),
    logit (p) <- a[g] + b_X*X + b_Z*Z[g] + buy*Xbar[g],
    transpars> vector[Ng]:a <<- abar + z*tau, 
    z[g] ~ dnorm(0,1),
    c(b_X, buy, b_Z) ~ dnorm(0,1),
    abar ~ dnorm(0,1),
    tau ~ dexp(1)
data=dat, chains=4, cores=4 ,
sample=TRUE )

mrx %>% gather_draws(b_X, b_Z) %>% mean_hdci()


Show code
b_mrx <- brm(Y ~ (1|id) + X + Z + Xbar, data = sim,
            family = bernoulli(),
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("normal(0, 1)", class = "b", coef = "Xbar"),
              set_prior("exponential(1)", class = "sd")))

b_mrx %>% gather_draws(b_X, b_Z) %>% mean_hdci()
subtracting the group mean instead of adjusting for it
Show code
b_mrc <- brm(Y ~ (1|id) + X + Z, data = sim %>% mutate(X = X - Xbar),
            family = bernoulli(),
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("exponential(1)", class = "sd")))

b_mrc %>% gather_draws(b_X, b_Z) %>% mean_hdci()

Multilevel Latent Mundlak

Show code
# The Latent Mundlak Machine
mru <- ulam(
    # y model 
    Y ~ bernoulli(p),
    logit(p) <- a[g] + b_X*X + b_Z*Z[g] + buy*u[g],
    transpars> vector[Ng]:a <<- abar + z*tau,
    # X model
    X ~ normal(mu,sigma),
    mu <- aX + bux*u[g],
    vector[Ng]:u ~ normal (0,1),
    # priors
    z[g] ~ dnorm(0,1),
    c(aX, b_X, buy, b_Z) ~ dnorm(0, 1),
    bux ~ dexp(1),
    abar ~ dnorm (0,1),
    tau ~ dexp(1),
    sigma ~ dexp(1)
  data = dat, chains = 4, cores=4, sample=TRUE)

mru %>% gather_draws(b_X, b_Z) %>% mean_hdci()


brms me() measurement error notation
Show code
b_mru_gr <- brm(Y ~ 1 +(1|id) + X + Z + me(Xbar, Xse, gr = id), data = sim, 
                family = bernoulli(),
             prior = c(
                 set_prior("normal(0, 1)", class = "b", coef = "X"),
                 set_prior("normal(0, 1)", class = "b", coef = "Z"),
                 set_prior("normal(0, 1)", class = "b", coef = "meXbarXsegrEQid"),
                 set_prior("exponential(1)", class = "sd"),
                 set_prior("exponential(1)", class = "sdme")))
brms mi() missingness notation
Show code
b_mru_mi <- brm(bf(Y ~ (1|id) + X + Z + mi(X2), family = bernoulli()) +
             bf(X2 | mi(Xse) ~ (1|id), 
                family = gaussian()), data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X", resp = "Y"),
              set_prior("normal(0, 1)", class = "b", coef = "Z", resp = "Y"),
              set_prior("normal(0, 1)", class = "b", coef = "miX2", resp = "Y"),
              set_prior("exponential(1)", class = "sd", resp = "Y"),
              set_prior("exponential(1)", class = "sd", resp = "X2"),
              set_prior("exponential(1)", class = "sigma", resp = "X2")))

b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z) %>% mean_hdci()

non-working implementation of Matti Vuorre’s approach

This is the approach Matti Vuorre posted on his blog, except that I rearranged to instead the slope of Xbar. This approach does not work at all, as far as I can tell.

Show code
latent_formula <- bf(
  # Y model
  Y ~ interceptY  + bX*X + bXbar*Xbar + bZ*Z,
  interceptY ~ 1 + (1 | id),
  bX + bZ + bXbar ~ 1,

  # Xbar model
  nlf(X ~ Xbar),
  Xbar ~ 1 + (1 | id),
  nl = TRUE
) #+
#  bernoulli()

get_prior(latent_formula, data = sim)
                prior class      coef group resp dpar      nlpar lb
 student_t(3, 0, 2.5) sigma                                       0
               (flat)     b                                   bX   
               (flat)     b Intercept                         bX   
               (flat)     b                                bXbar   
               (flat)     b Intercept                      bXbar   
               (flat)     b                                   bZ   
               (flat)     b Intercept                         bZ   
               (flat)     b                           interceptY   
               (flat)     b Intercept                 interceptY   
 student_t(3, 0, 2.5)    sd                           interceptY  0
 student_t(3, 0, 2.5)    sd              id           interceptY  0
 student_t(3, 0, 2.5)    sd Intercept    id           interceptY  0
               (flat)     b                                 Xbar   
               (flat)     b Intercept                       Xbar   
 student_t(3, 0, 2.5)    sd                                 Xbar  0
 student_t(3, 0, 2.5)    sd              id                 Xbar  0
 student_t(3, 0, 2.5)    sd Intercept    id                 Xbar  0
 ub       source
Show code
b_mru_mv <- brm(latent_formula, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", nlpar = "bX"),
              set_prior("normal(0, 1)", class = "b", nlpar = "bZ"),
              set_prior("normal(0, 1)", class = "b", nlpar = "bXbar"),
              set_prior("normal(0, 1)", class = "b", nlpar = "Xbar"),
              set_prior("normal(0, 1)", class = "b", nlpar = "interceptY"),
              set_prior("exponential(1)", class = "sd", nlpar = "interceptY"),
              set_prior("exponential(1)", class = "sd", nlpar = "Xbar")))

Show code
draws <- bind_rows(
  rethinking_naive = mn %>% gather_draws(b_X, b_Z),
  rethinking_fixed = mf %>% gather_draws(b_X, b_Z),
  rethinking_multilevel = mr %>% gather_draws(b_X, b_Z),
  rethinking_mundlak = mrx %>% gather_draws(b_X, b_Z, buy),
  brms_naive = b_mn %>% gather_draws(b_X, b_Z),
  brms_fixed = b_mf %>% gather_draws(b_X, b_Z),
  brms_multilevel = b_mr %>% gather_draws(b_X, b_Z),
  brms_mundlak = b_mrx %>% gather_draws(b_X, b_Z, b_Xbar),
  brms_mundlak_centered = b_mrc %>% gather_draws(b_X, b_Z),
  rethinking_latent_mundlak = mru %>% gather_draws(b_X, b_Z, buy),
  brms_latent_mundlak = b_mru_gr %>% gather_draws(b_X, b_Z, bsp_meXbarXsegrEQid),
  brms_latent_mundlak_mi = b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z, bsp_Y_miX2),
 .id = "model") %>% 
  separate(model, c("package", "model"), extra = "merge") %>% 
  mutate(model = fct_inorder(factor(model)),
         .variable = str_replace(.variable, "_Y", ""),
         .variable = recode(.variable, 
                            "buy" = "b_Xbar",
                            "bsp_miX2" = "b_Xbar",
                            "bsp_meXbarXsegrEQid" = "b_Xbar"),
         .variable = factor(.variable, levels = c("b_X", "b_Z", "b_Xbar")))

draws <- draws %>% group_by(package, model, .variable) %>% 
  mean_hdci(.width = c(.95, .99)) %>% 

ggplot(draws, aes(y = package, x = .value, xmin = .lower, xmax = .upper)) +
  geom_pointinterval(position = position_dodge(width = .4)) +
  ggrepel::geom_text_repel(aes(label = if_else(.width == .95, sprintf("%.2f", .value), NA_character_)), nudge_y = .1) +
  geom_vline(aes(xintercept = true_val), linetype = 'dashed', data = tibble(true_val = c(b_X, b_Z, b_Ug - b_X), .variable = factor(c("b_X", "b_Z", "b_Xbar"), levels = c("b_X", "b_Z", "b_Xbar")))) +  scale_color_discrete(breaks = rev(levels(draws$model))) +
  facet_grid(model ~ .variable, scales = "free_x") +
  theme_bw() +
  theme(legend.position = c(0.99,0.99),
        legend.justification = c(1,1))
Estimated coefficients and the true values (dashed line)

Figure 1: Estimated coefficients and the true values (dashed line)

For some reason, I did not manage to specify a uniform prior in rethinking. I think this explains the poor showing for the rethinking fixed effects model. Fixed effects regression + Bayes is not really a happy couple.

Gaussian simulations

Then, I simulated a Gaussian outcome instead. I didn’t run the rethinking models here, except for the latent Mundlak model.

Show code
# families
N_groups <- 300
a0 <- (-2)
b_X <- (1)
b_Z <- (-0.5)
b_Ug <- (3)
# 2 or more siblings
g_sizes <- 2 + rpois(N_groups, lambda = 0.2) # sample into groups
N_id <- sum(g_sizes)
g <- rep(1:N_groups, times = g_sizes)
Ug <- rnorm(N_groups, sd = 0.8) # group confounds
X <- rnorm(N_id, Ug[g] ) # individual varying trait
Z <- rnorm(N_groups) # group varying trait (observed)
Y <- a0 + b_X * X + b_Ug*Ug[g] + b_Z*Z[g] + rnorm(N_id)

groups <- tibble(id = factor(1:N_groups), Ug, Z)
sim <- tibble(id = factor(g), X, Y) %>% full_join(groups, by = "id") %>% arrange(id) %>% group_by(id) %>% 
  mutate(Xbar = mean(X)) %>% ungroup()
sim %>% distinct(id, Ug, Xbar) %>% select(-id) %>% cor(use = "p")
lm(Y ~ X + Z, data = sim)
lm(Y ~ Ug + X + Z, data = sim)
lm(Y ~ id + X + Z, data = sim)

sim <- sim %>% group_by(id) %>% 
  mutate(Xse = sd(X)/sqrt(n())) %>% ungroup() %>% 
  mutate(X2 = X)
Implementations of the various models



Show code
b_mn <- brm(Y ~ X + Z, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("normal(0, 2)", class = "Intercept")))

b_mn %>% gather_draws(b_X, b_Z) %>% mean_hdci()

Fixed effects


Show code
b_mf <- brm(Y ~ 1 + id + X + Z, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("normal(0, 2)", class = "b"))
            # , sample_prior = "only"

# b_mf %>% gather_draws(`b_id.*`, regex=T) %>% 
#   ggplot(aes(inv_logit(.value))) + geom_histogram(binwidth = .01)
b_mf %>% gather_draws(b_X, b_Z) %>% mean_hdci()



Show code
b_mr <- brm(Y ~ (1|id) + X + Z, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("exponential(1)", class = "sd")))

b_mr %>% gather_draws(b_X, b_Z) %>% mean_hdci()

Multilevel Mundlak


Show code
b_mrx <- brm(Y ~ (1|id) + X + Z + Xbar, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("normal(0, 1)", class = "b", coef = "Xbar"),
              set_prior("exponential(1)", class = "sd")))

b_mrx %>% gather_draws(b_X, b_Z) %>% mean_hdci()
subtracting the group mean instead of adjusting for it
Show code
b_mrc <- brm(Y ~ (1|id) + X + Z, data = sim %>% mutate(X = X - Xbar),
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("exponential(1)", class = "sd")))

b_mrc %>% gather_draws(b_X, b_Z) %>% mean_hdci()

Multilevel Latent Mundlak


Show code
dat <-  list(Y = Y, X = X, g = g, Ng = N_groups, Z = Z)
# The Latent Mundlak Machine
mru <- ulam(
    # y model 
    Y ~ normal(muY, sigmaY),
    muY <- a[g] + b_X*X + b_Z*Z[g] + buy*u[g],
    transpars> vector[Ng]:a <<- abar + z*tau,
    # X model
    X ~ normal(mu,sigma),
    mu <- aX + bux*u[g],
    vector[Ng]:u ~ normal (0,1),
    # priors
    z[g] ~ dnorm(0,1),
    c(aX, b_X, buy, b_Z) ~ dnorm(0, 1),
    bux ~ dexp(1),
    abar ~ dnorm (0,1),
    tau ~ dexp(1),
    c(sigma, sigmaY) ~ dexp(1)
  ),iter = 2000,
  data = dat, chains = 4, cores=4, sample=TRUE)

mru %>% gather_draws(b_X, b_Z) %>% mean_hdci()


brms me() measurement error notation
Show code
b_mru_gr <- brm(Y ~ 1 +(1|id) + X + Z + me(Xbar, Xse, gr = id), data = sim, 
             prior = c(
                 set_prior("normal(0, 1)", class = "b", coef = "X"),
                 set_prior("normal(0, 1)", class = "b", coef = "Z"),
                 set_prior("normal(0, 1)", class = "b", coef = "meXbarXsegrEQid"),
                 set_prior("exponential(1)", class = "sd"),
                 set_prior("exponential(1)", class = "sdme")))
brms mi() missingness notation
Show code
b_mru_mi <- brm(bf(Y ~ (1|id) + X + Z + mi(X2)) +
             bf(X2 | mi(Xse) ~ (1|id), family = gaussian()), data = sim %>% mutate(X2 = X),
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X", resp = "Y"),
              set_prior("normal(0, 1)", class = "b", coef = "Z", resp = "Y"),
              set_prior("normal(0, 1)", class = "b", coef = "miX2", resp = "Y"),
              set_prior("exponential(1)", class = "sd", resp = "Y"),
              set_prior("exponential(1)", class = "sd", resp = "X2"),
              set_prior("exponential(1)", class = "sigma", resp = "X2")))

b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z) %>% mean_hdci()
working implementation of the mi model using nonlinear syntax
Show code
latent_formula <- bf(nl = TRUE,
  # Y model
  Y ~ interceptY  + bX*X + X2l + bZ*Z,
  X2l ~ 0 + mi(X2),
  interceptY ~ 1 + (1 | id),
  bX + bZ ~ 1, family = gaussian()) +
  bf(X2 | mi(Xse) ~ 1 + (1|id), family = gaussian()) #+
#  bernoulli()

# get_prior(latent_formula, data = sim)

b_mru_minl <- brm(latent_formula, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", nlpar = "X2l", resp = "Y"),
              set_prior("normal(0, 1)", class = "b", nlpar = "bX", resp = "Y"),
              set_prior("normal(0, 1)", class = "b", nlpar = "bZ", resp = "Y"),
              # set_prior("normal(0, 1)", class = "b", nlpar = "bX2", resp = "Y"),
              set_prior("normal(0, 1)", class = "b", nlpar = "interceptY", resp = "Y"),
              set_prior("exponential(1)", class = "sd", nlpar = "interceptY", resp = "Y"),
              set_prior("normal(0, 1)", class = "Intercept", resp = "X2"),
              set_prior("exponential(1)", class = "sd", resp = "X2")))

non-working implementation of Matti Vuorre’s approach

This is the approach Matti Vuorre posted on his blog, except that I rearranged to instead the slope of Xbar. This approach does not work at all, as far as I can tell.

Show code
latent_formula <- bf(
  # Y model
  Y ~ interceptY  + bX*X + bXbar*Xbar + bZ*Z,
  interceptY ~ 1 + (1 | id),
  bX + bZ + bXbar ~ 1,

  # Xbar model
  nlf(X ~ Xbar),
  Xbar ~ 1 + (1 | id),
  nl = TRUE
) #+
#  bernoulli()

get_prior(latent_formula, data = sim)

b_mru_mv <- brm(latent_formula, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", nlpar = "bX"),
              set_prior("normal(0, 1)", class = "b", nlpar = "bZ"),
              set_prior("normal(0, 1)", class = "b", nlpar = "bXbar"),
              set_prior("normal(0, 1)", class = "b", nlpar = "Xbar"),
              set_prior("normal(0, 1)", class = "b", nlpar = "interceptY"),
              set_prior("exponential(1)", class = "sd", nlpar = "interceptY"),
              set_prior("exponential(1)", class = "sd", nlpar = "Xbar")))
Show code
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: Y ~ interceptY + bX * X + bXbar * Xbar + bZ * Z 
         interceptY ~ 1 + (1 | id)
         bX ~ 1
         bZ ~ 1
         bXbar ~ 1
         X ~ Xbar
         Xbar ~ 1 + (1 | id)
   Data: sim (Number of observations: 646) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Group-Level Effects: 
~id (Number of levels: 300) 
                         Estimate Est.Error l-95% CI u-95% CI Rhat
sd(interceptY_Intercept)     1.76      1.19     0.06     3.43 1.35
sd(Xbar_Intercept)           1.54      0.99     0.03     3.67 1.19
                         Bulk_ESS Tail_ESS
sd(interceptY_Intercept)       10       88
sd(Xbar_Intercept)             16       84

Population-Level Effects: 
                     Estimate Est.Error l-95% CI u-95% CI Rhat
interceptY_Intercept    -0.72      0.92    -2.25     1.14 1.03
bX_Intercept             0.11      1.10    -2.08     2.16 1.24
bZ_Intercept            -0.41      0.18    -0.75    -0.04 1.01
bXbar_Intercept          0.11      1.08    -1.98     2.13 1.24
Xbar_Intercept           0.06      1.17    -1.80     3.70 1.19
                     Bulk_ESS Tail_ESS
interceptY_Intercept      190      554
bX_Intercept               12      140
bZ_Intercept             1252     1963
bXbar_Intercept            12      136
Xbar_Intercept             15       51

Family Specific Parameters: 
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     1.47      0.06     1.37     1.59 1.02      257       96

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).


Show code
draws <- bind_rows(
  brms_naive = b_mn %>% gather_draws(b_X, b_Z),
  brms_fixed = b_mf %>% gather_draws(b_X, b_Z),
  brms_multilevel = b_mr %>% gather_draws(b_X, b_Z),
  brms_mundlak = b_mrx %>% gather_draws(b_X, b_Z, b_Xbar),
  brms_mundlak_centered = b_mrc %>% gather_draws(b_X, b_Z),
  rethinking_latent_mundlak = mru %>% gather_draws(b_X, b_Z, buy),
  brms_latent_mundlak = b_mru_gr %>% gather_draws(b_X, b_Z, bsp_meXbarXsegrEQid),
  brms_latent_mundlak_mi = b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z, bsp_Y_miX2),
 .id = "model") %>% 
  separate(model, c("package", "model"), extra = "merge") %>% 
  mutate(model = fct_inorder(factor(model)),
         .variable = str_replace(.variable, "_Y", ""),
         .variable = recode(.variable, 
                            "buy" = "b_Xbar",
                            "bsp_miX2" = "b_Xbar",
                            "bsp_meXbarXsegrEQid" = "b_Xbar"),
         .variable = factor(.variable, levels = c("b_X", "b_Z", "b_Xbar")))

draws <- draws %>% group_by(package, model, .variable) %>% 
  mean_hdci(.width = c(.95, .99)) %>% 

ggplot(draws, aes(y = package, x = .value, xmin = .lower, xmax = .upper)) +
  geom_pointinterval(position = position_dodge(width = .4)) +
  ggrepel::geom_text_repel(aes(label = if_else(.width == .95, sprintf("%.2f", .value), NA_character_)), nudge_y = .1) +
  geom_vline(aes(xintercept = true_val), linetype = 'dashed', data = tibble(true_val = c(b_X, b_Z, b_Ug - b_X), .variable = factor(c("b_X", "b_Z", "b_Xbar"), levels = c("b_X", "b_Z", "b_Xbar")))) +  scale_color_discrete(breaks = rev(levels(draws$model))) +
  facet_grid(model ~ .variable, scales = "free_x") +
  theme_bw() +
  theme(legend.position = c(0.99,0.99),
        legend.justification = c(1,1))
Estimated coefficients and the true values (dashed line)

Figure 2: Estimated coefficients and the true values (dashed line)

Gaussian + binary exposure simulations

Finally, I simulated a Gaussian outcome and a binary exposure (X). Especially in small groups, the group average of a binary variable can be pretty far off. I didn’t run the rethinking models here, except for the latent Mundlak model.

I was not able to implement this model in brms. brms does not permit the specification of a binary variable as missing using mi(). And adjusting for Xbar using a linear probability model did not work.

Show code
# families
N_groups <- 300
a0 <- (-2)
b_X <- (1)
b_Z <- (-0.5)
b_Ug <- (3)
# 2 or more siblings
g_sizes <- 2 + rpois(N_groups, lambda = 0.2) # sample into groups
N_id <- sum(g_sizes)
g <- rep(1:N_groups, times = g_sizes)
Ug <- rnorm(N_groups, sd = 0.8) # group confounds
X <- rbern(N_id, p=inv_logit(rnorm(N_id, Ug[g] ) ) )   # individual varying trait

Z <- rnorm(N_groups) # group varying trait (observed)
Y <- a0 + b_X * X + b_Ug*Ug[g] + b_Z*Z[g] + rnorm(N_id)

groups <- tibble(id = factor(1:N_groups), Ug, Z)
sim <- tibble(id = factor(g), X, Y) %>% full_join(groups, by = "id") %>% arrange(id) %>% group_by(id) %>% 
  mutate(Xbar = mean(X)) %>% ungroup()
sim %>% distinct(id, Ug, Xbar) %>% select(-id) %>% cor(use = "p")
lm(Y ~ X + Z, data = sim)
lm(Y ~ Ug + X + Z, data = sim)
lm(Y ~ id + X + Z, data = sim)

sim <- sim %>% group_by(id) %>% 
  mutate(Xse = sd(X)/sqrt(n())) %>% ungroup() %>% 
  mutate(X2 = X)
Implementations of the various models



Show code
b_mn <- brm(Y ~ X + Z, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("normal(0, 2)", class = "Intercept")))

b_mn %>% gather_draws(b_X, b_Z) %>% mean_hdci()

Fixed effects


Show code
b_mf <- brm(Y ~ 1 + id + X + Z, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("normal(0, 2)", class = "b"))
            # , sample_prior = "only"

# b_mf %>% gather_draws(`b_id.*`, regex=T) %>% 
#   ggplot(aes(inv_logit(.value))) + geom_histogram(binwidth = .01)
b_mf %>% gather_draws(b_X, b_Z) %>% mean_hdci()



Show code
b_mr <- brm(Y ~ (1|id) + X + Z, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("exponential(1)", class = "sd")))

b_mr %>% gather_draws(b_X, b_Z) %>% mean_hdci()

Multilevel Mundlak


Show code
b_mrx <- brm(Y ~ (1|id) + X + Z + Xbar, data = sim,
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("normal(0, 1)", class = "b", coef = "Xbar"),
              set_prior("exponential(1)", class = "sd")))

b_mrx %>% gather_draws(b_X, b_Z) %>% mean_hdci()
subtracting the group mean instead of adjusting for it
Show code
b_mrc <- brm(Y ~ (1|id) + X + Z, data = sim %>% mutate(X = X - Xbar),
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X"),
              set_prior("normal(0, 1)", class = "b", coef = "Z"),
              set_prior("exponential(1)", class = "sd")))

b_mrc %>% gather_draws(b_X, b_Z) %>% mean_hdci()

Multilevel Latent Mundlak


Show code
dat <-  list(Y = Y, X = X, g = g, Ng = N_groups, Z = Z)
# The Latent Mundlak Machine
mru <- ulam(
    # y model 
    Y ~ normal(muY, sigmaY),
    muY <- a[g] + b_X*X + b_Z*Z[g] + buy*u[g],
    transpars> vector[Ng]:a <<- abar + z*tau,
    # X model
    X ~ bernoulli(p),
    logit(p) <- aX + bux*u[g],
    vector[Ng]:u ~ normal (0,1),
    # priors
    z[g] ~ dnorm(0,1),
    c(aX, b_X, buy, b_Z) ~ dnorm(0, 1),
    bux ~ dexp(1),
    abar ~ dnorm (0,1),
    tau ~ dexp(1),
    c(sigma, sigmaY) ~ dexp(1)
  data = dat, chains = 4, cores=4, sample=TRUE)

mru %>% gather_draws(b_X, b_Z) %>% mean_hdci()


brms me() measurement error notation

I had to add .01 to the Xse because brms does not permit 0 measurement error (when modelling measurement error).

Show code
b_mru_gr <- brm(Y ~ 1 +(1|id) + X + Z + me(Xbar, Xse, gr = id), data = sim %>% mutate(Xse = Xse + 0.01), 
             prior = c(
                 set_prior("normal(0, 1)", class = "b", coef = "X"),
                 set_prior("normal(0, 1)", class = "b", coef = "Z"),
                 set_prior("normal(0, 1)", class = "b", coef = "meXbarXsegrEQid"),
                 set_prior("exponential(1)", class = "sd"),
                 set_prior("exponential(1)", class = "sdme")))
brms mi() missingness notation

brms does not support the mi() notation for binary variables. Also, it does not accept zero measurement error, so I averaged across groups to take the mean Xse.

Show code
b_mru_mi <- brm(bf(Y ~ (1|id) + X + Z + mi(X2)) +
             bf(X2 | mi(Xse) ~ (1|id)), family = gaussian(), data = sim %>% mutate(Xse = Xse + 0.01), iter = 4000,
            prior = c(
              set_prior("normal(0, 1)", class = "b", coef = "X", resp = "Y"),
              set_prior("normal(0, 1)", class = "b", coef = "Z", resp = "Y"),
              set_prior("normal(0, 1)", class = "b", coef = "miX2", resp = "Y"),
              set_prior("exponential(1)", class = "sd", resp = "Y"),
              set_prior("exponential(1)", class = "sd", resp = "X2"),
              set_prior("constant(0.01)", class = "sigma", resp = "X2")))

b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z) %>% mean_hdci()


Show code
draws <- bind_rows(
  brms_naive = b_mn %>% gather_draws(b_X, b_Z),
  brms_fixed = b_mf %>% gather_draws(b_X, b_Z),
  brms_multilevel = b_mr %>% gather_draws(b_X, b_Z),
  brms_mundlak = b_mrx %>% gather_draws(b_X, b_Z, b_Xbar),
  brms_mundlak_centered = b_mrc %>% gather_draws(b_X, b_Z),
  rethinking_latent_mundlak = mru %>% gather_draws(b_X, b_Z, buy),
  brms_latent_mundlak = b_mru_gr %>% gather_draws(b_X, b_Z, bsp_meXbarXsegrEQid),
  brms_latent_mundlak_mi = b_mru_mi %>% gather_draws(b_Y_X, b_Y_Z, bsp_Y_miX2),
 .id = "model") %>% 
  separate(model, c("package", "model"), extra = "merge") %>% 
  mutate(model = fct_inorder(factor(model)),
         .variable = str_replace(.variable, "_Y", ""),
         .variable = recode(.variable, 
                            "buy" = "b_Xbar",
                            "bsp_miX2" = "b_Xbar",
                            "bsp_meXbarXsegrEQid" = "b_Xbar"),
         .variable = factor(.variable, levels = c("b_X", "b_Z", "b_Xbar")))

draws <- draws %>% group_by(package, model, .variable) %>% 
  mean_hdci(.width = c(.95, .99)) %>% 

ggplot(draws, aes(y = package, x = .value, xmin = .lower, xmax = .upper)) +
  geom_pointinterval(position = position_dodge(width = .4)) +
  ggrepel::geom_text_repel(aes(label = if_else(.width == .95, sprintf("%.2f", .value), NA_character_)), nudge_y = .1) +
  geom_vline(aes(xintercept = true_val), linetype = 'dashed', data = tibble(true_val = c(b_X, b_Z, b_Ug - b_X), .variable = factor(c("b_X", "b_Z", "b_Xbar"), levels = c("b_X", "b_Z", "b_Xbar")))) +  scale_color_discrete(breaks = rev(levels(draws$model))) +
  facet_grid(model ~ .variable, scales = "free_x") +
  theme_bw() +
  theme(legend.position = c(0.99,0.99),
        legend.justification = c(1,1))
Estimated coefficients and the true values (dashed line)

Figure 3: Estimated coefficients and the true values (dashed line)

  1. the group here can be an individual measured several times, a sibling group, a school, a nation, etc.↩︎

  2. Mundlak, Y. 1978: On the pooling of time series and cross section data. Econometrica 46:69-85.↩︎

  3. What McElreath calls Full Luxury Bayes: Latent Mundlak machine↩︎

  4. I think the reason he thought it works is that in the data he used, from McNeish & Hamaker 2020, the group mean varies very little except for measurement error, so we expect it not to make a difference.↩︎


