Benchmarking time series models
[This article was first published on Modern Toolmaking, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
This is a quick post on the importance of benchmarking time-series forecasts. First we need to reload the functions from my last few posts on times-series cross-validation. (I copied the relevant code at the bottom of this post so you don’t have to find it).Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Next, we need to load data for the S&P 500. To simplify things, and allow us to explore seasonality effects, I’m going to load monthly data, back to 1980.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Setup | |
set.seed(1) | |
library(quantmod) | |
library(forecast) | |
#Load data | |
getSymbols('^GSPC', from='1980-01-01') | |
#Simplify to monthly level | |
GSPC <- to.monthly(GSPC) | |
Data <- as.ts(Cl(GSPC), start=1980) |
The object “Data” has monthly closing prices for the S&P 500 back until 1980. Next, we cross validate 3 time series forecasting models: auto.arima, from the forecast package, a mean forecast, that returns the mean value over the last year, and a naive forecast, which assumes the next value of the series will be equal to the present value. These last 2 forecasts serve as benchmarks, to help determine if auto.arima would be useful for forecasting the S&P 500. Also note that I’m using BIC as a criteria for selecting arima models, and I have trace on so you can see the results of the model selection process.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Setup model cross-validation | |
myControl <- list( minObs=12, | |
stepSize=1, | |
maxHorizon=12, | |
fixedWindow=TRUE, | |
preProcess=FALSE, | |
summaryFunc=tsSummary | |
) | |
#Cross validate 3 models (model 3 is SLOW!) | |
model1 <- cv.ts(Data, meanForecast, myControl) | |
model2 <- cv.ts(Data, naiveForecast, myControl) | |
model3 <- cv.ts(Data, auto.arimaForecast, myControl, ic='bic', trace=TRUE) | |
#Find the RMSE for each model and create a matrix with our results | |
models <- list(model1,model2,model3) | |
models <- lapply(models, function(x) x[1:12,'RMSE']) | |
results <- do.call(cbind,models) | |
colnames(results) <- c('mean','naive','ar') | |
#Order by average RMSE for the 1st 3 months | |
results <- t(results) | |
order <- rowMeans(results[,1:3]) | |
results <- results[order(order),] | |
print(results) |
After the 3 models finish cross-validating, it is useful to plot their forecast errors at different horizons. As you can see, auto.arima performs much better than the mean model, but is constantly worse than the naive model. This illustrates the importance of benchmarking forecasts. If you can’t constantly beat a naive forecast, there’s no reason to waste processing power on a useless model.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Plot | |
color <- 1:nrow(results) | |
plot(results[1,], col=1, type='l', ylim=c(min(results),max(results))) | |
for (i in color) { | |
lines(results[i,],col=i) | |
} | |
legend("topleft",legend=row.names(results),col=color,lty=1) |
Finally, here is all the code in one place. Note that you can parallelize the cv.ts function by loading your favorite foreach backend.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Function to cross-validate a time series. | |
cv.ts <- function(x, FUN, tsControl, xreg=NULL, ...) { | |
#Load required packages | |
stopifnot(is.ts(x)) | |
stopifnot(is.data.frame(xreg) | is.matrix(xreg) | is.null(xreg)) | |
stopifnot(require(forecast)) | |
stopifnot(require(foreach)) | |
stopifnot(require(plyr)) | |
#Load parameters from the tsControl list | |
stepSize <- tsControl$stepSize | |
maxHorizon <- tsControl$maxHorizon | |
minObs <- tsControl$minObs | |
fixedWindow <- tsControl$fixedWindow | |
summaryFunc <- tsControl$summaryFunc | |
preProcess <- tsControl$preProcess | |
#Make sure xreg object is long enough for last set of forecasts | |
if (! is.null(xreg)) { | |
xreg <- as.matrix(xreg) | |
if (nrow(xreg)<length(x)+maxHorizon) { | |
warning('xreg object too short to forecast beyond the length of the time series. | |
Appending NA values to xreg') | |
nRows <- (length(x)+maxHorizon)-nrow(xreg) | |
nCols <- dim(xreg)[2] | |
addRows <- matrix(rep(NA,nCols*nRows),nrow=nRows, ncol=nCols) | |
colnames(addRows) <- colnames(xreg) | |
xreg <- rbind(xreg,addRows) | |
} | |
} | |
#Define additional parameters | |
freq <- frequency(x) | |
n <- length(x) | |
st <- tsp(x)[1]+(minObs-2)/freq | |
#Create a matrix of actual values. | |
#X is the point in time, Y is the forecast horizon | |
#http://stackoverflow.com/questions/8140577/creating-a-matrix-of-future-values-for-a-time-series | |
formatActuals <- function(x,maxHorizon) { | |
actuals <- outer(seq_along(x), seq_len(maxHorizon), FUN="+") | |
actuals <- apply(actuals,2,function(a) x[a]) | |
actuals | |
} | |
actuals <- formatActuals(x,maxHorizon) | |
actuals <- actuals[minObs:(length(x)-1),,drop=FALSE] | |
#Create a list of training windows | |
#Each entry of this list will be the same length, if fixed=TRUE | |
#At each point in time, calculate 'maxHorizon' forecasts ahead | |
steps <- seq(1,(n-minObs),by=stepSize) | |
forcasts <- foreach(i=steps, .combine=rbind, .multicombine=FALSE) %dopar% { | |
if (is.null(xreg)) { | |
if (fixedWindow) { | |
xshort <- window(x, start=st+(i-minObs+1)/freq, end=st+i/freq) | |
} else { | |
xshort <- window(x, end=st + i/freq) | |
} | |
if (preProcess) { | |
if (testObject(lambda)) { | |
stop("Don't specify a lambda parameter when preProcess==TRUE") | |
} | |
stepLambda <- BoxCox.lambda(xshort, method='loglik') | |
xshort <- BoxCox(xshort, stepLambda) | |
} | |
out <- FUN(xshort, h=maxHorizon, ...) | |
if (preProcess) { | |
out <- InvBoxCox(out, stepLambda) | |
} | |
return(out) | |
} else if (! is.null(xreg)) { | |
if (fixedWindow) { | |
xshort <- window(x, start=st+(i-minObs+1)/freq, end=st+i/freq) | |
xregshort <- xreg[((i):(i+minObs-1)),,drop=FALSE] | |
} else { | |
xshort <- window(x, end=st + i/freq) | |
xregshort <- xreg[(1:(i+minObs-1)),,drop=FALSE] | |
} | |
newxreg <- xreg[(i+minObs):(i+minObs-1+maxHorizon),,drop=FALSE] | |
if (preProcess) { | |
if (testObject(lambda)) { | |
stop("Don't specify a lambda parameter when preProcess==TRUE") | |
} | |
stepLambda <- BoxCox.lambda(xshort, method='loglik') | |
xshort <- BoxCox(xshort, stepLambda) | |
} | |
out <- FUN(xshort, h=maxHorizon, | |
xreg=xregshort, newxreg=newxreg, ...) | |
if (preProcess) { | |
out <- InvBoxCox(out, stepLambda) | |
} | |
return(out) | |
} | |
} | |
#Extract the actuals we actually want to use | |
actuals <- actuals[steps,,drop=FALSE] | |
#Accuracy at each horizon | |
out <- data.frame( | |
ldply(1:maxHorizon, | |
function(horizon) { | |
P <- forcasts[,horizon,drop=FALSE] | |
A <- na.omit(actuals[,horizon,drop=FALSE]) | |
P <- P[1:length(A)] | |
P <- na.omit(P) | |
A <- A[1:length(P)] | |
summaryFunc(P,A) | |
} | |
) | |
) | |
#Add average accuracy, across all horizons | |
overall <- colMeans(out) | |
out <- rbind(out,overall) | |
#Add a column for which horizon and output | |
return(data.frame(horizon=c(1:maxHorizon,'All'),out)) | |
} | |
#Summary function for time series cross-validation | |
stopifnot(require(compiler)) | |
testObject <- function(object){ | |
exists(as.character(substitute(object))) | |
} | |
tsSummary <- cmpfun(function(P,A) { | |
data.frame(t(accuracy(P,A))) | |
}) | |
#Forecasting wrappers | |
meanForecast <- cmpfun(function(x,h,...) { | |
require(forecast) | |
meanf(x, h, ..., level=99)$mean | |
}) | |
naiveForecast <- cmpfun(function(x,h,...) { | |
require(forecast) | |
naive(x, h, ..., level=99)$mean | |
}) | |
auto.arimaForecast <- cmpfun(function(x,h,xreg=NULL,newxreg=NULL,...) { | |
require(forecast) | |
fit <- auto.arima(x, xreg=xreg, ...) | |
forecast(fit, h=h, level=99, xreg=newxreg)$mean | |
}) | |
#Setup | |
set.seed(1) | |
library(quantmod) | |
library(forecast) | |
#Load data | |
getSymbols('^GSPC', from='1980-01-01') | |
#Simplify to monthly level | |
GSPC <- to.monthly(GSPC) | |
Data <- as.ts(Cl(GSPC), start=1980) | |
#Setup model cross-validation | |
myControl <- list( minObs=12, | |
stepSize=1, | |
maxHorizon=12, | |
fixedWindow=TRUE, | |
preProcess=FALSE, | |
summaryFunc=tsSummary | |
) | |
#Cross validate 3 models (model 3 is SLOW!) | |
model1 <- cv.ts(Data, meanForecast, myControl) | |
model2 <- cv.ts(Data, naiveForecast, myControl) | |
model3 <- cv.ts(Data, auto.arimaForecast, myControl, ic='bic') | |
#Find the RMSE for each model and create a matrix with our results | |
models <- list(model1,model2,model3) | |
models <- lapply(models, function(x) x[1:12,'RMSE']) | |
results <- do.call(cbind,models) | |
colnames(results) <- c('mean','naive','ar') | |
#Order by average RMSE for the 1st 3 months | |
results <- t(results) | |
order <- rowMeans(results[,1:3]) | |
results <- results[order(order),] | |
print(results) | |
#Plot | |
color <- 1:nrow(results) | |
plot(results[1,], col=1, type='l', ylim=c(min(results),max(results))) | |
for (i in color) { | |
lines(results[i,],col=i) | |
} | |
legend("topleft",legend=row.names(results),col=color,lty=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Plot | |
color <- 1:nrow(results) | |
plot(results[1,], col=1, type='l', ylim=c(min(results),max(results))) | |
for (i in color) { | |
lines(results[i,],col=i) | |
} | |
legend("topleft",legend=row.names(results),col=color,lty=1) |
To leave a comment for the author, please follow the link and comment on their blog: Modern Toolmaking.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.