Estimating a Bayesian proportional hazards model

[This article was first published on ouR data generation, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

A recent conversation with a colleague about a large stepped-wedge design (SW-CRT) cluster randomized trial piqued my interest, because the primary outcome is time-to-event. This is not something I’ve seen before. A quick dive into the literature suggested that time-to-event outcomes are uncommon in SW-CRTs-and that the best analytic approach is not obvious. I was intrigued by how to analyze the data to estimate a hazard ratio while accounting for clustering and potential secular trends that might influence the time to the event.

Of course, my first thought was: How would I simulate data to explore different modeling approaches? And then: Could a Bayesian approach be useful here?

Generating data with clustering and a secular time trend turned out to be quite straightforward, which I’ll share that in a future post. Here, I’m focusing on the my first attempts to implement a Bayesian model that might eventually accommodate cluster-level random effects and flexible secular trends using splines, something I explored more generally in an earlier post.

Below, I start by generating a simple set of time-to-event outcomes (without any clustering or time trends) and fit a traditional Cox proportional hazards model to serve as a comparison. I then fit a Bayesian proportional hazards model using Stancode drawn from the online guide. That model works fine, but it has a key limitation that I try to address, first partially successfully, and then pretty successfully. This post walks through all these steps.

Simulating an RCT with time-to-event outcomes

Here are the R packages that are used in this post:

library(simstudy)
library(data.table)
library(survival)
library(cmdstanr)

And here are the data definitions for a two-arm randomized controlled trial that is stratified by a variable \(M\). Both the treatment \(A\) and covariate \(M\) are associated with the time-to-event outcome, as specified in defS. On average, the treatment \(A\) speeds up the time-to-event, and \(M\) slows things down. (In simstudy survival times are generated using a Weibull data generation process.)

defI <-
  defData(varname = "M", formula = 0.3, dist = "binary") |>
  defData(varname = "A", formula = "1;1", variance = "M", dist = "trtAssign")

defS <- 
  defSurv(
    varname = "timeEvent", 
    formula = "-11.6 + ..delta * A + ..beta_m * M",
    shape = 0.30)  |>
  defSurv(varname = "censorTime", formula = -11.3, shape = .35)

## Parameters

delta <- 1.5
beta_m <- -1.0

We are generating 1,000 independent observations:

set.seed(123)

dd <- genData(1000, defI)
dd <- genSurv(dd, defS, timeName = "tte", censorName = "censorTime", eventName = "event")
dd
## Key: <id>
## Index: <type>
##          id     M     A    tte event       type
##       <int> <int> <int>  <num> <num>     <char>
##    1:     1     0     0 26.974     1  timeEvent
##    2:     2     1     0 38.353     1  timeEvent
##    3:     3     0     0 32.836     1  timeEvent
##    4:     4     1     0 28.768     0 censorTime
##    5:     5     1     0 54.366     1  timeEvent
##   ---                                          
##  996:   996     1     1 11.012     0 censorTime
##  997:   997     0     0 15.420     1  timeEvent
##  998:   998     0     0 21.212     1  timeEvent
##  999:   999     1     0 41.153     0 censorTime
## 1000:  1000     0     1 25.659     1  timeEvent

Here is a Kaplan-Meier plot showing the “survival” times for each level of \(M\) and each treatment arm:

Fitting a traditional Cox proportional hazards model, we can see that the log hazard ratio for treatment \(A\) is greater than 0, suggesting that on average the time-to-event is shorter for those in the treatment arm. Likewise, those with \(M=1\) have longer times-to-events and the log hazard ratio is less than zero:

cox_model <- coxph(Surv(tte, event) ~ A + M, data = dd)
summary(cox_model)
## Call:
## coxph(formula = Surv(tte, event) ~ A + M, data = dd)
## 
##   n= 1000, number of events= 821 
## 
##       coef exp(coef) se(coef)      z Pr(>|z|)    
## A  1.44309   4.23374  0.08018  18.00   <2e-16 ***
## M -0.92537   0.39638  0.08302 -11.15   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
##   exp(coef) exp(-coef) lower .95 upper .95
## A    4.2337     0.2362    3.6181    4.9542
## M    0.3964     2.5228    0.3369    0.4664
## 
## Concordance= 0.695  (se = 0.009 )
## Likelihood ratio test= 415.5  on 2 df,   p=<2e-16
## Wald test            = 389  on 2 df,   p=<2e-16
## Score (logrank) test = 423.4  on 2 df,   p=<2e-16

First Bayes model

As I mentioned before, I turned to Stan documentation for the code that follows. I won’t go into the detailed derivation of the partial likelihood here since that is covered very nicely in the document. However, it is useful to see the final likelihood that is then reflected in the code.

The likelihood is written as follows (and note that the \(j\)’s represent only the cases with observed times, while \(j'\) include censored cases):

\[ L(\boldsymbol{\beta})= \prod_{j=1}^{N^{obs}} \left( \frac{\exp(\mathbf{x}_j^\top \boldsymbol{\beta})}{\sum_{j'=j}^{N} \exp(\mathbf{x}_{j'}^\top \boldsymbol{\beta})}\right) \]

where:

  • \(N^{obs}\) is the number of observed times,
  • \(N\) is the total number of observations (including censored),
  • \(\mathbf{x}_i\) is the vector of covariates for the \(i\)-th observation,
  • \(\boldsymbol{\beta}\) is the vector of coefficients,

We want the log likelihood, which transforms the product to a sum of logs:

\[ \begin{aligned} \log Pr[\text{obs. fails ordered } 1, \dots N^{obs}|\mathbf{x}, \boldsymbol{\beta}] &= \sum_{j=1}^{N^{obs}} \log \left( \frac{\exp(\mathbf{x}_j^\top \boldsymbol{\beta})}{\sum_{j'=j}^{N} \exp(\mathbf{x}_{j'}^\top \boldsymbol{\beta})}\right) \\ \\ &= \sum_{j=1}^{N^{obs}} \left(\mathbf{x}_j^\top \boldsymbol{\beta} - \log \sum_{j'=j}^{N} \exp(\mathbf{x}_{j'}^\top \boldsymbol{\beta})\right) \end{aligned} \]

