Explaining Keras image classification models with lime
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
What I did not show in that post was how to use the model for making predictions. This, I will do here. But predictions alone are boring, so I’m adding explanations for the predictions using the lime
package.
I have already written a few blog posts (here, here and here) about LIME and have given talks (here and here) about it, too.
Neither of them applies LIME to image classification models, though. And with the new(ish) release from March of Thomas Lin Pedersen’s lime
package, lime
is now not only on CRAN but it natively supports Keras and image classification models.
Thomas wrote a very nice article about how to use keras
and lime
in R! Here, I am following this article to use Imagenet (VGG16) to make and explain predictions of fruit images and then I am extending the analysis to last week’s model and compare it with the pretrained net.
Loading libraries and models
library(keras) # for working with neural nets library(lime) # for explaining models library(magick) # for preprocessing images library(ggplot2) # for additional plotting
- Loading the pretrained Imagenet model
model <- application_vgg16(weights = "imagenet", include_top = TRUE) model ## Model ## ___________________________________________________________________________ ## Layer (type) Output Shape Param # ## =========================================================================== ## input_1 (InputLayer) (None, 224, 224, 3) 0 ## ___________________________________________________________________________ ## block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 ## ___________________________________________________________________________ ## block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 ## ___________________________________________________________________________ ## block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 ## ___________________________________________________________________________ ## block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 ## ___________________________________________________________________________ ## block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 ## ___________________________________________________________________________ ## block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 ## ___________________________________________________________________________ ## block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 ## ___________________________________________________________________________ ## block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 ## ___________________________________________________________________________ ## block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 ## ___________________________________________________________________________ ## block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 ## ___________________________________________________________________________ ## block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 ## ___________________________________________________________________________ ## block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 ## ___________________________________________________________________________ ## block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 ## ___________________________________________________________________________ ## block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 ## ___________________________________________________________________________ ## block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 ## ___________________________________________________________________________ ## block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 ## ___________________________________________________________________________ ## block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 ## ___________________________________________________________________________ ## block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 ## ___________________________________________________________________________ ## flatten (Flatten) (None, 25088) 0 ## ___________________________________________________________________________ ## fc1 (Dense) (None, 4096) 102764544 ## ___________________________________________________________________________ ## fc2 (Dense) (None, 4096) 16781312 ## ___________________________________________________________________________ ## predictions (Dense) (None, 1000) 4097000 ## =========================================================================== ## Total params: 138,357,544 ## Trainable params: 138,357,544 ## Non-trainable params: 0 ## ___________________________________________________________________________
- loading my own model from last week’s post
model2 <- load_model_hdf5(filepath = "/Users/shiringlander/Documents/Github/DL_AI/Tutti_Frutti/fruits-360/keras/fruits_checkpoints.h5") model2 ## Model ## ___________________________________________________________________________ ## Layer (type) Output Shape Param # ## =========================================================================== ## conv2d_1 (Conv2D) (None, 20, 20, 32) 896 ## ___________________________________________________________________________ ## activation_1 (Activation) (None, 20, 20, 32) 0 ## ___________________________________________________________________________ ## conv2d_2 (Conv2D) (None, 20, 20, 16) 4624 ## ___________________________________________________________________________ ## leaky_re_lu_1 (LeakyReLU) (None, 20, 20, 16) 0 ## ___________________________________________________________________________ ## batch_normalization_1 (BatchNorm (None, 20, 20, 16) 64 ## ___________________________________________________________________________ ## max_pooling2d_1 (MaxPooling2D) (None, 10, 10, 16) 0 ## ___________________________________________________________________________ ## dropout_1 (Dropout) (None, 10, 10, 16) 0 ## ___________________________________________________________________________ ## flatten_1 (Flatten) (None, 1600) 0 ## ___________________________________________________________________________ ## dense_1 (Dense) (None, 100) 160100 ## ___________________________________________________________________________ ## activation_2 (Activation) (None, 100) 0 ## ___________________________________________________________________________ ## dropout_2 (Dropout) (None, 100) 0 ## ___________________________________________________________________________ ## dense_2 (Dense) (None, 16) 1616 ## ___________________________________________________________________________ ## activation_3 (Activation) (None, 16) 0 ## =========================================================================== ## Total params: 167,300 ## Trainable params: 167,268 ## Non-trainable params: 32 ## ___________________________________________________________________________
Load and prepare images
Here, I am loading and preprocessing two images of fruits (and yes, I am cheating a bit because I am choosing images where I expect my model to work as they are similar to the training images…).
- Banana
test_image_files_path <- "/Users/shiringlander/Documents/Github/DL_AI/Tutti_Frutti/fruits-360/Test" img <- image_read('https://upload.wikimedia.org/wikipedia/commons/thumb/8/8a/Banana-Single.jpg/272px-Banana-Single.jpg') img_path <- file.path(test_image_files_path, "Banana", 'banana.jpg') image_write(img, img_path) #plot(as.raster(img))
- Clementine
img2 <- image_read('https://cdn.pixabay.com/photo/2010/12/13/09/51/clementine-1792_1280.jpg') img_path2 <- file.path(test_image_files_path, "Clementine", 'clementine.jpg') image_write(img2, img_path2) #plot(as.raster(img2))
Superpixels
The segmentation of an image into superpixels are an important step in generating explanations for image models. It is both important that the segmentation is correct and follows meaningful patterns in the picture, but also that the size/number of superpixels are appropriate. If the important features in the image are chopped into too many segments the permutations will probably damage the picture beyond recognition in almost all cases leading to a poor or failing explanation model. As the size of the object of interest is varying it is impossible to set up hard rules for the number of superpixels to segment into - the larger the object is relative to the size of the image, the fewer superpixels should be generated. Using plot_superpixels it is possible to evaluate the superpixel parameters before starting the time consuming explanation function. (help(plot_superpixels))
plot_superpixels(img_path, n_superpixels = 35, weight = 10)
plot_superpixels(img_path2, n_superpixels = 50, weight = 20)
From the superpixel plots we can see that the clementine image has a higher resolution than the banana image.
Prepare images for Imagenet
image_prep <- function(x) { arrays <- lapply(x, function(path) { img <- image_load(path, target_size = c(224,224)) x <- image_to_array(img) x <- array_reshape(x, c(1, dim(x))) x <- imagenet_preprocess_input(x) }) do.call(abind::abind, c(arrays, list(along = 1))) }
- test predictions
res <- predict(model, image_prep(c(img_path, img_path2))) imagenet_decode_predictions(res) ## [[1]] ## class_name class_description score ## 1 n07753592 banana 0.9929747581 ## 2 n03532672 hook 0.0013420776 ## 3 n07747607 orange 0.0010816186 ## 4 n07749582 lemon 0.0010625814 ## 5 n07716906 spaghetti_squash 0.0009176208 ## ## [[2]] ## class_name class_description score ## 1 n07747607 orange 0.78233224 ## 2 n07753592 banana 0.04653566 ## 3 n07749582 lemon 0.03868873 ## 4 n03134739 croquet_ball 0.03350329 ## 5 n07745940 strawberry 0.01862431
- load labels and train explainer
model_labels <- readRDS(system.file('extdata', 'imagenet_labels.rds', package = 'lime')) explainer <- lime(c(img_path, img_path2), as_classifier(model, model_labels), image_prep)
Training the explainer (explain()
function) can take pretty long. It will be much faster with the smaller images in my own model but with the bigger Imagenet it takes a few minutes to run.
explanation <- explain(c(img_path, img_path2), explainer, n_labels = 2, n_features = 35, n_superpixels = 35, weight = 10, background = "white")
plot_image_explanation()
only supports showing one case at a time
plot_image_explanation(explanation)
clementine <- explanation[explanation$case == "clementine.jpg",] plot_image_explanation(clementine)
Prepare images for my own model
- test predictions (analogous to training and validation images)
test_datagen <- image_data_generator(rescale = 1/255) test_generator = flow_images_from_directory( test_image_files_path, test_datagen, target_size = c(20, 20), class_mode = 'categorical') predictions <- as.data.frame(predict_generator(model2, test_generator, steps = 1)) load("/Users/shiringlander/Documents/Github/DL_AI/Tutti_Frutti/fruits-360/fruits_classes_indices.RData") fruits_classes_indices_df <- data.frame(indices = unlist(fruits_classes_indices)) fruits_classes_indices_df <- fruits_classes_indices_df[order(fruits_classes_indices_df$indices), , drop = FALSE] colnames(predictions) <- rownames(fruits_classes_indices_df) t(round(predictions, digits = 2)) ## [,1] [,2] ## Kiwi 0 0.00 ## Banana 1 0.11 ## Apricot 0 0.00 ## Avocado 0 0.00 ## Cocos 0 0.00 ## Clementine 0 0.87 ## Mandarine 0 0.00 ## Orange 0 0.00 ## Limes 0 0.00 ## Lemon 0 0.00 ## Peach 0 0.00 ## Plum 0 0.00 ## Raspberry 0 0.00 ## Strawberry 0 0.01 ## Pineapple 0 0.00 ## Pomegranate 0 0.00 for (i in 1:nrow(predictions)) { cat(i, ":") print(unlist(which.max(predictions[i, ]))) } ## 1 :Banana ## 2 ## 2 :Clementine ## 6
This seems to be incompatible with lime, though (or if someone knows how it works, please let me know) - so I prepared the images similarly to the Imagenet images.
image_prep2 <- function(x) { arrays <- lapply(x, function(path) { img <- image_load(path, target_size = c(20, 20)) x <- image_to_array(img) x <- reticulate::array_reshape(x, c(1, dim(x))) x <- x / 255 }) do.call(abind::abind, c(arrays, list(along = 1))) }
- prepare labels
fruits_classes_indices_l <- rownames(fruits_classes_indices_df) names(fruits_classes_indices_l) <- unlist(fruits_classes_indices) fruits_classes_indices_l ## 9 10 8 2 11 ## "Kiwi" "Banana" "Apricot" "Avocado" "Cocos" ## 3 13 14 7 6 ## "Clementine" "Mandarine" "Orange" "Limes" "Lemon" ## 1 5 0 4 15 ## "Peach" "Plum" "Raspberry" "Strawberry" "Pineapple" ## 12 ## "Pomegranate"
- train explainer
explainer2 <- lime(c(img_path, img_path2), as_classifier(model2, fruits_classes_indices_l), image_prep2) explanation2 <- explain(c(img_path, img_path2), explainer2, n_labels = 1, n_features = 20, n_superpixels = 35, weight = 10, background = "white")
- plot feature weights to find a good threshold for plotting
block
(see below)
explanation2 %>% ggplot(aes(x = feature_weight)) + facet_wrap(~ case, scales = "free") + geom_density()
- plot predictions
plot_image_explanation(explanation2, display = 'block', threshold = 5e-07)
clementine2 <- explanation2[explanation2$case == "clementine.jpg",] plot_image_explanation(clementine2, display = 'block', threshold = 0.16)
sessionInfo() ## R version 3.5.0 (2018-04-23) ## Platform: x86_64-apple-darwin15.6.0 (64-bit) ## Running under: macOS High Sierra 10.13.5 ## ## Matrix products: default ## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib ## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib ## ## locale: ## [1] de_DE.UTF-8/de_DE.UTF-8/de_DE.UTF-8/C/de_DE.UTF-8/de_DE.UTF-8 ## ## attached base packages: ## [1] stats graphics grDevices utils datasets methods base ## ## other attached packages: ## [1] ggplot2_2.2.1 magick_1.9 lime_0.4.0 keras_2.1.6 ## ## loaded via a namespace (and not attached): ## [1] stringdist_0.9.5.1 reticulate_1.8 xfun_0.2 ## [4] lattice_0.20-35 colorspace_1.3-2 htmltools_0.3.6 ## [7] yaml_2.1.19 base64enc_0.1-3 rlang_0.2.1 ## [10] pillar_1.2.3 later_0.7.3 foreach_1.4.4 ## [13] plyr_1.8.4 tensorflow_1.8 stringr_1.3.1 ## [16] munsell_0.5.0 blogdown_0.6 gtable_0.2.0 ## [19] htmlwidgets_1.2 codetools_0.2-15 evaluate_0.10.1 ## [22] labeling_0.3 knitr_1.20 httpuv_1.4.4.1 ## [25] tfruns_1.3 parallel_3.5.0 curl_3.2 ## [28] Rcpp_0.12.17 xtable_1.8-2 scales_0.5.0 ## [31] backports_1.1.2 promises_1.0.1 jsonlite_1.5 ## [34] abind_1.4-5 mime_0.5 digest_0.6.15 ## [37] stringi_1.2.3 bookdown_0.7 shiny_1.1.0 ## [40] grid_3.5.0 rprojroot_1.3-2 tools_3.5.0 ## [43] magrittr_1.5 lazyeval_0.2.1 shinythemes_1.1.1 ## [46] glmnet_2.0-16 tibble_1.4.2 whisker_0.3-2 ## [49] zeallot_0.1.0 Matrix_1.2-14 gower_0.1.2 ## [52] assertthat_0.2.0 rmarkdown_1.10 iterators_1.0.9 ## [55] R6_2.2.2 compiler_3.5.0
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.