Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
This analysis introduces the K-Nearest Neighbor (KNN) machine learning algorithm using the familiar Pokemon dataset. By the end of this blog post you should have an understanding of the following:
- What the KNN machine learning algorithm is
- How to program the algorithm in R
- A bit more about Pokemon
If you would like to follow along, you can download the dataset from Kaggle.
What are Pokemon?
Pokemon are used to battle one another and come with a variety of different battle stats like: attack, defense, speed, special attack, special defense, and hit points. Pokemon can also be further classified into a variety of different types (such as: fire, water, grass, electric, ice, ghost, dark, fairy, ground, etc…). The stronger a pokemon is, the more likely it will win in battle. There are many nuances to Pokemon (and this is only a basic summary) but hopefully the specifics of what Pokemon are do not hinder your understanding of KNN (the central point of this post).
Question of interest
Do Pokemon with similar battle stats (attack, defense, speed, etc…) tend to congregate by type (fire, water, grass, etc…) and if so, can we guess a Pokemon’s type based on its battle stats alone?
Disclaimer: This is a pretty quick analysis used for illustrative purposes.
K-Nearest Neighbors (KNN) – by example
For the purpose of this blog post, we define KNN as a popular supervised machine learning algorithm for classifying data. In colloquial speak, the algorithm uses similarity among data points to determine the group membership of a new data point.
An example: Let’s say you have 5 pokemon and let’s suppose in the realm of Pokemon there are only two types: fire and water. It turns out (from our sampled data) that all our fire type pokemon (the first three: growlithe, flareon, and ponyta) are red and the remaining water pokemon (poliwag, vaporeon) are blue. The pokemon (and a crude line plot) are as follows:
These five pokemon will now become our training data. This means, we accept them to be who they are and know that their classification is truthful (those who are fire are actually fire and those as water are water). If I now showed you a new pokemon (vulpix) – let’s call this the test data -, could you figure out which group it belonged to? Let’s add vulpix to our line plot of Pokemon based on its coloring:
Intuitively you might classify vulpix as fire because it is red. That’s great logic! Now, I challenge you to slightly tweak your thinking just a bit: what if you decided to classify this vulpix as fire because it is most similar to its closest neighbor(s) on the line plot (maybe that’s what you did). In this case, vulpix is closest to growlithe, flareon, and ponyta and furthest from poliwag and vaporeon. Because its neighbors are fire, vulpix is now fire. This is the base thinking behind KNN – how close am I to my neighbors and whatever they are classified as is what I’m classified as.
Seems straightforward, right? Now what if I asked you to now classify the following Pokemon (tentacool):
This is tricky, because tentacool here is both red and blue; what do we do? Well, it looks like there’s a bit more blue on our Pokemon then red, so let’s plot tentacool and budge them slightly to the more blue side on our plot:
So now let’s think: who are the closest neighbors to tentacool? Well, if we only look to one neighbor, then it’s poliwag and if we look at two neighbors, then it’s both poliwag and ponyta. Extending to three neighbors then introduces vaporeon and four neighbors brings in flareon. A visual of this can be seen here with one neighbor (red), two neighbors (green) and three neighbors (blue):
In this instance, how we classify our tentacool will determine how many of neighbors we poll. If we choose one neighbor, we will say this pokemon is water. If we choose two, then we are split 50/50 and flip a coin (and let chance decide!). If three neighbors are selected, then we have 2 votes for water, and one vote for fire. Let’s say we choose 3 neighbors; in this instance, we will classify tentacool as water.
Note: At this point you may be wondering… what happened to our vulpix (the first pokemon we classified)? Why did we not include them on our new plot to help classify tentacool? This speaks to a larger topic of training and testing data but for now, understand that we don’t actually know the classification of vulpix. We made an educated guess to its classification. The classification of vulpix (for our example) is not 100% known and as such, will not be included as training data for making decisions about future pokemon (like tentacool).
In this basic example we saw that if we first plotted our training data based on an attribute (color) we could make a guess at the type of an unknown pokemon (test data). Obviously color alone may not be a great indicator for pokemon type (though, it does surprisingly well) so it might be worth looking into other factors as well: speed, attack, defense, hit points, etc… What we are going to figure out in the following analysis is if we can use two or more attributes of a pokemon along with KNN to help determine pokemon type.
You might have questions about the following (which will be addressed throughout the post):
What determines “how close” neighbors are?
How do I pick the attributes to cluster pokemon based on?
How many neighbors do I select?
That’s OK! These are questions that I hope to cover throughout this post.
Analysis
For this analysis we’ll need the following libraries and set the seed:
library("dplyr") library("class") library("ggplot2") library("GGally") library("caret") set.seed(12345)
First, we will import the dataset:
pokemon_data <- read.csv("../../blog_datasets/knn_pokemon/pokemon.csv")
There are a bunch of columns in this dataset (41); so that we can focus our understanding on KNN, let’s only select a few of the columns (name,attack, defense, sp_attack, sp_defense, speed, hp, type1, type2). These columns selected are simply the attributes discussed above. We can preview the data frame as follows:
reduced_pokemon_data <- pokemon_data %>% select(name,attack, defense, sp_attack, sp_defense, speed, hp, type1, type2) head(reduced_pokemon_data) ## name attack defense sp_attack sp_defense speed hp type1 type2 ## 1 Bulbasaur 49 49 65 65 45 45 grass poison ## 2 Ivysaur 62 63 80 80 60 60 grass poison ## 3 Venusaur 100 123 122 120 80 80 grass poison ## 4 Charmander 52 43 60 50 65 39 fire ## 5 Charmeleon 64 58 80 65 80 58 fire ## 6 Charizard 104 78 159 115 100 78 fire flying
For now, these are the columns that we will be working with. In our dataset, there are 801 unique pokemon and 18 different types. To make the analysis a bit more straightforward, we will only be looking at 5 different types of pokemon: bug, dragon, fighting, electric, and normal. I’ll further reduce the dataset to only look at pokemon with just one type (note: pokemon can have up to two types: for instance: fire/psychic, ground/dragon, etc…):
final_dataset <- reduced_pokemon_data %>% filter( type2 == "" & type1 %in% c( "bug", "dragon", "fighting", "electric", "normal"))
These adjustments drastically bring our dataset down from 801 instances to only 139. Alas, I’ll remind the reader that the purpose of this blog post is to explain KNN, not provide a rigorous analysis of the pokemon dataset. With that said, let’s drop our type2 column (since all pokemon in our data are just one type) and factorize our type1 column. We’re then going to break out dataset into a training, validation (to be covered below), and test set:
final_dataset <- final_dataset %>% select(-c(type2)) %>% mutate_at(vars(type1), ~(factor(type1))) random_rows <- sample(1:nrow(final_dataset), nrow(final_dataset) * .75) training_data <- final_dataset[random_rows, ] # Renumber the rows row.names(training_data) <- 1:nrow(training_data) testing_data <- final_dataset[-random_rows, ] random_rows_validation <- sample(1:nrow(training_data), nrow(training_data) * .30) validation_data <- training_data[random_rows_validation, ] training_data <- training_data[-random_rows_validation, ]
Specifically for this first analysis we are interested in looking at the attack and speed stats of our pokemon; a plot may look as follows:
ggplot(data = training_data, aes(x = attack, y = speed, col = type1)) + geom_point(size = 3) + ggtitle("Attack versus Speed for Pokemon of a Single Type (training)")
From the plot we can see a few pockets where it looks like clusters of pokemon are gathering together (for instance, bug pokemon tend to congregate with a lower attack and lower speed, and electric pokemon tend to have low attack and high speed). Let’s include our test pokemon (the ones we want to classify into the various groupings):
combined_training_test <- rbind(training_data, testing_data) combined_training_test$type1 <- factor(c(as.character(training_data$type1), rep(NA, nrow(testing_data)))) ggplot(data = combined_training_test, aes(x = attack, y = speed, col = type1)) + geom_point(size = 3) + ggtitle("Attack versus Speed for Pokemon of a Single Type (training and testing)")
To remind the reader of our goal here: we would like to classify each of these gray points to one of the 5 labeled Pokemon types. The algorithm will do this by finding the k-closest neighbors to each of the gray data points and classify them based on the voting method described earlier.
To answer our first question above, to determine which points are closest to our unknown test data the KNN algorithm will simply use the euclidean distance to determine which training data points are closest to it.
Next, we will tackle the second question: “How do I pick the attributes in my model?” Well, truth be told, KNN can take an infinite number of attributes, not just two! So I could show the plotting easier, I decided to only look at two attribuets. To choose which attributes I selected, I looked at a series of scatter plots and determined which two attributes (visually) produced the best groupings of data. I found this nice package (GGally) that helps visualize pairs of our columns in a nice way. Let’s take a look at a series of scatter plots for pairwise comparisons of variables:
pairs(final_dataset[, 2:8],col = training_data$type1,lower.panel=NULL)
As you can see above, the speed/attack scatter plot appeared to show the best separation among points. You could technically use any number of combinations of these points and use the KNN algorithm. At the end of this analysis I’ll show results when using all attributes. But first, let me subset our training data to just show the KNN algorithm with only speed and attack used:
training_data_speed_attack <- training_data %>% select(c(speed,attack)) testing_data_speed_attack <- testing_data %>% select(c(speed,attack))
Great, now that we have our attributes selected, we will program the algorithm. As far as machine learning algorithms go, KNN is quite simple in terms of the function call:
knn_attack_speed <- knn(train = training_data_speed_attack, test = testing_data_speed_attack, cl = training_data$type1, k = 5)
Here is what the various parameters to our function call mean:
train: A matrix or data frame of only the attributes for the training data
test: A matrix or data frame of only the attributes for the testing data
cl: A list of the classifications for the training data (remember, we are trying to predict the pokemon types for our test data)
k: The number of neighbors we’d like to “consult” to determine the type of the new test data pokemon (I will get to why I chose 5 below)
knn_attack_speed: What we decided to name our function call. This function only returns the predicted class labels for our testing data.
Now that our algorithm has run, one of the best ways to check how our algorithm performed is to construct a confusion matrix. The confusion matrix is a nice visual that will allow us to see how our predicted class labels compared to our actual class labels. The confusion matrix is as follows:
confusionMatrix(knn_attack_speed,testing_data$type1 ) ## Confusion Matrix and Statistics ## ## Reference ## Prediction bug dragon electric fighting normal ## bug 2 0 1 0 1 ## dragon 0 0 0 1 0 ## electric 0 0 1 0 1 ## fighting 1 0 0 1 1 ## normal 2 2 3 7 11 ## ## Overall Statistics ## ## Accuracy : 0.4286 ## 95% CI : (0.2632, 0.6065) ## No Information Rate : 0.4 ## P-Value [Acc > NIR] : 0.4272 ## ## Kappa : 0.1422 ## ## Mcnemar's Test P-Value : NA ## ## Statistics by Class: ## ## Class: bug Class: dragon Class: electric Class: fighting ## Sensitivity 0.40000 0.00000 0.20000 0.11111 ## Specificity 0.93333 0.96970 0.96667 0.92308 ## Pos Pred Value 0.50000 0.00000 0.50000 0.33333 ## Neg Pred Value 0.90323 0.94118 0.87879 0.75000 ## Prevalence 0.14286 0.05714 0.14286 0.25714 ## Detection Rate 0.05714 0.00000 0.02857 0.02857 ## Detection Prevalence 0.11429 0.02857 0.05714 0.08571 ## Balanced Accuracy 0.66667 0.48485 0.58333 0.51709 ## Class: normal ## Sensitivity 0.7857 ## Specificity 0.3333 ## Pos Pred Value 0.4400 ## Neg Pred Value 0.7000 ## Prevalence 0.4000 ## Detection Rate 0.3143 ## Detection Prevalence 0.7143 ## Balanced Accuracy 0.5595
The confusion matrix can be read as follows: When the actual class of the pokemon was bug (5 instances), KNN predicted: bug (2 times), fighting (1 time), and normal (2 times).
Overall from the output we can see our model did just OK! With an accuracy of 42.86% we can certainly say our model performed better than just rolling a 5 sided die (which would have resulted in an accuracy of 20.00% (100%/5 classes)). However, before we go any further, let’s talk about how we determined what K should be. To do this, we used cross validation (to be covered in another blog post). You’ll remember we partitioned our dataset into a training, testing, and validation set – this is where the validation set comes into play. We’ll use the validation set to appropriately select our level of K for our training data which we will then evaluate on our test data.
trControl <- trainControl(method = "cv", number = 10) fit <- train(type1 ~ speed + attack, method = "knn", tuneGrid = expand.grid(k = 1:5), trControl = trControl, metric = "Accuracy", data = validation_data) fit ## k-Nearest Neighbors ## ## 31 samples ## 2 predictor ## 5 classes: 'bug', 'dragon', 'electric', 'fighting', 'normal' ## ## No pre-processing ## Resampling: Cross-Validated (10 fold) ## Summary of sample sizes: 28, 29, 29, 27, 26, 28, ... ## Resampling results across tuning parameters: ## ## k Accuracy Kappa ## 1 0.3350000 -0.01774147 ## 2 0.4433333 0.13273810 ## 3 0.3516667 0.05808913 ## 4 0.4566667 0.19089390 ## 5 0.5066667 0.24827264 ## ## Accuracy was used to select the optimal model using the largest value. ## The final value used for the model was k = 5.
As we can see, the optimal number of folds on our validation set was 5 (based on the measurement of accuracy).
Analysis – Part 2
Just for fun, what if we decided to use all our attributes of the Pokemon rather than only looking at speed and attack? How would that change our analysis? First we determine our optimal k:
# Determine the optimal K fit_all_attributes <- train(type1 ~ speed + attack + defense + hp + sp_attack + sp_defense, method = "knn", tuneGrid = expand.grid(k = 1:5), trControl = trControl, metric = "Accuracy", data = validation_data) fit_all_attributes ## k-Nearest Neighbors ## ## 31 samples ## 6 predictor ## 5 classes: 'bug', 'dragon', 'electric', 'fighting', 'normal' ## ## No pre-processing ## Resampling: Cross-Validated (10 fold) ## Summary of sample sizes: 29, 29, 28, 27, 27, 28, ... ## Resampling results across tuning parameters: ## ## k Accuracy Kappa ## 1 0.4833333 0.1915584 ## 2 0.5000000 0.2445887 ## 3 0.5416667 0.2743590 ## 4 0.5583333 0.2772727 ## 5 0.5333333 0.2300000 ## ## Accuracy was used to select the optimal model using the largest value. ## The final value used for the model was k = 4.
Then we run the model:
# Perform the KNN training_data_all <- training_data %>% select(c(speed,attack, defense, hp, sp_attack, sp_defense)) testing_data_all <- testing_data %>% select(c(speed,attack, defense, hp, sp_attack, sp_defense)) knn_all <- knn(train = training_data_all, test = testing_data_all, cl = training_data$type1, k = 4)
Finally, we show the confusion matrix:
# What does our confusion matrix look like? confusionMatrix(knn_all,testing_data$type1 ) ## Confusion Matrix and Statistics ## ## Reference ## Prediction bug dragon electric fighting normal ## bug 2 0 0 0 0 ## dragon 2 1 0 0 0 ## electric 0 0 3 0 0 ## fighting 1 1 0 1 0 ## normal 0 0 2 8 14 ## ## Overall Statistics ## ## Accuracy : 0.6 ## 95% CI : (0.4211, 0.7613) ## No Information Rate : 0.4 ## P-Value [Acc > NIR] : 0.01326 ## ## Kappa : 0.4103 ## ## Mcnemar's Test P-Value : NA ## ## Statistics by Class: ## ## Class: bug Class: dragon Class: electric Class: fighting ## Sensitivity 0.40000 0.50000 0.60000 0.11111 ## Specificity 1.00000 0.93939 1.00000 0.92308 ## Pos Pred Value 1.00000 0.33333 1.00000 0.33333 ## Neg Pred Value 0.90909 0.96875 0.93750 0.75000 ## Prevalence 0.14286 0.05714 0.14286 0.25714 ## Detection Rate 0.05714 0.02857 0.08571 0.02857 ## Detection Prevalence 0.05714 0.08571 0.08571 0.08571 ## Balanced Accuracy 0.70000 0.71970 0.80000 0.51709 ## Class: normal ## Sensitivity 1.0000 ## Specificity 0.5238 ## Pos Pred Value 0.5833 ## Neg Pred Value 1.0000 ## Prevalence 0.4000 ## Detection Rate 0.4000 ## Detection Prevalence 0.6857 ## Balanced Accuracy 0.7619
And just like that, we’ve increased our accuracy to 60%! That’s pretty good for a 5 class classification! As we can see, including more attributes helps with the classification of our pokemon (though, I advise you to look at our classification of fighting pokemon – our model categorized 8 out of 9 fighting pokemon as type normal- that’s a problem!).
Ideally, we would want our model to have 100% accuracy on our test set however this is rarely (if ever?) the case. Now with our training model we could go out into the wild and classify a new pokemon’s type based on its attributes (knowing we have achieved a 60% accuracy on test data). Pretty cool, right?
Conclusion
In this blog post we’ve looked at the KNN algorithm using the Pokemon dataset. We covered the following in this blog post:
- What Pokemon are
- What KNN is
- How the KNN algorithm works
- How to determine attributes (or features) for measuring KNN
- Splitting the data into: training, testing, and validation sets
- How to pick the best K
There’s lots more information to unpack in this blog post (other metrics of “success” such as: precision, recall, positive predictive value, etc…, sample size, filtering data, using more class labels, etc…) however, this blog post provides a gentle introduction to the KNN algorithm. Perhaps a future blog post will use all of the pokemon types and try to generate a model with even larger predictive accuracy.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.