In Stan, the function logSumExp can be used to efficiently calculate

\[ \log \sum_{j'=j}^{N} \exp(\mathbf{x}_{j'}^\top \boldsymbol{\beta}). \]

This partial likelihood is implemented below in Stan. One confusing aspect (at least to me) is the way censoring is handled. Essentially, all event times for censored cases are assumed to occur after the last observed time. That is, all censored cases are part of the risk set for observed events, something I had not seen before. This is a pretty big assumption and has implications for data where the actual censoring times occur before the last observed event time.

The code might be a little confusing, because the data are delivered to Stan in reverse order. It is done this way to make calculation of the log likelihood more efficient. If you are trying to follow along with the code to see how it lines up with the equations above, keep this in mind.

stan_code <-
"
data {
  int<lower=0> K;          // num covariates

  int<lower=0> N;          // num uncensored obs
  vector[N] t;             // event time (non-strict decreasing)
  matrix[N, K] x;          // covariates for uncensored obs

  int N_c;                 // num censored obs
  real <lower=t[N]> t_c;   // censoring time
  matrix[N_c, K] x_c;      // covariates for censored obs
}

parameters {
  vector[K] beta;          // slopes (no intercept)
}

transformed parameters {
  vector[N] log_theta = x * beta;
  vector[N_c] log_theta_c = x_c * beta;
}

model {
  beta ~ normal(0, 4);
  
  real log_denom = log_sum_exp(log_theta_c);
  
  for (n in 1:N) {
    log_denom = log_sum_exp(log_denom, log_theta[n]);
    target += log_theta[n] - log_denom;   // log likelihood
  }
  
}
"

This code prepares the R data for Stan:

dd.o <- dd[event == 1]
setorder(dd.o, -tte)
x.o <- data.frame(dd.o[, .(A, M)])
N.o <- dd.o[, .N]
t.o <- dd.o[, tte]

dd.c <- dd[event == 0]
setorder(dd.c, -tte)
x.c <- data.frame(dd.c[, .(A, M)])
N.c <- dd.c[, .N]
t.c <- dd.c[, tte]

K <- ncol(x.o)          # num covariates

stan_data <- list(
  K = K,
  N = N.o,
  t = t.o,
  x = x.o,
  N_c = N.c,
  t_c = max(t.c),
  x_c = x.c
)

I’m using cmdstanr to interface with Stan. First we compile the Stan code.

stan_model <- cmdstan_model(write_stan_file(stan_code))

And then we fit the model. Even with 1,000 observations, the model estimates in just a couple of seconds on my laptop.

