1 Introduction

Tutorials in survival models for recurrent events are scarce with limited content on the whole workflow using Bayesian inference in R. Hence, I developed a Bayesian workflow for survival analysis, covering model specification, priors, visualization, validation and even deployment in Shiny.

This tutorial extends Brilleman et al.,[1] vignettes to recurrent events scenario. Regression models were developed in R using survival, rstanarm, and ggplot packages. Bayesian inference used Hamiltonian Monte Carlo No-U-Turn Sampler (NUTS) with 10 chains, 2000 iterations (50% warmup, 50% inference) and 10 degrees of freedom \((\delta=10)\) in the hazard function.

Briefly, survival Analysis is a branch of statistics that deals with analyzing the expected duration of time until one or more events happen. These events could be death in placebo vs treatment trial, mechanical failure in engines, or departure in customer churn analysis. The name “Survival Analysis” might sound a bit grim, but it’s not always about life and death. It’s about understanding the ‘lifespan’ of a subject in a system. For example, it can be used to predict when a machine part might fail, or how long a user might stay subscribed to a service. Here are some key points to understand about Survival Analysis:

Censoring: One of the unique aspects of Survival Analysis is its ability to handle ‘censoring’. Censoring occurs when we have incomplete information about the time of the event. For example, if we are studying the lifespan of a group of people, some of them might still be alive at the end of the study. Their exact lifespans are unknown or ‘censored’, but Survival Analysis can still use this information.

Survival Function: This function estimates the probability that a subject survives longer than a certain amount of time.

Hazard Function: This function estimates the probability that an event occurs at a certain time, given that the subject has survived up to that time.

Applications: Survival Analysis has wide applications in various fields such as medical research, engineering, economics, social sciences, and even customer analytics in business.

Remember, the goal of Survival Analysis is not just to predict when an event will happen, but to understand the underlying factors that influence the timing of the event. It’s a powerful tool in the statistician’s toolbox, helping us make sense of the complex, uncertain world around us.

2 Requirements

For this analysis, we will use the survival feature branch of rstanarm. Today, it has not yet fully merged into the main dev branch, but it is totally useful for survival analyses. For the use of more complex hazard functions (Splines-based). I recommend you install splines2 package. The full list of packages are listed below

