JAGS and Stan
[This article was first published on Wiekvoet, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
During the last year I have been running some estimations in both JAGS and Stan. In that period I have seen one example where JAGS could not get me decent samples (in the sense of low Rhat and high number of effective samples) but that was data which I could not blog about. When two weeks ago I had a problem where part of my model did not converge well in JAGS I wondered how Stan would fare. Hence this post. It appears that Stan did not really do much better. What did appear is that results in this kind of difficult problem can vary depending on the inits and random samples used in the chain. This probably means more samples helps, but that is not the topic of this post.Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Programs
In effect I expect most readers of this blog to know about both Stan and JAGS, but a few lines about them seem not amiss. Stan and JAGS can be used for the same kind of problems, but they are quite different. JAGS is a variation on BUGS, similar to WinBUGS and OpenBUGS, where a model states just relations between variables. Stan on the other hand, is a program where a model has clearly defined parts, where order of statements is of influence. Stan is compiled, which takes some time by itself. Both Stan and BUGS can be run by themselves, but I find it most convenient to run them from R. R is then used for pre-processing data, setting up the model and finally summarizing the samples. Because JAGS and Stan are so different, they need completely different number of MCMC samples. Stan is supposed to be more efficient, hence needing less samples to obtain a posterior result of similar quality.From a model development point of view, JAGS (rjags, R2jags) is slightly more integrated in R than Stan (Rstan), mostly because JAGS models pretend to be R models, which means my editor will lend a hand, while Rstan has its model just in a text vector. In addition, JAGS has no compilation time. The plus of Stan though is highly organized model code.
Models
The model describes the number of shootings per state, hierarchically under regions. This means there is a binomial probability of interest, the states, under beta distributed regions. The beta has uninformative priors. After some tweaking the models should be equivalent. This means that the JAGS model is slightly different from previous posts. The number of samples chosen is 250000 with 100000 burn-in for JAGS and 4000 with half burn-in for Stan. I have chosen for ten chains. Usually I would use four, but since I suspected some chains to misbehave, I opted for a larger number. The inits were either around 1000, which means that a number of parameters have to shift quite a bit to get beta near 1 in 100000 or close to the that distribution, which means the parameters mostly have to converge the regions and states to the correct values. In terms of model behavior I only look at the priors and hyperpriors. Especially a and b from the beta distribution (state level) are difficult to estimate, while their ratio and state level estimates are quite easy.Results
What I expected to write here is that Stan was coping a bit better, especially when the inits are a bit off. Which is what happened in the first version of the post. But then I did an additional calculation and Stan got worse results too. So, part of the conclusion is that it is very dependent on the inits and the random sampling in the MCMC chain.Speed
In terms of speed, Stan has the property that different chains have markedly different speeds. One chain can take 90 seconds while the next takes 250 seconds. In JAGS individual chains progress is not displayed, so no information there.In general speeds were about the same, 1000 to 1800 seconds. If that seems large, this was on a Celeron processor with one core used. MCMC chains are embarrassingly parallel, so gains can be made easy.
Gelman diagnostic
This is the diagnostic calculated in coda so the diagnostics are comparable. There are eight series of values in the figure. Each pair is for one model, where coda actually gives both point estimate and 95% upper limit of the estimate. Smart directs to the better inits, 1000 to inits around 1000. The x axis refers to the parameters estimated. There is something odd at variables 19, 20 and 28, 29. Just prior to posting I discovered that variables, especially aa and bb are sorted differently compared to the other variables in Stan than JAGS. In hindsight I should have put my parameters in alphabetical order.The plot shows that Stan is actually doing a bit less than JAGS especially with the inits which should have made correct results more easy.
Effective number of samples
Again the plot is made using calculations in coda so the numbers are comparable. In number of effective samples it seems Stan is doing a bit better for the more difficult parameters. For the easy parameters JAGS is a bit better, here the large number of samples for JAGS pays off..Conclusion
Neither JAGS nor Stan came out clearly on top, which was not as I expected. Nevertheless, it still seems that while JAGS is my tool for simple models, while Stan is the choice for more complex models. Conversion from JAGS to Stan was not difficult.Output from runs
stan & inits1
user system elapsed1992.725 2.666 2016.079
Inference for Stan model: model1.
10 chains, each with iter=4000; warmup=2000; thin=1;
post-warmup draws per chain=2000, total post-warmup draws=20000.
mean se_mean sd n_eff Rhat
a[1] 53.44 4.38 44.95 NA 105 1.10
a[2] 65.07 7.55 53.74 NA 51 1.10
a[3] 66.92 6.93 54.54 NA 62 1.10
a[4] 67.28 7.00 55.38 NA 63 1.11
a[5] 62.51 4.99 50.49 NA 102 1.10
a[6] 70.32 8.18 59.19 NA 52 1.11
a[7] 62.44 7.46 53.44 NA 51 1.10
a[8] 59.37 6.37 49.08 NA 59 1.11
a[9] 63.54 5.29 53.66 NA 103 1.10
b[1] 44661405.38 3638329.39 37093869.18 NA 104 1.10
b[2] 39099064.06 4007140.91 32542035.41 NA 66 1.10
b[3] 38335546.03 4378451.81 31386394.46 NA 51 1.11
b[4] 38034644.61 3133779.95 31056909.88 NA 98 1.10
b[5] 39443027.62 4038448.85 32265357.99 NA 64 1.10
b[6] 35355902.63 2933955.81 29177099.95 NA 99 1.10
b[7] 41237192.06 4561471.29 33706574.62 NA 55 1.10
b[8] 42188390.23 3364049.93 34756182.10 NA 107 1.10
b[9] 39656654.17 4301318.67 32155276.73 NA 56 1.10
betamean[1] 0.08 0.00 0.02 NA 793 1.01
betamean[2] 0.11 0.00 0.02 NA 1253 1.01
betamean[3] 0.11 0.00 0.01 NA 1436 1.01
betamean[4] 0.11 0.00 0.02 NA 1608 1.01
betamean[5] 0.10 0.00 0.02 NA 947 1.01
betamean[6] 0.12 0.00 0.02 NA 683 1.01
betamean[7] 0.09 0.00 0.02 NA 759 1.01
betamean[8] 0.09 0.00 0.01 NA 717 1.01
betamean[9] 0.10 0.00 0.02 NA 737 1.02
aa 63.61 6.74 52.03 NA 60 1.11
bb 39230839.48 3945047.13 31100418.06 NA 62 1.11
sda 11.71 1.18 13.79 NA 137 1.07
sdb 6716646.80 546142.87 8203368.43 NA 226 1.04
lp__ -7555.39 2.81 24.19 NA 74 1.10
Samples were drawn using NUTS(diag_e) at Sun Aug 24 10:32:29 2014.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
Stan & inits2
user system elapsed1039.834 1.779 1042.681
Inference for Stan model: model1.
10 chains, each with iter=4000; warmup=2000; thin=1;
post-warmup draws per chain=2000, total post-warmup draws=20000.
mean se_mean sd n_eff Rhat
a[1] 46.65 2.96 40.52 NA 188 1.04
a[2] 56.89 3.51 47.84 NA 186 1.03
a[3] 59.52 3.83 50.16 NA 171 1.04
a[4] 58.93 3.70 50.71 NA 188 1.04
a[5] 54.94 3.32 45.85 NA 190 1.03
a[6] 61.79 3.96 53.71 NA 184 1.04
a[7] 53.75 3.40 46.37 NA 186 1.04
a[8] 50.94 3.22 44.15 NA 187 1.04
a[9] 54.71 3.52 47.86 NA 185 1.04
b[1] 38295125.50 2367547.52 32822125.24 NA 192 1.04
b[2] 34100563.26 2135975.40 29277394.26 NA 188 1.04
b[3] 33605901.88 2089826.28 28387357.36 NA 185 1.04
b[4] 33282843.68 2064977.85 27983832.94 NA 184 1.04
b[5] 34582670.15 2189638.25 29708184.70 NA 184 1.04
b[6] 31320186.63 1969915.12 26631419.99 NA 183 1.04
b[7] 35911817.86 2246565.24 30569577.66 NA 185 1.04
b[8] 36698348.20 2277791.00 31350133.96 NA 189 1.03
b[9] 34452421.07 2110620.24 28624582.07 NA 184 1.04
betamean[1] 0.08 0.00 0.02 NA 531 1.02
betamean[2] 0.11 0.00 0.02 NA 935 1.01
betamean[3] 0.11 0.00 0.01 NA 377 1.02
betamean[4] 0.11 0.00 0.02 NA 1667 1.01
betamean[5] 0.10 0.00 0.02 NA 512 1.02
betamean[6] 0.12 0.00 0.02 NA 791 1.01
betamean[7] 0.09 0.00 0.01 NA 1239 1.01
betamean[8] 0.09 0.00 0.01 NA 426 1.02
betamean[9] 0.10 0.00 0.01 NA 704 1.01
aa 55.36 3.47 46.59 NA 181 1.04
bb 34503256.80 2129396.22 28613633.55 NA 181 1.04
sda 10.07 0.71 11.42 NA 256 1.03
sdb 5272208.87 372540.69 6542979.18 NA 308 1.02
lp__ -7557.13 1.74 23.38 NA 181 1.03
Samples were drawn using NUTS(diag_e) at Sun Aug 24 10:56:17 2014.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
JAGS inits 1
user system elapsed1784.454 1.222 1791.519
Inference for Bugs model at “/tmp/Rtmp2VW032/model75d4f74fe01.txt”, fit using jags,
10 chains, each with 250000 iterations (first 1e+05 discarded), n.thin = 150
n.sims = 10000 iterations saved
mu.vect sd.vect int.matrix Rhat n.eff
a[1] 52.059 43.877 NA 1.099 65
a[2] 62.548 51.117 NA 1.102 63
a[3] 64.949 53.088 NA 1.104 62
a[4] 65.120 54.238 NA 1.104 62
a[5] 61.225 50.014 NA 1.103 63
a[6] 67.498 56.420 NA 1.102 63
a[7] 59.181 49.137 NA 1.102 63
a[8] 57.787 47.902 NA 1.103 63
a[9] 61.403 51.773 NA 1.102 63
aa 60.755 49.814 NA 1.099 65
b[1] 42704592.037 35622401.674 NA 1.104 62
b[2] 37598265.554 31204530.763 NA 1.100 64
b[3] 36937885.481 30057853.515 NA 1.104 62
b[4] 36766568.933 30153113.256 NA 1.103 63
b[5] 38140646.333 31358488.234 NA 1.103 63
b[6] 34271877.336 28059552.425 NA 1.103 63
b[7] 39747380.703 32646090.587 NA 1.103 63
b[8] 40262327.803 33127265.445 NA 1.102 63
b[9] 38041673.881 30717741.484 NA 1.104 62
bb 37696463.522 30374989.608 NA 1.104 62
betamean[1] 0.079 0.018 NA 1.001 6100
betamean[2] 0.108 0.017 NA 1.005 1200
betamean[3] 0.111 0.012 NA 1.001 10000
betamean[4] 0.111 0.018 NA 1.002 4300
betamean[5] 0.103 0.015 NA 1.002 3200
betamean[6] 0.122 0.016 NA 1.004 1700
betamean[7] 0.094 0.015 NA 1.001 10000
betamean[8] 0.090 0.014 NA 1.002 4300
betamean[9] 0.098 0.015 NA 1.015 420
sda 10.819 15.196 NA 1.043 140
sdb 6136772.765 8870298.491 NA 1.046 130
deviance 266.189 18.534 NA 1.083 76
For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).
DIC info (using the rule, pD = var(deviance)/2)
pD = 151.1 and DIC = 417.3
DIC is an estimate of expected predictive error (lower deviance is better).
jags & ints 2
user system elapsed1854.296 0.863 1856.972
Inference for Bugs model at “/tmp/Rtmp2VW032/model75d1dda784d.txt”, fit using jags,
10 chains, each with 250000 iterations (first 1e+05 discarded), n.thin = 150
n.sims = 10000 iterations saved
mu.vect sd.vect int.matrix Rhat n.eff
a[1] 62.153 46.590 NA 1.063 98
a[2] 74.562 53.977 NA 1.063 98
a[3] 77.974 56.111 NA 1.067 92
a[4] 77.895 56.781 NA 1.067 93
a[5] 73.379 53.374 NA 1.065 94
a[6] 80.658 59.219 NA 1.065 95
a[7] 71.378 52.600 NA 1.062 99
a[8] 69.329 50.894 NA 1.065 95
a[9] 73.640 55.117 NA 1.064 97
aa 72.389 52.305 NA 1.062 99
b[1] 51494466.291 38425814.261 NA 1.065 95
b[2] 44959657.895 33022305.640 NA 1.062 99
b[3] 44386799.990 31879332.379 NA 1.066 94
b[4] 43857406.601 31539586.787 NA 1.066 94
b[5] 45956962.745 33781483.665 NA 1.066 94
b[6] 40856777.369 29453791.246 NA 1.065 95
b[7] 47575125.880 34573426.507 NA 1.063 98
b[8] 48634592.709 35656103.487 NA 1.066 94
b[9] 45324100.091 32444527.795 NA 1.063 97
bb 44759762.836 31404548.768 NA 1.062 99
betamean[1] 0.078 0.018 NA 1.002 4500
betamean[2] 0.107 0.016 NA 1.004 1600
betamean[3] 0.111 0.012 NA 1.001 9700
betamean[4] 0.112 0.017 NA 1.002 4200
betamean[5] 0.102 0.014 NA 1.002 3500
betamean[6] 0.123 0.015 NA 1.004 1800
betamean[7] 0.094 0.015 NA 1.001 10000
betamean[8] 0.090 0.014 NA 1.001 7800
betamean[9] 0.099 0.014 NA 1.013 470
sda 14.300 21.320 NA 1.027 220
sdb 8080234.258 12065450.388 NA 1.027 220
deviance 269.866 17.631 NA 1.053 120
For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).
DIC info (using the rule, pD = var(deviance)/2)
pD = 143.4 and DIC = 413.2
DIC is an estimate of expected predictive error (lower deviance is better).
Code
Reading data
r13 <- readLines('raw13.txt')r14 <- readLines('raw14.txt')
r1 <- c(r13,r14)
r2 <- gsub('\[[a-zA-Z0-9]*\]','',r1)
r3 <- gsub('^ *$','',r2)
r4 <- r3[r3!='']
r5 <- gsub('\t$','',r4)
r6 <- gsub('\t References$','',r5)
r7 <- read.table(textConnection(r6),
sep=’t’,
header=TRUE,
stringsAsFactors=FALSE)
r7$Location[r7$Location==’Washington DC’] <-
‘WashingtonDC, DC’
r8 <- read.table(textConnection(as.character(r7$Location)),
sep=’,’,
col.names=c(‘Location’,’State’),
stringsAsFactors=FALSE)
r8$State <- gsub(' ','',r8$State)
r8$State[r8$State==’Tennessee’] <- 'TN'
r8$State[r8$State==’Ohio’] <- 'OH'
r8$State[r8$State %in% c(‘Kansas’,’KA’)] <- 'KS'
r8$State[r8$State==’Louisiana’] <- 'LA'
r8$State[r8$State==’Illinois’] <- 'IL'
r8$State <- toupper(r8$State)
table(r8$State)
r7$StateAbb <- r8$State
r7$Location <- r8$Location
r7 <- r7[! (r7$State %in% c( 'PUERTORICO','PR')),]
r7$Date <- gsub('/13$','/2013',r7$Date)
r7$date <- as.Date(r7$Date,format="%m/%d/%Y")
states <- data.frame(
StateAbb=as.character(state.abb),
StateRegion=state.division,
State=as.character(state.name)
)
states <- rbind(states,data.frame(StateAbb='DC',
State=’District of Columbia’,
StateRegion= ‘Middle Atlantic’))
# http://www.census.gov/popest/data/state/totals/2013/index.html
inhabitants <- read.csv('NST-EST2013-01.treated.csv')
#put it all together
states <- merge(states,inhabitants)
r9 <- merge(r7,states)
#########################
r10 <- merge(as.data.frame(xtabs(~StateAbb,data=r9)),states,all=TRUE)
r10$Freq[is.na(r10$Freq)] <- 0
r10$Incidence <- r10$Freq*100000*365/r10$Population/
as.numeric((max(r7$date)-min(r7$date)))
Common for modelling
datain <- list(count=r10$Freq,
Population = r10$Population,
n=nrow(r10),
nregion =nlevels(r10$StateRegion),
Region=(1:nlevels(r10$StateRegion))[r10$StateRegion],
scale=100000*365/
as.numeric((max(r7$date)-min(r7$date))))
parameters <- c('a','b','betamean','aa','bb','sda','sdb')
inits1 <- function()
list(a=rnorm(datain$nregion,100,10),
b=rnorm(datain$nregion,1e8,1e7),
aa=rnorm(1,100,10),
sda=rnorm(1,100,10),
sdb=rnorm(1,1e7,1e6),
bb=rnorm(1,1e7,1e6),
p1=rnorm(datain$n,1e-7,1e-8))
inits2 <- function()
list(a=rnorm(datain$nregion,1000,100),
b=rnorm(datain$nregion,1000,100),
aa= rnorm(1, 1000,100),
sda=rnorm(1,1000,100),
sdb=rnorm(1,1000,100),
bb= rnorm(1,1000,100),
p1=rnorm(datain$n,1e-7,1e-8))
inits1l<- lapply(1:10,function(x) inits1())
inits2l<- lapply(1:10,function(x) inits2())
Stan model
model1 <- 'data {
int
int
int count[n];
int
int Population[n];
real scale;
}
parameters {
vector
vector
vector
real
real
real
real
}
model {
for (i in 1:n) {
p1[i] ~ beta(a[Region[i]],b[Region[i]]);
}
count ~ binomial(Population,p1);
a ~ normal(aa,sda);
b ~ normal(bb,sdb);
aa ~ uniform(0,1e5);
bb ~ uniform(0,1e8);
sda ~ uniform(0,1e5);
sdb ~ uniform(0,1e8);
}
generated quantities {
vector [n] pp;
vector [nregion] betamean;
for (i in 1:nregion) {
betamean[i] <- scale*a[i]/(a[i]+b[i]);
}
pp <- p1 * scale;
}
‘
system.time(fits1 <- stan(model_code = model1,
data = datain,
pars=parameters,
init=inits1l,
iter = 4000,
chains = 10))
print(fits1,probs=NA)
system.time(fits2 <- stan(model_code = model1,
data = datain,
pars=parameters,
init=inits2l,
iter = 4000,
chains = 10))
print(fits2,probs=NA)
JAGS model
model.jags <- function() {for (i in 1:n) {
count[i] ~ dbin(p1[i],Population[i])
p1[i] ~ dbeta(a[Region[i]],b[Region[i]])
pp[i] <- p1[i]*scale
}
for (i in 1:nregion) {
a[i] ~ dnorm(aa,tauaa) %_% T(0,)
b[i] ~ dnorm(bb,taubb) %_% T(0,)
betamean[i] <- scale * a[i]/(a[i]+b[i])
}
tauaa <- pow(sda,-2)
sda ~dunif(0,1e5)
taubb <- pow(sdb,-2)
sdb ~dunif(0,1e8)
aa ~ dunif(0,1e5)
bb ~ dunif(0,1e8)
}
system.time(jags1 <- jags(datain,
model=model.jags,
inits=inits1l,
parameters=parameters,
n.iter=250000,
n.burnin=100000,
n.chain=10))
print(jags1,intervals=NA)
system.time(jags2 <- jags(datain,
model=model.jags,
inits=inits2l,
parameters=parameters,
n.iter=250000,
n.burnin=100000,
n.chain=10))
print(jags2,intervals=NA)
Post processing
st1coda <- mcmc.list(lapply(1:ncol(fits1),function(x) mcmc(as.array(fits1)[,x,])))
st2coda <- mcmc.list(lapply(1:ncol(fits2),
function(x) mcmc(as.array(fits2)[,x,])))
jg1coda <- as.mcmc(jags1)
jg2coda <- as.mcmc(jags2)
options(width=100)
gdiag <- cbind(gelman.diag(st1coda)[[1]],
gelman.diag(st2coda)[[1]],
gelman.diag(jg1coda)[[1]],
gelman.diag(jg2coda)[[1]])
png(‘Gelman Diagnostic.png’)
matplot(gdiag,ylim=c(1,1.5),ylab=’Gelman Diagnostic’)
legend(x=’topleft’,pch=format(c(1:8)),
c(‘1 Stan smart point’,
‘2 Stan smart upper’,
‘3 Stan 1000 point’,
‘4 Stan 1000 upper’,
‘5 Jags smart point’,
‘6 Jags smart upper’,
‘7 Jags 1000 point’,
‘8 Jags 1000 upper’),
col=c(1:6,1,2),ncol=2)
dev.off()
efs <- cbind(effectiveSize(st1coda),
effectiveSize(st2coda),
effectiveSize(jg1coda),
effectiveSize(jg2coda))
png(‘Efsz.png’)
matplot(efs,log=’y’,ylab=’Effective sample size’)
legend(x=’topleft’,pch=format(c(1:8)),
c(‘1 Stan smart ‘,
‘2 Stan 1000 point’,
‘3 Jags smart point’,
‘4 Jags 1000 point’),
col=c(1:4))
dev.off()
To leave a comment for the author, please follow the link and comment on their blog: Wiekvoet.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.