fit <- stan_model$sample(
  data = stan_data, 
  iter_warmup = 1000,
  iter_sampling = 4000,
  chains = 4,
  parallel_chains = 4,
  refresh = 0
)
## Running MCMC with 4 parallel chains...
## 
## Chain 3 finished in 1.7 seconds.
## Chain 1 finished in 1.9 seconds.
## Chain 4 finished in 1.8 seconds.
## Chain 2 finished in 1.9 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 1.8 seconds.
## Total execution time: 2.0 seconds.

Looking at the log hazard ratios, something seems awry. The Bayesian estimates are attenuated relative to the original Cox PH estimates, and this is not due to the prior distribution assumption. Rather, it is the result of assuming that all censored times are longer than the longest observed time-to-event. I’m not showing this here, but the attenuation does largely go away if there is no censoring.

fit$summary(variables = "beta")
## # A tibble: 2 × 10
##   variable   mean median     sd    mad     q5    q95  rhat ess_bulk ess_tail
##   <chr>     <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl> <dbl>    <dbl>    <dbl>
## 1 beta[1]   1.03   1.03  0.0723 0.0720  0.916  1.15   1.00   14541.   11205.
## 2 beta[2]  -0.588 -0.588 0.0802 0.0802 -0.721 -0.455  1.00   13062.   10867.

Given this limitation, I decided to try to implement an algorithm that accommodates dynamic risk sets, effectively taking censored cases out of the analysis as soon as they are censored. This is what the coxph model estimated above does.

Second Bayes model

The partial likelihood for my alternative Cox proportional hazards model is given by:

\[ L(\boldsymbol{\beta}) = \prod_{i=1}^{N} \left( \frac{\exp(\mathbf{x}_i^\top \boldsymbol{\beta})}{\sum_{j \in R(t_i)} \exp(\mathbf{x}_j^\top \boldsymbol{\beta})} \right)^{\delta_i} \]

where:

  • \(N\) is the number of observations (censored or not),
  • \(\mathbf{x}_i\) is the vector of covariates for the \(i\)-th observation,
  • \(\boldsymbol{\beta}\) is the vector of coefficients,
  • \(t_i\) is the observed time for the \(i\)-th observation,
  • \(R(t_i)\) is the risk set at time \(t_i\) (the set of individuals still at risk just before time \(t_i\)) - which only includes censored cases censored after \(t_i\),
  • \(\delta_i\) is the event indicator (\(\delta_i = 1\) if the event occurred, \(\delta_i = 0\) if censored).

The Stan code below implements the log likelihood that follows from this. In contrast to the first version, the data is passed to Stan in ascending order. The one major complication is that I needed to create a search function in order to define the “risk” set. (Actually, I asked DeepSeek to do this for me.) The fundamental difference between this version and the first is the calculation of the denominator in the log likelihood.

stan_code <-
"
functions {
  int binary_search(vector v, real tar_val) {
    int low = 1;
    int high = num_elements(v);
    int result = -1;

    while (low <= high) {
      int mid = (low + high) %/% 2;
      if (v[mid] == tar_val) {
        result = mid; // Store the index
        high = mid - 1; // Look for earlier occurrences
      } else if (v[mid] < tar_val) {
        low = mid + 1;
      } else {
        high = mid - 1;
      }
    }
    return result;
  }
}

data {
  int<lower=0> K;          // Number of covariates

  int<lower=0> N_o;        // Number of uncensored observations
  vector[N_o] t_o;         // Event times (sorted in decreasing order)
  matrix[N_o, K] x_o;      // Covariates for uncensored observations

  int<lower=0> N;          // Number of total observations
  vector[N] t;             // Individual times
  matrix[N, K] x;          // Covariates for all observations
}

parameters {
  vector[K] beta;          // Fixed effects for covariates
}

model {
  
  // Prior

  beta ~ normal(0, 4);
  
  // Model

  vector[N] log_theta = x * beta;

  for (n_o in 1:N_o) {
    int start_risk = binary_search(t, t_o[n_o]); // Use binary search
    real log_denom = log_sum_exp(log_theta[start_risk:N]);
    target += log_theta[start_risk] - log_denom;
  }

}
"

Preparing the data is a little different. This time, I am passing the observed data and the full data, both in ascending order:

dx <- copy(dd)
setorder(dx, tte)

dx.o <- dx[event == 1]
x_o <- data.frame(dx.o[, .(A, M)])
N_o <- dx.o[, .N]
t_o <- dx.o[, tte]

x_all <- data.frame(dx[, .(A, M)])
N_all <- dx[, .N]
t_all <- dx[, tte]