# installation
devtools::install_github("stan-dev/rstanarm", ref = "feature/survival", build_vignettes = FALSE)
devtools::install.packages("splines2)   

# required 
library(rstanarm)
library(survival)
library(tidyverse)
library(patchwork)
library(bayesplot)
library(loo)
library(gtsummary)
library(data.table)

2.1 Building time-dependent sets with tmerge

data<-cgd0
cgd0[1:4,]
##   id center random treat sex age height weight inherit steroids propylac
## 1  1    204  82888     1   2  12    147   62.0       2        2        2
## 2  2    204  82888     0   1  15    159   47.5       2        2        1
## 3  3    204  82988     1   1  19    171   72.7       1        2        1
## 4  4    204  91388     1   1  12    142   34.0       1        2        1
##   hos.cat futime etime1 etime2 etime3 etime4 etime5 etime6 etime7
## 1       2    414    219    373     NA     NA     NA     NA     NA
## 2       2    439      8     26    152    241    249    322    350
## 3       2    382     NA     NA     NA     NA     NA     NA     NA
## 4       2    388     NA     NA     NA     NA     NA     NA     NA

Data are from the famous controlled trial in chronic granulotomous disease (CGD). It contains 203 observations on time to serious infections observed through end of study for each patient. Recurrent events and covariates for each patient were encoded as time intervals between events. [2] For example, patient 1 was followed for 60 days and had infection events on days 1, 24 and 37, patient 2 had 7 events and patient 3 had one event on day one, and all patients were censored on day 439.

data2<-tmerge(
  cgd0[, 1:13],
  cgd0,
  id = id,
  tstop = futime,
  infect = event(etime1),
  infect = event(etime2),
  infect = event(etime3),
  infect = event(etime4),
  infect = event(etime5),
  infect = event(etime6),
  infect = event(etime7)
)

data2 <- tmerge(data2, data2, id= id, enum = cumtdc(tstart))

2.2 Survival baseline: Cox Proportional Hazard

f.null<-formula(Surv(tstart, tstop, infect) ~ 1.0)
f.full<-formula(Surv(tstart, tstop, infect) ~ treat + inherit + steroids)
## Call:
## coxph(formula = f.full, data = data2, cluster = id)
## 
##   n= 203, number of events= 76 
## 
##             coef exp(coef) se(coef) robust se      z Pr(>|z|)    
## treat    -1.0722    0.3422   0.2619    0.3118 -3.438 0.000585 ***
## inherit   0.1777    1.1944   0.2356    0.3180  0.559 0.576395    
## steroids -0.7726    0.4618   0.5169    0.4687 -1.648 0.099310 .  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
##          exp(coef) exp(-coef) lower .95 upper .95
## treat       0.3422     2.9219    0.1857    0.6306
## inherit     1.1944     0.8372    0.6404    2.2278
## steroids    0.4618     2.1653    0.1843    1.1573
## 
## Concordance= 0.652  (se = 0.04 )
## Likelihood ratio test= 22.49  on 3 df,   p=5e-05
## Wald test            = 16.81  on 3 df,   p=8e-04
## Score (logrank) test = 22.7  on 3 df,   p=5e-05,   Robust = 10.44  p=0.02
## 
##   (Note: the likelihood ratio and score tests assume independence of
##      observations within a cluster, the Wald and robust score tests do not).
##            df      AIC
## m.cox.null  0 684.2894
## m.cox.full  3 667.7961

3 Bayes Inference

Probability for infection event was calculated under a Bayesian survival analysis framework using a spline-based (M-spline) hazard regression model. [2] The M-Spline hazard function is defined as: \[ \begin{aligned} h(i)=\sum_{l=1}^{L} \gamma_{l}M_{l}(t;k;\delta)exp(\eta_{i}(t))\\ \end{aligned} \] Where \(h_i(t)\) is the hazard of the event for individual i with \(\eta_{i}\) time-dependent predictors at time \(t\). \(l^{th} (l=1,…,L)\) denotes the basis term for a degree \(\delta\) M-spline function evaluated at a vector of knot locations \(k=\{ k_1,…,k_J \}\) and \(\gamma_l\) denotes the \(l^{th}\) M-spline coefficient. For our example, we estimated hazard ratios (HRs) and survival probability curves between treated and untreated patients.

We start with no prior knowledge (default):

# when chains>1 r makes use of viewer
CHAINS <- 10
CORES <- 10
ITER <- 2000
SEED <- 42
# draw from the prior predictive distribution of the stan_surv survival model
prior.stan.cgd <- stan_surv(
  formula = f.full,
  data = data2,
  basehaz = "exp",
  prior_PD = TRUE,
  chains = CHAINS,
  cores = CORES,
  iter = ITER,refresh=2000,
  seed = SEED)

Let’s use more appropriate priors:

prior.stan.cgd2 <- update(prior.stan.cgd,
                            prior_intercept = normal(0, 1),
                            prior = normal(0, .5))
print(prior.stan.cgd2, digits = 3)
## stan_surv
##  baseline hazard: exponential
##  formula:         Surv(tstart, tstop, infect) ~ treat + inherit + steroids
##  observations:    203
##  events:          76 (37.4%)
##  right censored:  127 (62.6%)
##  delayed entry:   yes
## ------
##             Median MAD_SD exp(Median)
## (Intercept) -6.219  1.653     NA     
## treat        0.004  0.501  1.004     
## inherit     -0.002  0.492  0.998     
## steroids     0.004  0.518  1.004     
## 
## ------
## * For help interpreting the printed output see ?print.stanreg
## * For info on the priors used see ?prior_summary.stanreg
#Compare them
mcmc_intervals(prior.stan.cgd)

mcmc_intervals(prior.stan.cgd2)

3.1 Sampling exp+mspline

# Null
fit.stan.cgd.exp.f.null <- update(prior.stan.cgd2,  
                             prior_PD = FALSE,
                             formula=f.null,
                             basehaz = "exp")
# cubic m-spline tstart2, tstop2
fit.stan.cgd.ms10.f.null <- update(fit.stan.cgd.exp.f.null,
                        basehaz = "ms",
                        basehaz_ops = list(df = 10))
# Full
fit.stan.cgd.exp.f.full <- update(prior.stan.cgd2,  
                             prior_PD = FALSE,
                             formula=f.full,
                             basehaz = "exp")
fit.stan.cgd.ms10.f.full <- update(fit.stan.cgd.exp.f.full,
                        basehaz = "ms",
                        basehaz_ops = list(df = 10))
fits_stan <- list("exp.f.null" = fit.stan.cgd.exp.f.null,
                  "exp.f.full" = fit.stan.cgd.exp.f.full,
                  "ms10.f.null" = fit.stan.cgd.ms10.f.null,
                  "ms10.f.full" = fit.stan.cgd.ms10.f.full

                  )
print(fit.stan.cgd.exp.f.full, digits = 3)
## stan_surv
##  baseline hazard: exponential
##  formula:         Surv(tstart, tstop, infect) ~ treat + inherit + steroids
##  observations:    203
##  events:          76 (37.4%)
##  right censored:  127 (62.6%)
##  delayed entry:   yes
## ------
##             Median MAD_SD exp(Median)
## (Intercept) -5.527  0.815     NA     
## treat       -0.825  0.225  0.438     
## inherit      0.164  0.214  1.178     
## steroids    -0.296  0.384  0.744     
## 
## ------
## * For help interpreting the printed output see ?print.stanreg
## * For info on the priors used see ?prior_summary.stanreg
print(fit.stan.cgd.ms10.f.full, digits = 3)
## stan_surv
##  baseline hazard: M-splines on hazard scale
##  formula:         Surv(tstart, tstop, infect) ~ treat + inherit + steroids
##  observations:    203
##  events:          76 (37.4%)
##  right censored:  127 (62.6%)
##  delayed entry:   yes
## ------
##                  Median MAD_SD exp(Median)
## (Intercept)       0.682  0.868     NA     
## treat            -0.874  0.232  0.417     
## inherit           0.169  0.217  1.185     
## steroids         -0.318  0.403  0.728     
## m-splines-coef1   0.051  0.026     NA     
## m-splines-coef2   0.032  0.029     NA     
## m-splines-coef3   0.035  0.029     NA     
## m-splines-coef4   0.086  0.044     NA     
## m-splines-coef5   0.077  0.046     NA     
## m-splines-coef6   0.068  0.047     NA     
## m-splines-coef7   0.262  0.079     NA     
## m-splines-coef8   0.132  0.100     NA     
## m-splines-coef9   0.130  0.103     NA     
## m-splines-coef10  0.052  0.052     NA     
## 
## ------
## * For help interpreting the printed output see ?print.stanreg
## * For info on the priors used see ?prior_summary.stanreg

Further information on calculating the hazard curve and the RMST can be found in [3-6].

3.2 Posterior uncertainty intervals

m.cox.full%>%
  tbl_regression(exponentiate = T)
Characteristic HR1 95% CI1 p-value
treat 0.34 0.19, 0.63 <0.001
inherit 1.19 0.64, 2.23 0.6
steroids 0.46 0.18, 1.16 0.10
1 HR = Hazard Ratio, CI = Confidence Interval
mcmc_post_ci(fit.stan.cgd.exp.f.full,.95,4)
##                           HR
## treat    0.44 (0.28 to 0.67)
## inherit  1.18 (0.78 to 1.79)
## steroids 0.74 (0.36 to 1.63)
mcmc_post_ci(fit.stan.cgd.ms10.f.full,.95,4)
##                           HR
## treat    0.42 (0.26 to 0.65)
## inherit  1.18 (0.76 to 1.81)
## steroids 0.73 (0.34 to 1.68)

3.3 Hazard curves

plots <- map(fits_stan,plot)

a<-plots[[2]]+
  labs(title = "Constant (exp)")+
  coord_cartesian(ylim = c(0,.1))+
  theme(plot.title = element_text(hjust = .5))

b<-plots[[4]]+labs(title = "M-splines  (df=10)")+
  coord_cartesian(ylim = c(0,.1))+
  theme(plot.title = element_text(hjust = .5))
a+b

3.4 Survival curves: COX PH vs Bayes

data_test <- data.frame(
  id = 1:2,
  treat = c(0, 1),
  inherit = c(1, 1),
  steroids = c(1, 1)
)
ndraws=1000

##### Constant (exponential) 
psb<-posterior_survfit(fit.stan.cgd.exp.f.full,
                      newdata = data_test,
                      times = 0,
                      extrapolate   = T, 
                      condition     = FALSE,
                      return_matrix = F,
                      control = list(edist = 439),
                      draws = ndraws)

psb<-psb %>% 
  left_join(data_test,by="id")
#tidybayes does  not work with posterior_survfit yet
b<-psb %>% as_tibble()%>%
  ggplot(aes(x=time,y=median,col=factor(treat),fill=factor(treat))) +
  # scale_x_continuous(breaks=c(50,100,150,200,250,300,350,400,439))+
  geom_ribbon(aes(ymin = ci_lb, ymax = ci_ub),size=0.1,alpha=0.1) +
  geom_line()+labs(x="Time (days)",y="",subtitle="Bayesian constant",col="Treatment",fill="Treatment")+
  theme_minimal()+theme(legend.position = "none")
# +add_knots(fit.stan.ms10)
####M-spline
psc<-posterior_survfit(fit.stan.cgd.ms10.f.full,
                      newdata = data_test,
                      times = 0,
                      extrapolate   = T, 
                      condition     = FALSE,
                      return_matrix = F,
                      control = list(edist = 439),
                      draws = ndraws)

psc<-psc %>% 
  left_join(data_test,by="id")
#tidybayes does  not work with posterior_survfit yet
c<-psc %>% as_tibble()%>%
  ggplot(aes(x=time,y=median,col=factor(treat),fill=factor(treat))) +
  # scale_x_continuous(breaks=c(50,100,150,200,250,300,350,400,439))+
  geom_ribbon(aes(ymin = ci_lb, ymax = ci_ub),size=0.1,alpha=0.1) +
  geom_line()+labs(x="Time (days)",y="",subtitle="Bayesian M-spline",col="Treatment",fill="Treatment")+
  theme_minimal()+theme(legend.position = "none")
# +add_knots(fit.stan.ms10)

#Lets compare with cox
ps2<-survfit(m.cox.full, newdata = data_test)
ps2cox<-data.frame(time=rep(ps2$time,2),
                   median=c(ps2$surv[,1],ps2$surv[,2]),
                   ci_lb=c(ps2$lower[,1],ps2$lower[,2]),
                   ci_ub=c(ps2$upper[,1],ps2$upper[,2]),
                   treat = rep(c(0, 1),each=length(ps2$time)),
                   inherit = 1,
                   steroids = 1
                   )
a<-ps2cox %>%
  ggplot(aes(x=time,y=median,col=factor(treat),fill=factor(treat))) +
  geom_ribbon(aes(ymin = ci_lb, ymax = ci_ub),size=0.1,alpha=0.1) +
  geom_line()+labs(x="Time (days)",y="Probability of Survival Free\n of Infection",subtitle="COX PH",col="Treatment",fill="Treatment")+
  theme_minimal()+theme(legend.position = "bottom")

legend = get_legend(a)
a<-a+theme(legend.position = "none")

(a+b+c)/(plot_spacer()+legend+plot_spacer())

3.5 Beautiful survival curves in Bayes

For publishing purposes, survival plots require additional tweaking in ggplot. I have taken inspiration from survminer package that makes a wonderful paper-like plots for cox ph models.[7] Let’s do the same but for our rstanarm model:

annoTextS=4
cbPalette <- c("#4DBBD5","#E64B35")
grid <- seq(0,439,by=100)

data_test <- data.frame(
  id = 1:2,
  treat = c(0, 1),
  inherit = c(1, 1),
  steroids = c(1, 1)
) %>%
  mutate(Strata =ifelse(treat==0, "Untreated", "Treated"))
ndraws=1000
# already collapsed
ps<-posterior_survfit(fit.stan.cgd.ms10.f.full,
                      newdata = data_test,
                      times = 0,
                      extrapolate   = T, 
                      condition     = FALSE,
                      return_matrix = F,
                      control = list(edist = 439),
                      draws = ndraws)
ps<-ps %>% 
  left_join(data_test,by="id")
# prepare HR annotations
text.df<-data.frame(
  x=c(0),
  y=c(0.05),
  label=c("HR Treated=0.44 (0.28 to 0.67)")
)
################ survival curves
a<-ps %>%as_tibble()%>%
  ggplot(aes(x=time,y=median,col=Strata)) +
  geom_ribbon(aes(ymin = ci_lb, ymax = ci_ub,fill=Strata), 
              # fill = "gray90",
              alpha=0.2,
              size=0.0) +
  geom_line()+
  scale_color_manual(values=cbPalette)+
  scale_fill_manual(values = cbPalette)+
  scale_x_continuous(breaks = grid)+
  labs(x="",y="Probability of Survival Free\n of Infection",col="Strata")+
  survminer::theme_survminer(base_family = "Times New Roman")
a<-a+annotate(geom="text",x=text.df$x,y=text.df$y,label=text.df$label,
           size=annoTextS,
           hjust=0,family="Times New Roman")+
  theme(legend.position = "right",
        text=element_text(family="Times New Roman"),
        plot.margin = unit(c(0,0,0,0), "cm"))
#obtain legend object
legend = get_legend(a)
# a<-a+theme(legend.position = "none")
################ Risk table as ggplot element
datatr <- fit.stan.cgd.ms10.f.full$data %>% ungroup() %>%
  mutate(Strata = factor(
    ifelse(treat==0, "Untreated", "Treated")
  ))
summary(datatr$Strata)
##   Treated Untreated 
##        83       120
patients<-datatr %>% 
  group_by(id) %>% 
  arrange(tstart) %>% 
  slice_head()
riskcounts.df<-rbind(
  RiskSetCount(grid,patients,strataoi ="Untreated"),
  RiskSetCount(grid,patients,strataoi ="Treated")
    )

tabrisk<-ggplot(riskcounts.df, aes(x = time,y = factor(strata),
  label = as.character(value)
  ))  +
  geom_text(size = 4,family = "Times New Roman")+
  coord_cartesian(xlim = c(0,439))+
  scale_x_continuous(breaks=grid)+
  scale_y_discrete(limits=rev(c(
    "Untreated",
    "Treated"
  )),labels=c("",""))+
  labs(x="Time (months)",y="Strata",subtitle = "Number at risk")+
  survminer::theme_survminer(base_family = "Times New Roman")+
  theme(legend.position = "none",
        text = element_text(family = "Times New Roman"),
        axis.text.y = element_text( hjust = 1 ),
        axis.ticks.y = element_line(size  = 2,colour = rev(cbPalette)),
        axis.ticks.length.y = unit(15, "pt"),
        plot.margin = unit(c(0,0,0,0), "cm"))

(a<-a / tabrisk+plot_layout(ncol=1,heights = c(3,1)))

3.6 Time to recurrence infections

Complementary to HR, we quantify time to recurrence (TTR) using the difference in the Restricted Mean Survival Time (RMST). Clinically, RMST is defined as the average event-free survival time among a population up to a fixed clinically important follow-up time \((\tau)\). [5-6] It is estimated using the area under the curve between start of the follow-up \((t=0)\) and a particular time horizon (\(t=\tau\)). The RMST is denoted by \(\rho(\tau)\) and approximated using a \(15\)-points Gauss-Kronrod quadrature as:

\[ \begin{aligned} \rho(\tau)\approx\frac{\tau}{2}\sum_{i=1}^{15}w_{i}S(\frac{\tau}{2}+\frac{\tau}{2}\chi_i)\\ \end{aligned} \]

tau <- c(180, 365)

rmst.trt <-
  map(tau,
      ~ rmst_check_plot(
        fit.stan.cgd.ms10.f.full,
        data_test,
        tau = .
      ))
#number of digits for the table
ndig=1

rmst.table={}
for(i in 1:length(tau)) {
  treated=paste0(
    round(median(rmst.trt[[i]][[1]]$rmstA),ndig),
    " (",
    round(quantile(rmst.trt[[i]][[1]]$rmstA,prob=c(0.025)),ndig),
    " to ",
    round(quantile(rmst.trt[[i]][[1]]$rmstA,prob=c(0.975)),ndig),
    ")"
  )
  notreated=paste0(
    round(median(rmst.trt[[i]][[1]]$rmstB),ndig),
    " (",
    round(quantile(rmst.trt[[i]][[1]]$rmstB,prob=c(0.025)),ndig),
    " to ",
    round(quantile(rmst.trt[[i]][[1]]$rmstB,prob=c(0.975)),ndig),
    ")"
  )
  diff=paste0(
    round(median(rmst.trt[[i]][[1]]$diffA_B),ndig),
    " (",
    round(quantile(rmst.trt[[i]][[1]]$diffA_B,prob=c(0.025)),ndig),
    " to ",
    round(quantile(rmst.trt[[i]][[1]]$diffA_B,prob=c(0.975)),ndig),
    ")"
  )
  obs=data.frame(tau=tau[i],
                   RMST.A=treated,
                 RMST.B=notreated,
                 RMST.diff=diff
                 )
  rmst.table<-rbind(rmst.table,obs)
}
rmst.table%>% 
  kableExtra::kbl() %>% 
  kableExtra::kable_paper("hover",full_width=F) 
tau RMST.A RMST.B RMST.diff
180 141.8 (105.8 to 163.5) 162.8 (141.2 to 173.3) -20.3 (-41.8 to -7.5)
365 222.7 (140.8 to 292.3) 292.8 (225.6 to 333.8) -67.4 (-112.8 to -29.3)
# join all 3 measures
rmst.trt.gg<-rbind(
    rmst.trt[[1]][[1]],
    rmst.trt[[2]][[1]]
    )
rmst.trt.gg$tau<-factor(as.character(rmst.trt.gg$tau),levels=c("180","365"))
# wide to long for easy manipulation
rmst.trt.gg<-gather(rmst.trt.gg,condition,time,rmstA:ratioA_B)

a<-ggplot() +
  geom_point(
    data = rmst.trt.gg %>% filter(condition %in% c("rmstA", "rmstB")),
    aes(x = tau, y = time, group = condition,col=condition),
    position = "jitter",
    alpha = 0.05,
    size = 0.01
  ) +
  scale_color_manual(values = c("red", "blue"),labels=c("No","Yes"))+
  geom_boxplot()+
  stat_summary(
               fun=mean,
               geom="line"
               )+
  labs(y = "Time-free of infection",col="Treatment") + 
  guides(colour = guide_legend(override.aes = list(size=10)))+
  theme_bw()
a

bayesplot_grid(
    a,
    rmst.trt[[1]]$p3,
    rmst.trt[[2]]$p3,
    # rmst.ar[[3]]$p3,
    grid_args = list(ncol = 2),
    # titles = paste0("RMST (tau=", tau, ")"),
    # subtitles = rep("with medians and 95% CI", 4)
    subtitles = c("Time-free evolution","Tau=180","Tau=365")
    )

On average, treated patients had a median of 20 and 67 additional days free of infection compared to non-treated patients, at 6 months and 1-year of follow-up.

4 Validation

To assess the fit of the regression models, we performed model comparison (M-Spline vs baseline exponential hazard) at different levels.

  • At computational level (MCMC mixing).

  • At hazard ratios level comparing estimates VS posteriors.

  • At prediction level (LOGO, WAIC and C-Index.

4.1 MCMC

Are chains mixing well?

color_scheme_set("mix-blue-red")
mcmc_trace(fit.stan.cgd.ms10.f.full, 
           pars=c("treat", "(Intercept)"),
           facet_args = list(ncol = 1, strip.position = "left")
           )

4.2 Checking hazard ratios (Estimates VS Posteriors)

How far is our distribution compared to the cox’s estimate? Comparing the three models, we can see M-spline-based is closer to the cox estimate.

# extract HR from classical coxph for arm=B
exp(coef(m.cox.full))[1]
##     treat 
## 0.3422474
base_cox_hr <- vline_at(exp(coef(m.cox.full))[1], color = "green")

a<-mcmc_hist(prior.stan.cgd2,
             pars = c("treat"),
             transformations = exp,
             binwidth = 0.001) + base_cox_hr+labs(subtitle="Priors")

b<-mcmc_hist(fit.stan.cgd.exp.f.full,
             pars = c("treat"),
             transformations = exp,
             binwidth = 0.001) + base_cox_hr+labs(subtitle = "Posterior (const)")
c<-mcmc_hist(fit.stan.cgd.ms10.f.full,
             pars = c("treat"),
             transformations = exp,
             binwidth = 0.001) + base_cox_hr+labs(subtitle = "Posterior (ms-10)")

a+b+c

4.3 LOGO and WAIC

Goodness of fit is examined by using a leave-one-out cross validation based on the expected log predictive density (elpd). According to Gelman et al.,[8] the lower the elpd score the better the model fit is. A leave-one-out is used to avoid over optimistic due to overfitting. We assume “leaving-out” an individual rather than an observation for both goodness of fit and calibration schemes.

post <- as.array(fit.stan.cgd.ms10.f.full) 
ids<-fit.stan.cgd.exp.f.full$data$id
chain_id=rep(1:dim(post)[2], each = dim(post)[1])
#### model exp null
#1. Get log likehood per infection event
myllk.exp.f.null<-log_lik(fit.stan.cgd.exp.f.null,merge_chains = F)
#2. Join llk by patient 
myllk2.exp.f.null<-llkByPatient(llk = myllk.exp.f.null,ids = ids)
## [1] 10000   203
## [1]   128 10000
# 3. Effective samples
reff.exp.f.null<-relative_eff(myllk2.exp.f.null,chain_id = chain_id,cores = 10)

#### model exp full
#1. Get log likehood per infection event
myllk.exp.f.full<-log_lik(fit.stan.cgd.exp.f.full,merge_chains = F)
#2. Join llk by patient 
myllk2.exp.f.full<-llkByPatient(llk = myllk.exp.f.full,ids = ids)
## [1] 10000   203
## [1]   128 10000
# 3. Effective samples
reff.exp.f.full<-relative_eff(myllk2.exp.f.full,chain_id = chain_id,cores = 10)

#model ms10 null
myllk.ms10.f.null<-log_lik(fit.stan.cgd.ms10.f.null)
myllk2.ms10.f.null<-llkByPatient(llk = myllk.ms10.f.null,ids = ids)
## [1] 10000   203
## [1]   128 10000
reff.ms10.f.null<-relative_eff(myllk2.ms10.f.null,chain_id = chain_id,cores = 10)

#model ms10 full
myllk.ms10.f.full<-log_lik(fit.stan.cgd.ms10.f.full)
myllk2.ms10.f.full<-llkByPatient(llk = myllk.ms10.f.full,ids = ids)
## [1] 10000   203
## [1]   128 10000
reff.ms10.f.full<-relative_eff(myllk2.ms10.f.full,chain_id = chain_id,cores = 10)

# Versus frequentist approaches
AIC(m.cox.null,m.cox.full)
##            df      AIC
## m.cox.null  0 684.2894
## m.cox.full  3 667.7961
#leave-one-out ELPD
compare(
  loo(myllk2.exp.f.null, r_eff = reff.exp.f.null),
  loo(myllk2.exp.f.full, r_eff = reff.exp.f.full),
  loo(myllk2.ms10.f.null, r_eff = reff.ms10.f.null),
  loo(myllk2.ms10.f.full, r_eff = reff.ms10.f.full)
)%>% 
  kableExtra::kbl() %>% 
  kableExtra::kable_paper("hover",full_width=F) 
elpd_diff se_diff elpd_loo se_elpd_loo p_loo se_p_loo looic se_looic
loo(myllk2.ms10.f.full, r_eff = reff.ms10.f.full) 0.000000 0.000000 -538.4634 76.34679 10.661322 2.1729935 1076.927 152.6936
loo(myllk2.exp.f.full, r_eff = reff.exp.f.full) -3.248304 3.663320 -541.7117 76.20418 5.078211 1.3992236 1083.423 152.4084
loo(myllk2.ms10.f.null, r_eff = reff.ms10.f.null) -7.332385 4.526938 -545.7958 77.89670 7.215401 1.3588574 1091.592 155.7934
loo(myllk2.exp.f.null, r_eff = reff.exp.f.null) -10.350965 6.078644 -548.8144 77.78945 2.013492 0.6077481 1097.629 155.5789
#leave-one-out WAIC
compare(
  waic(myllk2.exp.f.null, r_eff = reff.exp.f.null),
  waic(myllk2.exp.f.full, r_eff = reff.exp.f.full),
  waic(myllk2.ms10.f.null, r_eff = reff.ms10.f.null),
  waic(myllk2.ms10.f.full, r_eff = reff.ms10.f.full)
)%>% 
  kableExtra::kbl() %>% 
  kableExtra::kable_paper("hover",full_width=F) 
elpd_diff se_diff elpd_waic se_elpd_waic p_waic se_p_waic waic se_waic
waic(myllk2.ms10.f.full, r_eff = reff.ms10.f.full) 0.000000 0.000000 -538.3963 76.32857 10.594244 2.1479533 1076.793 152.6571
waic(myllk2.exp.f.full, r_eff = reff.exp.f.full) -3.277120 3.663874 -541.6734 76.19422 5.039949 1.3821537 1083.347 152.3884
waic(myllk2.ms10.f.null, r_eff = reff.ms10.f.null) -7.375539 4.535321 -545.7719 77.89051 7.191478 1.3520872 1091.544 155.7810
waic(myllk2.exp.f.null, r_eff = reff.exp.f.null) -10.416951 6.084981 -548.8133 77.78919 2.012400 0.6074468 1097.627 155.5784

In both, null and full scenarios, elpd is better in M-splines alternatives, of course full model is better than null as expected.

4.4 Harrel C-Index

We can also make use of Harrel C-index (also known as concordance index).[9] Let’s focus only on full models for the concordance metric.

ndraws=1000
data_test<-fit.stan.cgd.exp.f.full$data
data_test$coxlp.f.full<-predict(m.cox.full,newdata = data_test,"lp")
data_test$coxsurv.f.full<-predict(m.cox.full,newdata = data_test,"survival")

#loghaz refers to lp in bayes
data_test$explp.f.full<-posterior_survfit(fit.stan.cgd.exp.f.full,
                      newdata = data_test,
                      extrapolate = F,
                      type="loghaz",
                      draws = ndraws,return_matrix = F,
                      times       = "tstart",
                      last_time   = "tstop")$median
data_test$expsurv.f.full<-posterior_survfit(fit.stan.cgd.exp.f.full,
                      newdata = data_test,
                      extrapolate = F,
                      type="surv",
                      draws = ndraws,return_matrix = F,
                      times       = "tstart",
                      last_time   = "tstop")$median
#M-splines
data_test$ms10lp.f.full<-posterior_survfit(fit.stan.cgd.ms10.f.full,
                      newdata = data_test,
                      extrapolate = F,
                      type="loghaz",
                      draws = ndraws,return_matrix = F,
                      times       = "tstart",
                      last_time   = "tstop")$median
data_test$ms10surv.f.full<-posterior_survfit(fit.stan.cgd.ms10.f.full,
                      newdata = data_test,
                      extrapolate = F,
                      type="surv",
                      draws = ndraws,return_matrix = F,
                      times       = "tstart",
                      last_time   = "tstop")$median
# Pairs
pairs(~coxlp.f.full+explp.f.full+ms10lp.f.full, data_test,
      upper.panel = panel.cor,    # Correlation panel
      lower.panel = panel.smooth)

pairs(~coxsurv.f.full+expsurv.f.full+ms10surv.f.full, data_test,
      upper.panel = panel.cor,    # Correlation panel
      lower.panel = panel.smooth)

Graphically, we observe some correlation between cox and bayesian predictions. Let’s estimate C-index for our predictions

y_test <- Surv(data_test$tstart,
               data_test$tstop,
               data_test$infect)

# cindex for linear predictor (log hazard)
concordance(y_test~data_test$coxlp.f.full,reverse = T) #it works with risk
## Call:
## concordance.formula(object = y_test ~ data_test$coxlp.f.full, 
##     reverse = T)
## 
## n= 203 
## Concordance= 0.652 se= 0.03269
## concordant discordant     tied.x     tied.y    tied.xy 
##       4066       1758       1767          6          0
concordance(y_test~data_test$explp.f.full,data = data_test,reverse = T)
## Call:
## concordance.formula(object = y_test ~ data_test$explp.f.full, 
##     data = data_test, reverse = T)
## 
## n= 203 
## Concordance= 0.652 se= 0.03269
## concordant discordant     tied.x     tied.y    tied.xy 
##       4066       1758       1767          6          0
concordance(y_test~data_test$ms10lp.f.full,data = data_test,reverse = T)
## Call:
## concordance.formula(object = y_test ~ data_test$ms10lp.f.full, 
##     data = data_test, reverse = T)
## 
## n= 203 
## Concordance= 0.5639 se= 0.03532
## concordant discordant     tied.x     tied.y    tied.xy 
##       3793       2823        975          6          0

4.5 Calibration plots

How far are our predicted versus observed predictions?

#fixed time
times = as.double(seq(5, 439, 100))

summary(data_test$tstart)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##     0.0     0.0     0.0    69.5   121.0   373.0
#most common time slots
times2<-data_test %>% 
  dplyr::filter(tstart>0.0) %>% 
  mutate(ints = cut(tstart ,
                    breaks = seq(0, 439, 100),
                    include.lowest = FALSE,
                    right = FALSE)) %>% 
  dplyr::group_by(ints,tstart) %>% 
  dplyr::summarise(myn=n()) %>%  
  slice_max(myn, with_ties = F)
times2
## # A tibble: 4 × 3
## # Groups:   ints [4]
##   ints      tstart   myn
##   <fct>      <dbl> <int>
## 1 [0,100)       65     2
## 2 [100,200)    121     2
## 3 [200,300)    265     2
## 4 [300,400)    373     2
times2<-times2%>% 
  pull(tstart) 
times
## [1]   5 105 205 305 405
(times<-times2)
## [1]  65 121 265 373
y_test.f.full <- filter(data_test, tstart %in% times) %>% 
  ungroup() %>% 
  select(c("tstart","tstop","infect")) 
res<-calibrate(data = data_test,times = times,y = y_test.f.full,
               tstart_col = "tstart",tstop_col ="tstop",status_col = "infect",
               n_groups = 10,surv_col = "coxsurv.f.full" )
autoplot(res)+labs(subtitle = "Cox PH")

res<-calibrate(data = data_test,times = times,y = y_test.f.full,
               tstart_col = "tstart",tstop_col ="tstop",status_col = "infect",
               n_groups = 10,surv_col = "expsurv.f.full" )

autoplot(res)+labs(subtitle = "Constant (exp)")

res<-calibrate(data = data_test,times = times,y = y_test.f.full,
               tstart_col = "tstart",tstop_col ="tstop",status_col = "infect",
               n_groups = 10,surv_col = "ms10surv.f.full" )

autoplot(res)+labs(subtitle = "M-spline")

We obtain mixed results in the validation part for all approaches, reasons are manifold. Survival analysis is hard, specially in very tricky datasets (recurrent visits). I suggest you select model based on a combination of more than one validation technique (HR + elpd, HR + C-Index). Also, I suggest evidence-driven covariate-selection with additional caution on non-normal covariates.

5 Deploying your survival model

After the sufficient validation with internal and external datasets, you can make your model available to more people in your research/clinical organization by using Shiny. Shiny is a reactive…

sessionInfo()
## R version 4.2.1 (2022-06-23)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS Monterey 12.6.5
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] data.table_1.14.8 gtsummary_1.7.1   loo_2.6.0         bayesplot_1.10.0 
##  [5] patchwork_1.1.2   forcats_1.0.0     stringr_1.5.0     dplyr_1.1.3      
##  [9] purrr_1.0.2       readr_2.1.2       tidyr_1.3.0       tibble_3.2.1     
## [13] ggplot2_3.4.4     tidyverse_1.3.2   survival_3.5-5    rstanarm_2.26.1  
## [17] Rcpp_1.0.11      
## 
## loaded via a namespace (and not attached):
##   [1] readxl_1.4.1         backports_1.4.1      systemfonts_1.0.4   
##   [4] plyr_1.8.9           igraph_1.5.1         splines_4.2.1       
##   [7] crosstalk_1.2.0      rstantools_2.3.1.1   inline_0.3.19       
##  [10] digest_0.6.33        htmltools_0.5.6.1    fansi_1.0.5         
##  [13] magrittr_2.0.3       checkmate_2.3.0      googlesheets4_1.0.1 
##  [16] tzdb_0.3.0           modelr_0.1.9         RcppParallel_5.1.7  
##  [19] matrixStats_1.0.0    svglite_2.1.0        xts_0.13.1          
##  [22] timechange_0.2.0     prettyunits_1.2.0    colorspace_2.1-0    
##  [25] rvest_1.0.3          haven_2.5.1          xfun_0.41           
##  [28] callr_3.7.3          crayon_1.5.2         jsonlite_1.8.7      
##  [31] lme4_1.1-34          zoo_1.8-12           glue_1.6.2          
##  [34] kableExtra_1.3.4     survminer_0.4.9      gtable_0.3.4        
##  [37] gargle_1.2.1         webshot_0.5.4        V8_4.2.1            
##  [40] distributional_0.3.2 car_3.1-0            pkgbuild_1.4.2      
##  [43] rstan_2.32.3         abind_1.4-5          scales_1.2.1        
##  [46] DBI_1.1.3            rstatix_0.7.2        miniUI_0.1.1.1      
##  [49] viridisLite_0.4.2    xtable_1.8-4         km.ci_0.5-6         
##  [52] stats4_4.2.1         splines2_0.5.1       StanHeaders_2.26.28 
##  [55] DT_0.30              htmlwidgets_1.6.2    httr_1.4.7          
##  [58] threejs_0.3.3        posterior_1.5.0      ellipsis_0.3.2      
##  [61] pkgconfig_2.0.3      farver_2.1.1         sass_0.4.7          
##  [64] dbplyr_2.2.1         utf8_1.2.4           labeling_0.4.3      
##  [67] tidyselect_1.2.0     rlang_1.1.1          reshape2_1.4.4      
##  [70] later_1.3.1          munsell_0.5.0        cellranger_1.1.0    
##  [73] tools_4.2.1          cachem_1.0.8         cli_3.6.1           
##  [76] generics_0.1.3       ggridges_0.5.4       broom_1.0.1         
##  [79] evaluate_0.23        fastmap_1.1.1        yaml_2.3.7          
##  [82] processx_3.8.2       knitr_1.45           fs_1.6.3            
##  [85] survMisc_0.5.6       nlme_3.1-157         mime_0.12           
##  [88] xml2_1.3.3           compiler_4.2.1       shinythemes_1.2.0   
##  [91] rstudioapi_0.14      curl_5.1.0           ggsignif_0.6.3      
##  [94] gt_0.9.0             reprex_2.0.2         broom.helpers_1.13.0
##  [97] bslib_0.5.1          stringi_1.7.12       highr_0.10          
## [100] ps_1.7.5             lattice_0.20-45      Matrix_1.5-1        
## [103] commonmark_1.9.0     nloptr_2.0.3         markdown_1.11       
## [106] KMsurv_0.1-5         shinyjs_2.1.0        tensorA_0.36.2      
## [109] vctrs_0.6.4          pillar_1.9.0         lifecycle_1.0.3     
## [112] jquerylib_0.1.4      httpuv_1.6.12        QuickJSR_1.0.7      
## [115] R6_2.5.1             promises_1.2.1       gridExtra_2.3       
## [118] codetools_0.2-18     boot_1.3-28          colourpicker_1.3.0  
## [121] MASS_7.3-60          gtools_3.9.4         assertthat_0.2.1    
## [124] withr_2.5.2          shinystan_2.6.0      parallel_4.2.1      
## [127] hms_1.1.2            labelled_2.10.0      grid_4.2.1          
## [130] minqa_1.2.6          rmarkdown_2.25       carData_3.0-5       
## [133] googledrive_2.0.0    ggpubr_0.6.0         shiny_1.7.5.1       
## [136] lubridate_1.9.2      base64enc_0.1-3      dygraphs_1.1.1.6

6 References and resources

  1. Brilleman SL, Elçi EM, Novik JB, Wolfe R. Bayesian Survival Analysis Using the rstanarm R Package. 2020. arXiv preprint. arXiv:2002.09633. URL: [https://arxiv.org/abs/2002.09633.](https://arxiv.org/abs/2002.09633) Accessed July 1, 2022.

  2. Fleming and Harrington, Counting Processes and Survival Analysis, appendix D.2.

  3. Therneau, T., Crowson, C., and Atkinson E. “Using Time Dependent Covariates and Time Dependent Coefficients in the Cox Model”. Survival Vignettes. Accessed May 1, 2023.

  4. Therneau T (2023). A Package for Survival Analysis in R. R package version 3.5-5, https://CRAN.R-project.org/package=survival.

  5. Ramsay, J. O. 1988. “Monotone Regression Splines in Action.” Statistical Science 3 (4): 425–41. https: //doi.org/10.1214/ss/1177012761 Accessed July 1, 2022.

  6. Royston, Patrick, and Mahesh KB Parmar. 2013. “Restricted Mean Survival Time: An Alternative to the Hazard Ratio for the Design and Analysis of Randomized Trials with a Time-to-Event Outcome.” BMC Medical Research Methodology 13 (1): 152.

  7. Kassambara A, Kosinski M, Biecek P (2021). survminer: Drawing Survival Curves using ‘ggplot2’. R package version 0.4.9, https://CRAN.R-project.org/package=survminer.

  8. Vehtari A, Gelman A, Gabry J. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 2017/09/01 2017;27(5):1413-1432. doi:10.1007/s11222-016-9696-4

  9. Harrell F.E Jr., Lee K.L., Mark D.B., “Multivariable prognostic models: issues in developing models, evaluating assumptions and adequacy, and measuring and reducing errors”, Statistics in Medicine, 15(4), 361–87, 1996.