Estimating a Bayesian proportional hazards model
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
A recent conversation with a colleague about a large stepped-wedge design (SW-CRT) cluster randomized trial piqued my interest, because the primary outcome is time-to-event. This is not something I’ve seen before. A quick dive into the literature suggested that time-to-event outcomes are uncommon in SW-CRTs-and that the best analytic approach is not obvious. I was intrigued by how to analyze the data to estimate a hazard ratio while accounting for clustering and potential secular trends that might influence the time to the event.
Of course, my first thought was: How would I simulate data to explore different modeling approaches? And then: Could a Bayesian approach be useful here?
Generating data with clustering and a secular time trend turned out to be quite straightforward, which I’ll share that in a future post. Here, I’m focusing on the my first attempts to implement a Bayesian model that might eventually accommodate cluster-level random effects and flexible secular trends using splines, something I explored more generally in an earlier post.
Below, I start by generating a simple set of time-to-event outcomes (without any clustering or time trends) and fit a traditional Cox proportional hazards model to serve as a comparison. I then fit a Bayesian proportional hazards model using Stan
code drawn from the online guide. That model works fine, but it has a key limitation that I try to address, first partially successfully, and then pretty successfully. This post walks through all these steps.
Simulating an RCT with time-to-event outcomes
Here are the R
packages that are used in this post:
library(simstudy) library(data.table) library(survival) library(cmdstanr)
And here are the data definitions for a two-arm randomized controlled trial that is stratified by a variable \(M\). Both the treatment \(A\) and covariate \(M\) are associated with the time-to-event outcome, as specified in defS
. On average, the treatment \(A\) speeds up the time-to-event, and \(M\) slows things down. (In simstudy
survival times are generated using a Weibull data generation process.)
defI <- defData(varname = "M", formula = 0.3, dist = "binary") |> defData(varname = "A", formula = "1;1", variance = "M", dist = "trtAssign") defS <- defSurv( varname = "timeEvent", formula = "-11.6 + ..delta * A + ..beta_m * M", shape = 0.30) |> defSurv(varname = "censorTime", formula = -11.3, shape = .35) ## Parameters delta <- 1.5 beta_m <- -1.0
We are generating 1,000 independent observations:
set.seed(123) dd <- genData(1000, defI) dd <- genSurv(dd, defS, timeName = "tte", censorName = "censorTime", eventName = "event") dd ## Key: <id> ## Index: <type> ## id M A tte event type ## <int> <int> <int> <num> <num> <char> ## 1: 1 0 0 26.974 1 timeEvent ## 2: 2 1 0 38.353 1 timeEvent ## 3: 3 0 0 32.836 1 timeEvent ## 4: 4 1 0 28.768 0 censorTime ## 5: 5 1 0 54.366 1 timeEvent ## --- ## 996: 996 1 1 11.012 0 censorTime ## 997: 997 0 0 15.420 1 timeEvent ## 998: 998 0 0 21.212 1 timeEvent ## 999: 999 1 0 41.153 0 censorTime ## 1000: 1000 0 1 25.659 1 timeEvent
Here is a Kaplan-Meier plot showing the “survival” times for each level of \(M\) and each treatment arm:
Fitting a traditional Cox proportional hazards model, we can see that the log hazard ratio for treatment \(A\) is greater than 0, suggesting that on average the time-to-event is shorter for those in the treatment arm. Likewise, those with \(M=1\) have longer times-to-events and the log hazard ratio is less than zero:
cox_model <- coxph(Surv(tte, event) ~ A + M, data = dd) summary(cox_model) ## Call: ## coxph(formula = Surv(tte, event) ~ A + M, data = dd) ## ## n= 1000, number of events= 821 ## ## coef exp(coef) se(coef) z Pr(>|z|) ## A 1.44309 4.23374 0.08018 18.00 <2e-16 *** ## M -0.92537 0.39638 0.08302 -11.15 <2e-16 *** ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## exp(coef) exp(-coef) lower .95 upper .95 ## A 4.2337 0.2362 3.6181 4.9542 ## M 0.3964 2.5228 0.3369 0.4664 ## ## Concordance= 0.695 (se = 0.009 ) ## Likelihood ratio test= 415.5 on 2 df, p=<2e-16 ## Wald test = 389 on 2 df, p=<2e-16 ## Score (logrank) test = 423.4 on 2 df, p=<2e-16
First Bayes model
As I mentioned before, I turned to Stan
documentation for the code that follows. I won’t go into the detailed derivation of the partial likelihood here since that is covered very nicely in the document. However, it is useful to see the final likelihood that is then reflected in the code.
The likelihood is written as follows (and note that the \(j\)’s represent only the cases with observed times, while \(j'\) include censored cases):
\[ L(\boldsymbol{\beta})= \prod_{j=1}^{N^{obs}} \left( \frac{\exp(\mathbf{x}_j^\top \boldsymbol{\beta})}{\sum_{j'=j}^{N} \exp(\mathbf{x}_{j'}^\top \boldsymbol{\beta})}\right) \]
where:
- \(N^{obs}\) is the number of observed times,
- \(N\) is the total number of observations (including censored),
- \(\mathbf{x}_i\) is the vector of covariates for the \(i\)-th observation,
- \(\boldsymbol{\beta}\) is the vector of coefficients,
We want the log likelihood, which transforms the product to a sum of logs:
\[ \begin{aligned} \log Pr[\text{obs. fails ordered } 1, \dots N^{obs}|\mathbf{x}, \boldsymbol{\beta}] &= \sum_{j=1}^{N^{obs}} \log \left( \frac{\exp(\mathbf{x}_j^\top \boldsymbol{\beta})}{\sum_{j'=j}^{N} \exp(\mathbf{x}_{j'}^\top \boldsymbol{\beta})}\right) \\ \\ &= \sum_{j=1}^{N^{obs}} \left(\mathbf{x}_j^\top \boldsymbol{\beta} - \log \sum_{j'=j}^{N} \exp(\mathbf{x}_{j'}^\top \boldsymbol{\beta})\right) \end{aligned} \]
In Stan
, the function logSumExp
can be used to efficiently calculate
\[ \log \sum_{j'=j}^{N} \exp(\mathbf{x}_{j'}^\top \boldsymbol{\beta}). \]
This partial likelihood is implemented below in Stan
. One confusing aspect (at least to me) is the way censoring is handled. Essentially, all event times for censored cases are assumed to occur after the last observed time. That is, all censored cases are part of the risk set for observed events, something I had not seen before. This is a pretty big assumption and has implications for data where the actual censoring times occur before the last observed event time.
The code might be a little confusing, because the data are delivered to Stan
in reverse order. It is done this way to make calculation of the log likelihood more efficient. If you are trying to follow along with the code to see how it lines up with the equations above, keep this in mind.
stan_code <- " data { int<lower=0> K; // num covariates int<lower=0> N; // num uncensored obs vector[N] t; // event time (non-strict decreasing) matrix[N, K] x; // covariates for uncensored obs int N_c; // num censored obs real <lower=t[N]> t_c; // censoring time matrix[N_c, K] x_c; // covariates for censored obs } parameters { vector[K] beta; // slopes (no intercept) } transformed parameters { vector[N] log_theta = x * beta; vector[N_c] log_theta_c = x_c * beta; } model { beta ~ normal(0, 4); real log_denom = log_sum_exp(log_theta_c); for (n in 1:N) { log_denom = log_sum_exp(log_denom, log_theta[n]); target += log_theta[n] - log_denom; // log likelihood } } "
This code prepares the R
data for Stan
:
dd.o <- dd[event == 1] setorder(dd.o, -tte) x.o <- data.frame(dd.o[, .(A, M)]) N.o <- dd.o[, .N] t.o <- dd.o[, tte] dd.c <- dd[event == 0] setorder(dd.c, -tte) x.c <- data.frame(dd.c[, .(A, M)]) N.c <- dd.c[, .N] t.c <- dd.c[, tte] K <- ncol(x.o) # num covariates stan_data <- list( K = K, N = N.o, t = t.o, x = x.o, N_c = N.c, t_c = max(t.c), x_c = x.c )
I’m using cmdstanr
to interface with Stan
. First we compile the Stan
code.
stan_model <- cmdstan_model(write_stan_file(stan_code))
And then we fit the model. Even with 1,000 observations, the model estimates in just a couple of seconds on my laptop.
fit <- stan_model$sample( data = stan_data, iter_warmup = 1000, iter_sampling = 4000, chains = 4, parallel_chains = 4, refresh = 0 ) ## Running MCMC with 4 parallel chains... ## ## Chain 3 finished in 1.7 seconds. ## Chain 1 finished in 1.9 seconds. ## Chain 4 finished in 1.8 seconds. ## Chain 2 finished in 1.9 seconds. ## ## All 4 chains finished successfully. ## Mean chain execution time: 1.8 seconds. ## Total execution time: 2.0 seconds.
Looking at the log hazard ratios, something seems awry. The Bayesian estimates are attenuated relative to the original Cox PH estimates, and this is not due to the prior distribution assumption. Rather, it is the result of assuming that all censored times are longer than the longest observed time-to-event. I’m not showing this here, but the attenuation does largely go away if there is no censoring.
fit$summary(variables = "beta") ## # A tibble: 2 × 10 ## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail ## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 beta[1] 1.03 1.03 0.0723 0.0720 0.916 1.15 1.00 14541. 11205. ## 2 beta[2] -0.588 -0.588 0.0802 0.0802 -0.721 -0.455 1.00 13062. 10867.
Given this limitation, I decided to try to implement an algorithm that accommodates dynamic risk sets, effectively taking censored cases out of the analysis as soon as they are censored. This is what the coxph
model estimated above does.
Second Bayes model
The partial likelihood for my alternative Cox proportional hazards model is given by:
\[ L(\boldsymbol{\beta}) = \prod_{i=1}^{N} \left( \frac{\exp(\mathbf{x}_i^\top \boldsymbol{\beta})}{\sum_{j \in R(t_i)} \exp(\mathbf{x}_j^\top \boldsymbol{\beta})} \right)^{\delta_i} \]
where:
- \(N\) is the number of observations (censored or not),
- \(\mathbf{x}_i\) is the vector of covariates for the \(i\)-th observation,
- \(\boldsymbol{\beta}\) is the vector of coefficients,
- \(t_i\) is the observed time for the \(i\)-th observation,
- \(R(t_i)\) is the risk set at time \(t_i\) (the set of individuals still at risk just before time \(t_i\)) - which only includes censored cases censored after \(t_i\),
- \(\delta_i\) is the event indicator (\(\delta_i = 1\) if the event occurred, \(\delta_i = 0\) if censored).
The Stan
code below implements the log likelihood that follows from this. In contrast to the first version, the data is passed to Stan
in ascending order. The one major complication is that I needed to create a search function in order to define the “risk” set. (Actually, I asked DeepSeek to do this for me.) The fundamental difference between this version and the first is the calculation of the denominator in the log likelihood.
stan_code <- " functions { int binary_search(vector v, real tar_val) { int low = 1; int high = num_elements(v); int result = -1; while (low <= high) { int mid = (low + high) %/% 2; if (v[mid] == tar_val) { result = mid; // Store the index high = mid - 1; // Look for earlier occurrences } else if (v[mid] < tar_val) { low = mid + 1; } else { high = mid - 1; } } return result; } } data { int<lower=0> K; // Number of covariates int<lower=0> N_o; // Number of uncensored observations vector[N_o] t_o; // Event times (sorted in decreasing order) matrix[N_o, K] x_o; // Covariates for uncensored observations int<lower=0> N; // Number of total observations vector[N] t; // Individual times matrix[N, K] x; // Covariates for all observations } parameters { vector[K] beta; // Fixed effects for covariates } model { // Prior beta ~ normal(0, 4); // Model vector[N] log_theta = x * beta; for (n_o in 1:N_o) { int start_risk = binary_search(t, t_o[n_o]); // Use binary search real log_denom = log_sum_exp(log_theta[start_risk:N]); target += log_theta[start_risk] - log_denom; } } "
Preparing the data is a little different. This time, I am passing the observed data and the full data, both in ascending order:
dx <- copy(dd) setorder(dx, tte) dx.o <- dx[event == 1] x_o <- data.frame(dx.o[, .(A, M)]) N_o <- dx.o[, .N] t_o <- dx.o[, tte] x_all <- data.frame(dx[, .(A, M)]) N_all <- dx[, .N] t_all <- dx[, tte] K <- ncol(x_o) # num covariates stan_data <- list( K = K, N_o = N_o, t_o = t_o, x_o = x_o, N = N_all, t = t_all, x = x_all ) stan_model <- cmdstan_model(write_stan_file(stan_code)) fit <- stan_model$sample( data = stan_data, iter_warmup = 1000, iter_sampling = 4000, chains = 4, parallel_chains = 4, refresh = 0 ) ## Running MCMC with 4 parallel chains... ## ## Chain 2 finished in 57.9 seconds. ## Chain 4 finished in 67.2 seconds. ## Chain 1 finished in 67.8 seconds. ## Chain 3 finished in 67.9 seconds. ## ## All 4 chains finished successfully. ## Mean chain execution time: 65.2 seconds. ## Total execution time: 68.1 seconds. fit$summary(variables = "beta") ## # A tibble: 2 × 10 ## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail ## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 beta[1] 1.44 1.43 0.0801 0.0801 1.31 1.57 1.00 12604. 10436. ## 2 beta[2] -0.925 -0.924 0.0816 0.0817 -1.06 -0.791 1.00 12492. 10436.
Two things to note about this model. First, it appears that estimates are spot on! They mirror the estimates from the coxph
model using the survival
package. That is encouraging. Second, the implementation is very inefficient, taking more than a minute to run! This is less encouraging, and does not bode well for a more complex model that incorporates random effects and splines.
Final Bayes model
I asked ChatGPT this time to see if it could make my code more efficient. (I’ve been comparing ChatGPT and DeepSeek - both have been pretty impressive.) It recognized that my initial brute force approach was calculating each denominator anew for each observed evaluation. This is highly inefficient, on the order of \(O(N^2)\) (Unprompted, ChatGPT provided me with this information). The algorithm is reconfigured so that the denominators are pre-calculated - starting with the last time point (censored or observed), similar to the first approach. It turns out this is much more efficient with \(O(N)\).
stan_code <- " ... model { // Prior beta ~ normal(0, 4); // Likelihood vector[N] theta = x * beta; vector[N] log_sum_exp_theta; // Compute cumulative sum of exp(theta) in log space log_sum_exp_theta[N] = theta[N]; // Initialize the last element for (i in tail(sort_indices_desc(t), N-1)) { log_sum_exp_theta[i] = log_sum_exp(theta[i], log_sum_exp_theta[i + 1]); } for (n_o in 1:N_o) { int start_risk = binary_search(t, t_o[n_o]); // Use binary search real log_denom = log_sum_exp_theta[start_risk]; target += theta[start_risk] - log_denom; } } "
The data requirements for this are the same as the second model, so no changes are needed there.
stan_model <- cmdstan_model(write_stan_file(stan_code)) fit <- stan_model$sample( data = stan_data, iter_warmup = 1000, iter_sampling = 4000, chains = 4, parallel_chains = 4, refresh = 0 ) ## Running MCMC with 4 parallel chains... ## ## Chain 3 finished in 2.2 seconds. ## Chain 2 finished in 2.4 seconds. ## Chain 4 finished in 2.4 seconds. ## Chain 1 finished in 2.7 seconds. ## ## All 4 chains finished successfully. ## Mean chain execution time: 2.4 seconds. ## Total execution time: 2.8 seconds. fit$summary(variables = "beta") ## # A tibble: 2 × 10 ## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail ## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 beta[1] 1.44 1.44 0.0800 0.0794 1.30 1.57 1.00 12036. 10499. ## 2 beta[2] -0.925 -0.925 0.0831 0.0834 -1.06 -0.789 1.00 11887. 10386.
The model also works well, as the estimate is the same as the previous Bayesian model and the coxph
model. More importantly, the computation time is reduced considerably, to about 3 seconds. My hope is that this final model is flexible enough to handle the extensions I need for the data structure that sparked this whole exploration.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.