K <- ncol(x_o)          # num covariates

stan_data <- list(
  K = K,
  N_o = N_o,
  t_o = t_o,
  x_o = x_o,
  N = N_all,
  t = t_all,
  x = x_all
)

stan_model <- cmdstan_model(write_stan_file(stan_code))

fit <- stan_model$sample(
  data = stan_data,
  iter_warmup = 1000,
  iter_sampling = 4000,
  chains = 4,
  parallel_chains = 4,
  refresh = 0
)
## Running MCMC with 4 parallel chains...
## 
## Chain 2 finished in 57.9 seconds.
## Chain 4 finished in 67.2 seconds.
## Chain 1 finished in 67.8 seconds.
## Chain 3 finished in 67.9 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 65.2 seconds.
## Total execution time: 68.1 seconds.
fit$summary(variables = "beta")
## # A tibble: 2 × 10
##   variable   mean median     sd    mad    q5    q95  rhat ess_bulk ess_tail
##   <chr>     <dbl>  <dbl>  <dbl>  <dbl> <dbl>  <dbl> <dbl>    <dbl>    <dbl>
## 1 beta[1]   1.44   1.43  0.0801 0.0801  1.31  1.57   1.00   12604.   10436.
## 2 beta[2]  -0.925 -0.924 0.0816 0.0817 -1.06 -0.791  1.00   12492.   10436.

Two things to note about this model. First, it appears that estimates are spot on! They mirror the estimates from the coxph model using the survival package. That is encouraging. Second, the implementation is very inefficient, taking more than a minute to run! This is less encouraging, and does not bode well for a more complex model that incorporates random effects and splines.

Final Bayes model

I asked ChatGPT this time to see if it could make my code more efficient. (I’ve been comparing ChatGPT and DeepSeek - both have been pretty impressive.) It recognized that my initial brute force approach was calculating each denominator anew for each observed evaluation. This is highly inefficient, on the order of \(O(N^2)\) (Unprompted, ChatGPT provided me with this information). The algorithm is reconfigured so that the denominators are pre-calculated - starting with the last time point (censored or observed), similar to the first approach. It turns out this is much more efficient with \(O(N)\).

stan_code <-
"
...

model {
  
  // Prior
  
  beta ~ normal(0, 4);
  
    // Likelihood
  
  vector[N] theta = x * beta;
  vector[N] log_sum_exp_theta;
  
  // Compute cumulative sum of exp(theta) in log space
  
  log_sum_exp_theta[N] = theta[N]; // Initialize the last element
  
  for (i in tail(sort_indices_desc(t), N-1)) {
    log_sum_exp_theta[i] = log_sum_exp(theta[i], log_sum_exp_theta[i + 1]);
  }

  for (n_o in 1:N_o) {
    int start_risk = binary_search(t, t_o[n_o]); // Use binary search
    real log_denom = log_sum_exp_theta[start_risk];
    target += theta[start_risk] - log_denom;
  }
}
"

The data requirements for this are the same as the second model, so no changes are needed there.

stan_model <- cmdstan_model(write_stan_file(stan_code))

fit <- stan_model$sample(
  data = stan_data,
  iter_warmup = 1000,
  iter_sampling = 4000,
  chains = 4,
  parallel_chains = 4,
  refresh = 0
)
## Running MCMC with 4 parallel chains...
## 
## Chain 3 finished in 2.2 seconds.
## Chain 2 finished in 2.4 seconds.
## Chain 4 finished in 2.4 seconds.
## Chain 1 finished in 2.7 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 2.4 seconds.
## Total execution time: 2.8 seconds.
fit$summary(variables = "beta")
## # A tibble: 2 × 10
##   variable   mean median     sd    mad    q5    q95  rhat ess_bulk ess_tail
##   <chr>     <dbl>  <dbl>  <dbl>  <dbl> <dbl>  <dbl> <dbl>    <dbl>    <dbl>
## 1 beta[1]   1.44   1.44  0.0800 0.0794  1.30  1.57   1.00   12036.   10499.
## 2 beta[2]  -0.925 -0.925 0.0831 0.0834 -1.06 -0.789  1.00   11887.   10386.

The model also works well, as the estimate is the same as the previous Bayesian model and the coxph model. More importantly, the computation time is reduced considerably, to about 3 seconds. My hope is that this final model is flexible enough to handle the extensions I need for the data structure that sparked this whole exploration.

To leave a comment for the author, please follow the link and comment on their blog: ouR data generation.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)