Fitting GAMs with brms: part 1
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Regular readers will know that I have a somewhat unhealthy relationship with GAMs and the mgcv package. I use these models all the time in my research but recently we’ve been hitting the limits of the range of models that mgcv can fit. So I’ve been looking into alternative ways to fit the GAMs I want to fit but which can handle the kinds of data or distributions that have been cropping up in our work. The brms package (Bürkner, 2017) is an excellent resource for modellers, providing a high-level R front end to a vast array of model types, all fitted using Stan. brms is the perfect package to go beyond the limits of mgcv because brms even uses the smooth functions provided by mgcv, making the transition easier. In this post I take a look at how to fit a simple GAM in brms and compare it with the same model fitted using mgcv.
In this post we’ll use the following packages. If you don’t know schoenberg, it’s a package I’m writing to provide ggplot
versions of plots that can be produced by mgcv from fitted GAM objects. schoenberg is in early development, but it currentl works well enough to plot the models we fit here. If you’ve never come across this package before, you can install it from Github using devtools::install_github(‘gavinsimpson/schoenberg’)
To illustrate brms GAM-fittgin chops, we’ll use the mcycle
data set that comes with the MASS package. It contains a set of measurements of the acceleration force on a rider’s head during a simulated motorcycle collision and the time, in milliseconds, post collision. The data are loaded using data()
and we take a look at the first few rows
The aim is to model the acceleration force (accel
) as a function of time post collision (times
). The plot below shows the data.
We’ll model acceleration as a smooth function of time using a GAM and the default thin plate regression spline basis. This can be done using the gam()
function in mgcv ad for comparison with the fully bayesian model we’ll fit shortly, we use `method = “REML” to estimate the smoothness parameter for the spline in mixed model form and REML
As we can see from the model summary, the estimated smooth uses about 8.5 effective degrees of freedom and in the test of zero effect, the null hypothesis is strongly rejected. The fitted spline explains about 80% of the variance or deviance in the data.
To plot the fitted smooth we could use the plot()
method provided by mgcv, but this uses base graphics. Instead we can use the draw()
method from schoenberg, which can currently handle most of the univariate smooths in mgcv plus 2-d tensor product smooths
The equivalent model can be estimated using a fully-bayesian approach via the brm()
function in the brms package. In fact, brm()
will use the smooth specification functions from mgcv, making our lives much easier. The major difference though is that you can’t use te()
or ti()
smooths in brm()
models; you need to use t2()
tensor product smooths instead. This is because the smooths in the model are going to be treated as random effects and the model estimated as a GLMM, which exploits the duality of splines as random effects. In this representation, the wiggly parts of the spline basis are treated as a random effect and their associated variance parameter controls the degree of wiggliness of the fitted spline. The perfectly smooth parts of the basis are treated as a fixed effect. In this form, the GAM can be estimated using standard GLMM software; it’s what allows the gamm4()
function for example to fit GAMMs using the lme4 package for example. This is also the reason why we can’t use te()
or ti()
smooths; those smooths do not have nicely separable penalties which means they can’t be written in the form required to be fitted using typical mixed model software.
The brm()
version of the GAM is fitted using the code below. Note that I have changed a few things from their default values as
-
the model required more than the default number of MCMC samples —
iter = 4000
, -
the samples needed thinning to deal with some strong autocorrelation in the Markov chains —
thin = 10
, -
the
adapt.delta
parameter, a tuning parameter in he NUTS sampler for Hamiltonian Monte Carlo, potentially needed raising — there was a warning about a potential divergent transition but I should have looked to see if it was one or not; instead I just increased the tuning parameter to0.99
, -
four chains fitted by default but I wanted these to be fitted using 4 CPU
cores
, -
seed
sets the internal random number generator seed, which allows reproducibility of models, and -
for this post I didn’t want to print out the progress of the sampler —
refresh = 0
— typically you won’t want to do this so you can see how the sampling is progressing.
The rest of the model is pretty similar to the gam()
version we fitted earlier. The main difference is that I use the bf()
function to create a special brms formula specifying the model. You don’t actually need to do this for such a simple model, but in a later post we’ll use this to fit distributional GAMs. Note that I’m leaving all the priors in the model at the default values. I’ll look at defining priors in a later post; for now I’m just going to use the default priors that brm()
uses
Once the model has finished compiling and sampling we can output the model summary
This outputs details of the model fitted plus parameter estimates (as posterior means), standard errors, (by default) 95% credible intervals and two other diagnostics:
-
Eff.Sample
is the effective sample size of the posterior samples in the model, and -
Rhat
is the potential scale reduction factor or Gelman-Rubin diagnostic and is a measure of how well the chains have converged and ideally should be equal to1
.
The summary includes two entries for the smooth of times
:
-
sds(stimes_1)
is the variance parameter, which has the effect of controlling the wiggliness of the smooth — the larger this value the more wiggly the smooth. We can see that the credible interval doesn’t include 0 so there is evidence that a smooth is required over and above a linear parametric effect oftimes
, details of which are given next, -
stimes_1
is the fixed effect part of the spline, which is the linear function that is perfectly smooth.
The final parameter table includes information on the variance of the data about the conditional mean of the response.
How does this model compare with the one fitted using gam()
? We can use the gam.vcomp()
function to compute the variance component representation of the smooth estimated via gam()
. To make it comparable with the value shown for the brms model, we don’t undo the rescaling of the penalty matrix that gam()
performs to help with numeric stability during model fitting.
This gives a posterior mean of 807.89 with 95% confidence interval of 480.66–1357.88, which compares well with posterior mean and credible interval of the brm()
version of 722.44 (450.17 – 1150.27).
The marginal_smooths()
function is used to extract the marginal effect of the spline.
This function extracts enough information about the estimated spline to plot it using the plot()
method
Given the similarity in the variance components of the two models it is not surprising the two estimated smooth also look similar. The marginal_smooths()
is effectively the equivalent of the plot()
method for mgcv-based GAMs.
There’s a lot that we can and should do to check the model fit. For now, we’ll look at two posterior predictive check plots that brms, via the bayesplot package (Gabry and Mahr, 2018), makes very easy to produce using the pp_check()
function.
The default produces a density plot overlay of the original response values (the thick black line) with 10 draws from the posterior distribution of the model. If the model is a good fit to the data, samples of data sampled from it at the observed values of the covariate(s) should be similar to one another.
Another type of posterior predictive check plot is the empirical cumulative distribution function of the observations and random draws from the model posterior, which we can produce with type = “ecdf_overlay”
Both plots show significant deviations between the the posterior simulations and the observed data. The poor posterior predictive check results are in large part due to the non-constant variance of the acceleration data conditional upon the covariate. Both models assumed that the observation are distributed Gaussian with means equal to the fitted values (estimated expectation of the response) with the same variance (^2). The observations appear to have different variances, which we can model with a distributional model, which allow all parameters of the distribution of the response to be modelled with linear predictors. We’ll take a look at these models in a future post.
References
Bürkner, P.-C. (2017). brms: An R package for bayesian multilevel models using Stan. Journal of Statistical Software 80, 1–28. doi:10.18637/jss.v080.i01.
Gabry, J., and Mahr, T. (2018). Bayesplot: Plotting for bayesian models. Available at: https://CRAN.R-project.org/package=bayesplot.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.