Bayesian proportional hazards model for a stepped-wedge design

[This article was first published on ouR data generation, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

We’ve finally reached the end of the road. This is the fifth and last post in a series building up to a Bayesian proportional hazards model for analyzing a stepped-wedge cluster-randomized trial. If you are just joining in, you may want to start at the beginning.

The model presented here integrates non-linear time trends and cluster-specific random effects—elements we’ve previously explored in isolation. There’s nothing fundamentally new in this post; it brings everything together. Given that the groundwork has already been laid, I’ll keep the commentary brief and focus on providing the code.

Simulating data from a stepped-wedge CRT

I’ll generate a single data set for 25 sites, each site enrolling study participants over a 30-month period. Sites will transition from control to intervention sequentially, with one new site starting each month. Each site will enroll 25 patients each month.

The outcome (\(Y\)) is the number of days to an event. The treatment (\(A\)) reduces the time to event. Survival times also depend on the enrollment month—an effect I’ve exaggerated for illustration. Additionally, each site \(i\) has a site-specific effect \(b_i \sim N(\mu=0, \sigma = 0.5)\), which influences the time to event among its participants.

Here are the libraries needed for the code shown here:

library(simstudy)
library(data.table)
library(splines)
library(splines2)
library(survival)
library(survminer)
library(coxme)
library(cmdstanr)

Definitions

def <- defData(varname = "b", formula = 0, variance = 0.5^2)

defS <-
  defSurv(
    varname = "eventTime",
    formula = 
      "..int + ..delta_f * A + ..beta_1 * k + ..beta_2 * k^2 + ..beta_3 * k^3 + b",
    shape = 0.30)  |>
  defSurv(varname = "censorTime", formula = -11.3, shape = 0.36)

Parameters

int <- -11.6
delta_f <-  0.80

beta_1 <-  0.05
beta_2 <-  -0.025
beta_3 <- 0.001

Data generation

set.seed(28271)

### Site level data

ds <- genData(25, def, id = "site")                 
ds <- addPeriods(ds, 30, "site", perName = "k") 

# Each site has a unique starting point, site 1 starts period 3, site 2 period 4, etc.

ds <- trtStepWedge(ds, "site", nWaves = 25,     
                   lenWaves = 1, startPer = 3, 
                   grpName = "A", perName = "k")

### Individual level data

dd <- genCluster(ds, "timeID", numIndsVar = 25, level1ID = "id") 
dd <- genSurv(dd, defS, timeName = "Y", censorName = "censorTime", digits = 0,
              eventName = "event", typeName = "eventType")

### Final observed data set

dd <- dd[, .(id, site, k, A, Y, event)]

Here is a set of Kaplan-Meier plots for each site and enrollment period. When a site is in the intervention condition, the K-M curve is red. For simplicity, censoring is not shown, though about 20% of cases in this dataset are censored.

Model estimation

This model has quite a few components relative to the earlier models, but nothing is really new. There is a penalized spline for the effect of time and a random effect for each site. The primary parameter of interest is still \(\beta\).

For completeness, here is the model specification:

\[ \log L(\beta) = \sum_{j=1}^{J} \left[ \sum_{i \in D_j} \left(\beta A_i + \sum_{m=1} ^ M \gamma_m X_{m_i} + b_{s[i]} \right) - \sum_{r=0}^{d_j-1} \log \left( \sum_{k \in R_j} \left(\beta A_k + \sum_{m=1} ^ M \gamma_m X_{m_i} + b_{s[k]} \right) - r \cdot \bar{w}_j \right) \right] - \lambda \sum_{m=1}^{M} \left( Q^{(2)} \gamma \right)_m^2 \\ \]

where

  • \(J\): number of unique event times
  • \(M\): number of spline basis functions
  • \(D_j\) is the set of individuals who experience an event at time \(t_j\).
  • \(R_j\) is the risk set at time \(t_j\), including all individuals who are still at risk at that time.
  • \(d_j\) is the number of events occurring at time \(t_j\).
  • \(r\) ranges from 0 to \(d_j - 1\), iterating over the tied events.
  • \(\bar{w}_j\) represents the average risk weight of individuals experiencing an event at \(t_j\):

\[\bar{w}_j = \frac{1}{d_j} \sum_{i \in D_j} \left(\beta A_i + b_{s[i]} \right)\]

  • \(A_i\): binary indicator for treatment
  • \(X_{m_i}\): value of the \(m^{\text{th}}\) spline basis function for the \(i^{\text{th}}\) observation
  • \(Q^{(2)}\): the second-difference matrix of the spline function

The parameters of the model are

  • \(\beta\): treatment coefficient
  • \(\gamma_m\): spline coefficient for the \(m^\text{th}\) spline basis function
  • \(b_{s[i]}\): cluster-specific random effect, where \(s[i]\) is the cluster of patient \(i\)
  • \(\lambda\): the penalization term; this will not be estimated but provided by the user

The assumed prior distributions for \(\beta\) and the random effects are:

\[ \begin{aligned} \beta &\sim N(0,4) \\ b_i &\sim N(0,\sigma_b) \\ \sigma_b &\sim t_{\text{student}}(df = 3, \mu=0, \sigma = 2) \\ \gamma_m &\sim N(0,2) \end{aligned} \]

And here is the implementation of the model in Stan:

stan_code <- 
"
data {
  
  int<lower=1> S;          // Number of clusters
  int<lower=1> K;          // Number of covariates
  
  int<lower=1> N_o;        // Number of uncensored observations
  array[N_o] int i_o;      // Event times (sorted in decreasing order)

  int<lower=1> N;          // Number of total observations
  matrix[N, K] x;          // Covariates for all observations
  array[N] int<lower=1,upper=S> s;          // Cluster
  
  // Spline-related data
  
  int<lower=1> Q;          // Number of basis functions
  matrix[N, Q] B;          // Spline basis matrix
  matrix[N, Q] Q2_spline;  // 2nd derivative for penalization
  real<lower=0> lambda;    // penalization term
  
  array[N] int index;

  int<lower=0> T;            // Number of records as ties
  int<lower=1> J;            // Number of groups of ties
  array[T] int t_grp;        // Indicating tie group
  array[T] int t_index;      // Index in data set
  vector[T] t_adj;           // Adjustment for ties (efron)
  
}

parameters {
  
  vector[K] beta;          // Fixed effects for covariates
  
  vector[S] b;             // Random effects
  real<lower=0> sigma_b;   // SD of random effect
  
  vector[Q] gamma;         // Spline coefficients
  
}

model {
  
  // Priors
  
  beta ~ normal(0, 1);
  
  // Random effects
  
  b ~ normal(0, sigma_b);
  sigma_b ~ normal(0, 0.5);

  
  // Spline coefficients prior
  
  gamma ~ normal(0, 2);
  
  // Penalization term for spline second derivative
  
  target += -lambda * sum(square(Q2_spline * gamma));
  
  // Compute cumulative sum of exp(theta) in log space (more efficient)
  
  vector[N] theta;
  vector[N] log_sum_exp_theta;
  vector[J] exp_theta_grp = rep_vector(0, J);
  
  int first_in_grp;
  
  // Calculate theta for each observation
  
  for (i in 1:N) {
    theta[i] = dot_product(x[i], beta) + dot_product(B[i], gamma) + b[s[i]];
  }
  
  // Compute cumulative sum of log(exp(theta)) from last to first observation
  
  log_sum_exp_theta = rep_vector(0.0, N);
  log_sum_exp_theta[N] = theta[N];
  
  for (i in tail(sort_indices_desc(index), N-1)) {
    log_sum_exp_theta[i] = log_sum_exp(theta[i], log_sum_exp_theta[i + 1]);
  }

   // Efron algorithm - adjusting cumulative sum for ties
  
  for (i in 1:T) {
    exp_theta_grp[t_grp[i]] += exp(theta[t_index[i]]);
  }

  for (i in 1:T) {
  
    if (t_adj[i] == 0) {
      first_in_grp = t_index[i];
    }

    log_sum_exp_theta[t_index[i]] =
      log( exp(log_sum_exp_theta[first_in_grp]) - t_adj[i] * exp_theta_grp[t_grp[i]]);
  }
  
  // Likelihood for uncensored observations

  for (n_o in 1:N_o) {
    target += theta[i_o[n_o]] - log_sum_exp_theta[i_o[n_o]];
  }
}
"

Compiling the model:

stan_model <- cmdstan_model(write_stan_file(stan_code))

Getting the data from R to Stan:

dx <- copy(dd)
setorder(dx, Y)
dx[, index := .I]

dx.obs <- dx[event == 1]
N_obs <- dx.obs[, .N]
i_obs <- dx.obs[, index]

N_all <- dx[, .N]
x_all <- data.frame(dx[, .(A)])
s_all <- dx[, site]

K <- ncol(x_all)                 # num covariates - in this case just A
S <- dx[, length(unique(site))]

# Spline-related info

n_knots <- 5
spline_degree <- 3
knot_dist <- 1/(n_knots + 1)
probs <- seq(knot_dist, 1 - knot_dist, by = knot_dist)
knots <- quantile(dx$k, probs = probs)
spline_basis <- bs(dx$k, knots = knots, degree = spline_degree, intercept = TRUE)
B <- as.matrix(spline_basis)

Q2 <- dbs(dx$k, knots = knots, degree = spline_degree, derivs = 2, intercept = TRUE)
Q2_spline <- as.matrix(Q2)

ties <- dx[, .N, keyby = Y][N>1, .(grp = .I, Y)]
ties <- merge(ties, dx, by = "Y")
ties <- ties[, order := 1:.N, keyby = grp][, .(grp, index)]
ties[, adj := 0:(.N-1)/.N, keyby = grp]

stan_data <- list(
  S = S,
  K = K,
  N_o = N_obs,
  i_o = i_obs,
  N = N_all,
  x = x_all,
  s = s_all,
  Q = ncol(B),
  B = B,
  Q2_spline = Q2_spline,
  lambda = 0.15,
  index = dx$index,
  T = nrow(ties),
  J = max(ties$grp),
  t_grp = ties$grp,
  t_index = ties$index,
  t_adj = ties$adj
)

Now we sample from the posterior - you can see that it takes quite a while to run, at least on my 2020 MacBook Pro M1 with 8GB RAM:

fit_mcmc <- stan_model$sample(
  data = stan_data,
  seed = 1234,
  iter_warmup = 1000,
  iter_sampling = 4000,
  chains = 4,
  parallel_chains = 4,
  refresh = 0
)
## Running MCMC with 4 parallel chains...
## Chain 4 finished in 1847.8 seconds.
## Chain 1 finished in 2202.8 seconds.
## Chain 3 finished in 2311.8 seconds.
## Chain 2 finished in 2414.9 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 2194.3 seconds.
## Total execution time: 2415.3 seconds.
fit_mcmc$summary(variables = c("beta", "sigma_b"))
## # A tibble: 2 × 10
##   variable  mean median     sd    mad    q5   q95  rhat ess_bulk ess_tail
##   <chr>    <dbl>  <dbl>  <dbl>  <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>
## 1 beta[1]  0.815  0.815 0.0298 0.0298 0.767 0.865  1.00    3513.    5077.
## 2 sigma_b  0.543  0.535 0.0775 0.0739 0.432 0.683  1.00    3146.    5110.

Estimating a “frequentist” random-effects model

After all that, it turns out you can just fit a frailty model with random effects for site and a spline for time period \(k\) using the coxmme package. This is obviously much simpler then everything I have presented here.

frailty_model <- coxme(Surv(Y, event) ~ A + ns(k, df = 3) + (1 | site), data = dd)
summary(frailty_model)
## Mixed effects coxme model
##  Formula: Surv(Y, event) ~ A + ns(k, df = 3) + (1 | site) 
##     Data: dd 
## 
##   events, n = 14989, 18750
## 
## Random effects:
##   group  variable        sd  variance
## 1  site Intercept 0.5306841 0.2816256
##                   Chisq    df p   AIC   BIC
## Integrated loglik 18038  5.00 0 18028 17990
##  Penalized loglik 18185 27.85 0 18129 17917
## 
## Fixed effects:
##                    coef exp(coef) se(coef)      z      p
## A               0.80966   2.24714  0.02959  27.36 <2e-16
## ns(k, df = 3)1 -2.71392   0.06628  0.04428 -61.29 <2e-16
## ns(k, df = 3)2  1.04004   2.82933  0.07851  13.25 <2e-16
## ns(k, df = 3)3  4.48430  88.61492  0.04729  94.83 <2e-16

However, the advantage of the Bayesian model is its flexibility. For example, if you wanted to include site-specific spline curves—analogous to site-specific time effects—you could extend the Bayesian approach to do so. The current Bayesian model implements a study-wide time spline, but incorporating site-specific splines would be a natural extension. I initially hoped to implement site-specific splines using the mgcv package, but the models did not converge. I am quite confident that a Bayesian extension would, though it would likely require substantial computing resources. If someone wants me to try that, I certainly could, but for now, I’ll stop here.

To leave a comment for the author, please follow the link and comment on their blog: ouR data generation.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)