Tune xgboost models with early stopping to predict shelter animal status
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
This is the latest in my series of screencasts demonstrating how to use the tidymodels packages, from just getting started to tuning more complex models. I participated in this week’s episode of the SLICED playoffs, a competitive data science streaming show, where we competed to predict the status of shelter animals. ? I used xgboost’s early stopping feature as I competed, so let’s walk through how and when to try that out!
Here is the code I used in the video, for those who prefer reading instead of or in addition to video.
Explore data
Our modeling goal is to predict
the outcome for shelter animals (adoption, transfer, or no outcome) given features about the animal and event. The main data set provided is in a CSV file called training.csv
.
library(tidyverse) train_raw <- read_csv("train.csv")
You can watch this week’s full episode of SLICED to see lots of exploratory data analysis and visualization of this dataset, but let’s just make a few plots to understand it better.
How are outcomes distributed for animals of different ages?
library(lubridate) train_raw %>% mutate( age_upon_outcome = as.period(as.Date(datetime) - date_of_birth), age_upon_outcome = time_length(age_upon_outcome, unit = "weeks") ) %>% ggplot(aes(age_upon_outcome, after_stat(density), fill = outcome_type)) + geom_histogram(bins = 15, alpha = 0.5, position = "identity") + labs(x = "Age in weeks", fill = NULL)
How does adoption rate change with day of the week and week of the year?
train_raw %>% mutate(outcome_type = outcome_type == "adoption") %>% group_by( week = week(datetime), wday = wday(datetime) ) %>% summarise(outcome_type = mean(outcome_type)) %>% ggplot(aes(week, wday, fill = outcome_type)) + geom_tile(alpha = 0.8) + scale_fill_viridis_c(labels = scales::percent) + labs(fill = "% adopted", x = "week of the year", y = "week day")
Notice the difference on weekends vs. weekdays especially!
There is certainly lots more to explore (including, for example, learning about the names of the animals, something I spent a good bit of time on during the competition), but let’s move on to modeling.
Build a model
Let’s start our modeling by setting up our “data budget,” as well as the metrics (this challenge was evaluate on multiclass log loss).
library(tidymodels) set.seed(123) shelter_split <- train_raw %>% mutate( age_upon_outcome = as.period(as.Date(datetime) - date_of_birth), age_upon_outcome = time_length(age_upon_outcome, unit = "weeks") ) %>% initial_split(strata = outcome_type) shelter_train <- training(shelter_split) shelter_test <- testing(shelter_split) shelter_metrics <- metric_set(accuracy, roc_auc, mn_log_loss) set.seed(234) shelter_folds <- vfold_cv(shelter_train, strata = outcome_type) shelter_folds ## # 10-fold cross-validation using stratification ## # A tibble: 10 × 2 ## splits id ## <list> <chr> ## 1 <split [36724/4081]> Fold01 ## 2 <split [36724/4081]> Fold02 ## 3 <split [36724/4081]> Fold03 ## 4 <split [36724/4081]> Fold04 ## 5 <split [36724/4081]> Fold05 ## 6 <split [36725/4080]> Fold06 ## 7 <split [36725/4080]> Fold07 ## 8 <split [36725/4080]> Fold08 ## 9 <split [36725/4080]> Fold09 ## 10 <split [36725/4080]> Fold10
For feature engineering, let’s concentrate on just a handful of predictors, like when the event (adoption, transfer, or “no outcome”) was recorded and features of the animal itself like age, sex, type, etc.
shelter_rec <- recipe(outcome_type ~ age_upon_outcome + animal_type + datetime + sex + spay_neuter, data = shelter_train ) %>% step_date(datetime, features = c("year", "week", "dow"), keep_original_cols = FALSE) %>% step_dummy(all_nominal_predictors(), one_hot = TRUE) %>% step_zv(all_predictors()) ## we can `prep()` just to check that it works prep(shelter_rec) ## Data Recipe ## ## Inputs: ## ## role #variables ## outcome 1 ## predictor 5 ## ## Training data contained 40805 data points and no missing data. ## ## Operations: ## ## Date features from datetime [trained] ## Dummy variables from animal_type, sex, spay_neuter, datetime_dow [trained] ## Zero variance filter removed no terms [trained]
Now let’s create a tunable xgboost model specification. This is where
early stopping comes in; we will keep the number of trees as a constant (and not too terribly high), set stop_iter
(the early stopping parameter) to tune()
, and then tune a few other parameters. Notice that we need to set a validation
set (a proportion of each analysis set, actually) to hold back to use for deciding when to stop.
We can also create a custom stopping_grid
to specific what parameters I want to try out.
stopping_spec <- boost_tree( trees = 500, mtry = tune(), learn_rate = tune(), stop_iter = tune() ) %>% set_engine("xgboost", validation = 0.2) %>% set_mode("classification") stopping_grid <- grid_latin_hypercube( mtry(range = c(5L, 20L)), ## depends on number of columns in data learn_rate(range = c(-5, -1)), ## keep pretty big stop_iter(range = c(10L, 50L)), ## bigger than default size = 10 )
Now we can put these together in a workflow and tune across the grid of parameters and our resamples.
early_stop_wf <- workflow(shelter_rec, stopping_spec) doParallel::registerDoParallel() set.seed(345) stopping_rs <- tune_grid( early_stop_wf, shelter_folds, grid = stopping_grid, metrics = shelter_metrics )
We did it!
Evaluate results
How did these results turn out? We can visualize them.
autoplot(stopping_rs) + theme_light(base_family = "IBMPlexSans")
Or we can look at the top results manually.
show_best(stopping_rs, metric = "mn_log_loss") ## # A tibble: 5 × 9 ## mtry learn_rate stop_iter .metric .estimator mean n std_err .config ## <int> <dbl> <int> <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 12 0.0612 46 mn_log_loss multiclass 0.502 10 0.00319 Preproc… ## 2 18 0.0378 36 mn_log_loss multiclass 0.505 10 0.00279 Preproc… ## 3 7 0.00710 12 mn_log_loss multiclass 0.544 10 0.00246 Preproc… ## 4 9 0.00252 33 mn_log_loss multiclass 0.655 10 0.00145 Preproc… ## 5 11 0.00195 25 mn_log_loss multiclass 0.699 10 0.00122 Preproc…
Let’s use last_fit()
to fit one final time to the training data and evaluate one final time on the testing data, with the numerically optimal result from stopping_rs
.
stopping_fit <- early_stop_wf %>% finalize_workflow(select_best(stopping_rs, "mn_log_loss")) %>% last_fit(shelter_split) stopping_fit ## # Resampling results ## # Manual resampling ## # A tibble: 1 × 6 ## splits id .metrics .notes .predictions .workflow ## <list> <chr> <list> <list> <list> <list> ## 1 <split [40805/13603]> train/test split <tibble … <tibb… <tibble [13… <workflo…
How did this model perform on the testing data, that was not used in tuning/training?
collect_metrics(stopping_fit) ## # A tibble: 2 × 4 ## .metric .estimator .estimate .config ## <chr> <chr> <dbl> <chr> ## 1 accuracy multiclass 0.807 Preprocessor1_Model1 ## 2 roc_auc hand_till 0.877 Preprocessor1_Model1
This result is pretty good for a single model; we would expect to do better by incorporating the breed
information, perhaps the presence/absence of a name, or moving to an ensembled model.
What features are most important for this xgboost model?
library(vip) ## use this fitted workflow `extract_workflow(stopping_fit)` to predict on new data extract_workflow(stopping_fit) %>% extract_fit_parsnip() %>% vip(num_features = 15, geom = "point")
Age, spay/neuter status, animal type, and seasonal information like week of the year or day of the week are important for this model.
We can collect the predictions on the testing set and do whatever we want, like create an ROC curve.
collect_predictions(stopping_fit) %>% roc_curve(outcome_type, .pred_adoption:.pred_transfer) %>% ggplot(aes(1 - specificity, sensitivity, color = .level)) + geom_abline(lty = 2, color = "gray80", size = 1.5) + geom_path(alpha = 0.8, size = 1) + coord_equal() + labs(color = NULL)
We can also look at a confusion matrix.
collect_predictions(stopping_fit) %>% conf_mat(outcome_type, .pred_class) %>% autoplot()
Early stopping is a great option when you have plenty of data and don’t want to overfit your boosted trees! I will be back on SLICED for the final four next Tuesday, and I plan to use early stopping again because it is a good fit for this kind of situation.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.