Multi-label Classification of Pokemon Types with TensorFlow
Create a Pokedex with deep learning
In honor of the release of the new Pokemon game, Pokemon Legends: Arceus, I thought it would be fun to do another Pokemon-themed data science project: train a neural network that takes images of Pokemon as input and outputs their types.
Primer
Before I go further, a brief primer on Pokemon for those unfamiliar. Pokemon are animal-like creatures that can be captured and trained to battle against other Pokemon. Each Pokemon has an elemental type or two that denotes what types of moves it can use are strongest, and what types of Pokemon it would have an advantage or disadvantage against in combat. For example, Pikachu, the iconic mascot of the Pokemon franchise, is an electric type. This means it has a weakness to ground-type Pokemon and an advantage against water and flying-type Pokemon. In the games, you can capture wild Pokemon and record data on their types and other information using an encyclopedia-like device called a Pokedex. One of the objectives of most games is to capture and record data on every Pokemon in the land, but how can we determine a Pokemon’s type just by looking at it? That’s the focus of this article.
Data
We will use sprite images of Pokemon across the various iterations, or generations, of Pokemon games from a Github repository called PokeAPI, that’s free to use, with the images property of Nintendo. We will also use a dataset from Kaggle that lists each Pokemon’s type and number in the national Pokedex, which we will use to index and label our images. There are 18 types in total. Previous readers may recall we used this dataset to determine the best team of Pokemon to use in two of the recent games.
The images are organized as follows:
sprites
- pokemon
- other
- dream world (SVGs)
- official artwork (PNGs)
- versions
- generation i
- red and blue (PNGs with back, gray, transparent, back-gray variants)
- yellow (PNGs with back, gbc, gray, transparent, back-gbc, back-gray, back-transparent variants)
- generation ii
- crystal (PNGs with back, shiny, back-shiny, transparent, transparent-shiny, back-transparent, back-transparent-shiny variants)
- gold (PNGs with back, shiny, transparent, back-shiny variants)
- silver (PNGs with back, shiny, transparent, back-shiny variants)
- generation iii
- emerald (PNGs with shiny variants)
- fire red and leaf green (PNGs with back, shiny, back-shiny variants)
- ruby and sapphire (PNGs with back, shiny, back-shiny variants)
- generation iv
- diamond and pearl (PNGs with back, female, shiny, back-female, back-shiny, shiny-female variants)
- heart gold and soul silver (PNGs with back, female, shiny, back-female, back-shiny, shiny-female variants)
- platinum (PNGs with back, female, shiny, back-female, back-shiny, shiny-female variants)
- generation v
- black and white (PNGs with back, female, shiny, back-female, back-shiny, shiny-female, animated variants)
- generation vi
- omega ruby and alpha sapphire (PNGs with female, shiny, shiny-female variants)
- x and y (PNGs with female, shiny, shiny-female variants)
- generation vii
- ultra sun and ultra moon (PNGs with female, shiny, shiny-female variants)
- icons (PNGs)
- generation viii
- icons (PNGs with female variants)
- default PokeAPI sprites (PNGs with back, female, shiny, back-female, back-shiny, shiny-female variants)
- items
- default PokeAPI items (PNGs)
Some key terms to clarify: "shiny" refers to Pokemon with an alternate coloring than their normal appearance. For example, instead of the iconic yellow Pikachu, shiny Pikachu has a sun-burnt orange coloring. Female Pokemon is self-explanatory: some Pokemon have different appearances based on their biological sex. Female Pikachu, for instance, have a heart-shaped tail instead of a lightning bolt tail. Finally, "back" refers to the back-sprite of Pokemon that the user wields during a Pokemon battle against an opponent, whose front sprite is shown. These different views of the same Pokemon will be very useful when we’re training our model because there are only 898 Pokemon as of 2021, which doesn’t make for a large dataset on its own. By including sprites from different games and generations, taking into account both front and back sprites, shiny Pokemon, and sexually dimorphic Pokemon, we have a much larger dataset, which will be very helpful in training our model. On that note, let’s transition into how we will use this data to train our classifier.
Machine Learning Approach
Since each Pokemon can have 1 or 2 typings, we will model this as a multi-label classification problem in machine learning. This is where we can assign more than one label to a target input. For example, classifying an image of a pair of denim pants as both ‘blue’ and ‘jeans’. This requires a different architecture than the simpler multi-class classification problem, where we have single mutually-exclusive classes to assign targets to (e.g., determining whether an animal is a cat or a dog in an image).
In the multi-label scenario, we define a set of labels – in this case, the 18 typings – to which our model assigns a probability per label, and can classify our target as multiple labels provided the label probabilities surpass a certain threshold (in our case, it will be 0.5). In many neural network architectures, we obtain a vector of raw output values. Since we have 18 possible types for Pokemon, our classifier will spew 18 outputs for each Pokemon image we feed into the model. These values are transformed into probabilities, allowing us to make predictions on the Pokemon’s type based on these values and the threshold. In cases where we’re just assigning a single class to each input (e.g., taking an image of a pet and determining whether it’s a cat or a dog), we often use a softmax function to transform these raw outputs.
This essentially applies an exponential transformation to the outputs and calculates probabilities by summing over the exponentiation of each of the raw outputs in the denominator, with the exponentiated target output in the numerator. This is similar to how we typically calculate probabilities of a single event, by considering the chance of all the other events. In our case, however, since some Pokemon have more than one typing, independent of each other, we need to use a different approach. This is where the sigmoid function comes in.
Like softmax, the raw outputs undergo an exponential transformation but note that these probabilities are calculated independent of one another (i.e., they need not sum to 1). So the probability of, say, Charizard – a Fire/Flying type Pokemon – being a Fire-type has no bearing on its probability of being a Flying-type, and vice-versa. Thus, we could have an 83% chance of it being Fire and a 74% chance of it being Flying. Therefore, we will use this function when evaluating our predictions and fine-tuning the model accordingly during the training phase.
The other consideration we need to take into account is how we will evaluate our model. For similar reasons as to why we can’t use the softmax function to derive probabilities, we can’t simply use the traditional accuracy metric as a means of how well our model is performing during the training and validation; accuracy is essentially the fraction of cases we classify correctly. Mathematically, this is defined as:
Like softmax, this is useful in cases where we have just one correct answer for each of our classes, but in scenarios where we have more than one like our Pokemon model, it can result in poor performance by heavily penalizing the loss function if it doesn’t get the typing exactly right. So back to our Charizard example, it wouldn’t award partial credit for guessing one of its types correctly: it’s all or nothing. This can be especially problematic when the distribution of labels is not uniform, as shown below:
Instead, we will use the F1-score metric. This is the harmonic mean of precision and recall, which measure the fraction of correctly-predicted positive cases out of all predicted positives and correctly-predicted positive cases out all actual positive cases, respectively:
By using information from both these metrics, we can derive a more reliable measure of how well our model is performing.
Two more things about the F1-score before we transition to the coding part of this article. First, the above F1 score is specific to each type. To obtain a global sense of how well our classifier is doing across all these classes, we can take an average of F1 scores. T[[[here](https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric)](https://towardsdatascience.com/the-unknown-benefits-of-using-a-soft-f1-loss-in-classification-systems-753902c0105d)](https://towardsdatascience.com/micro-macro-weighted-averages-of-f1-score-clearly-explained-b603420b292f) are several ways to do this, but you can read more here and here. For our case, we will take a global average of F1 scores, called a macro F1 score. This metric weighs all types equally, and thus, is not as influenced by class imbalance, which is a challenge with our data: we have an overrepresentation of Water and Normal types, for example, but few Ghost and Ice-type Pokemon. By weighing the types equally, we aren’t heavily influenced by the predominant typings in our dataset. Second, we will use a loss function called a soft-F1-score to evaluate our model performance. This is a modified version of the F1-score that’s made to be differentiable (an important pre-requisite for a loss function to allow for back-propagation). For sake of brevity, I won’t go into too much detail here, but you can read more about it here and here.
Model Implementation
We will use the TensorFlow library to program our neural network. I will preface that much of the model implementation was adapted from Ashref Maiza’s tutorial for multi-label classification with TensorFlow, in addition to other great resources that will be linked at the end of this article.
Let’s start by loading our libraries, which will be quite a few:
Next we will do some data cleaning on our dictionary of Pokemon types and their indices within the Pokedex.
Now, we can begin creating our training dataset. We will load in our images of Pokemon sprites from generations 1 through 5 using the get_images() function, which takes as input the folder for each generation, so we will run this function 5 times in total:
Recall that since the number of individual Pokemon is pretty small, we’re using Pokemon sprites from multiple generations and games, in addition to back sprites, shiny variants, and female Pokemon as a built-in way of augmenting the size of our training data. Through this, our training dataset consists of >15,000 images.
Next, we will binarize our types using one-hot encoding and convert our data to a format amenable for multi-label classification using the tf.data API.
Labels:
0. Bug
1. Dark
2. Dragon
3. Electric
4. Fairy
5. Fighting
6. Fire
7. Flying
8. Ghost
9. Grass
10. Ground
11. Ice
12. Normal
13. Poison
14. Psychic
15. Rock
16. Steel
17. Water
sprites/sprites/pokemon/versions/generation-iv/diamond-pearl/125.png [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
sprites/sprites/pokemon/versions/generation-i/yellow/back/gbc/67.png [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0]
sprites/sprites/pokemon/versions/generation-ii/silver/180.png [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
The Pokemon under sprites/sprites/pokemon/versions/generation-iv/diamond-pearl/125.png is an Electric type Pokemon called Electabuzz, so in the binarized type vector, it has a 1 at index 3 corresponding to that typing. Similarly, sprites/sprites/pokemon/versions/generation-i/yellow/back/67.png and sprites/sprites/pokemon/versions/generation-ii/silver/180.png are the Fighting-type, Machoke, and the Electric-type, Flaaffy, respectively, with their label vectors formatted accordingly. Finally, we will format this dataset such that the images and labels are linked to it in a manner that TensorFlow can process.
Now that our training and validation datasets have been prepared and formatted accordingly, we’re finally ready to create and train our model! Even though we were able to augment our dataset, we will use transfer learning to improve its performance and decrease the training time. With transfer learning, we’re using a pre-trained model on a much larger dataset – and one that need not be related to our own – as the backbone of our model to identify classes under a new context. One of the nice things about TensorFlow is that they have a rich database of pre-trained models you can easily import, especially for computer vision like in our case. The only thing we need to add on is a dense hidden layer and an output layer corresponding to each of the 18 Pokemon types, and then we can compile and train the model.
This took a little over 15 minutes to run on my PC, which has an NVIDIA GTX 1660 Super GPU. Looking at the training and validation performance, we don’t seem to have any significant overfitting, with final macro F1-scores of 0.5413 and 0.4612 for training and validation, respectively, though performance displays diminishing returns during the last few epochs for both the training and validation datasets, but nevertheless pretty good. Likewise, the final loss values for training and validation are 0.54 and 0.5976, respectively.
Now let’s test the model. I purposely trained it on the first 5 generations of Pokemon, because the dataset from Kaggle goes up to generation 6, so I was curious how well it would perform on newer Pokemon without having seen their images in neither the training nor validation phase. I will also add that from generation 6 and onwards, Pokemon shifted from regular 2D sprites to 3D models of Pokemon, so it will be all the more interesting to see how our model fares with this data it hasn’t seen before.
Delphox
Type
['Fire', 'Psychic']
Prediction
['Fire', 'Psychic']
Clauncher
Type
['Water']
Prediction
['Water']
Noivern
Type
['Flying', 'Dragon']
Prediction
['Dragon', 'Flying', 'Psychic', 'Water']
Quilladin
Type
['Grass']
Prediction
['Grass']
Gogoat
Type
['Grass']
Prediction
[]
Hawlucha
Type
['Fighting', 'Flying']
Prediction
['Flying', 'Psychic']
Goomy
Type
['Dragon']
Prediction
['Water']
Sylveon
Type
['Fairy']
Prediction
['Flying', 'Water']
We see that our classifier does a fairly good job of predicting the types of these Pokemon. I was particularly impressed that it got Delphox’s (a fox-like Pokemon) types perfectly. It struggled with correctly predicting Sylveon’s Fairy type and gave up with Gogoat since none of the type probabilities exceeded 0.5 (I will note, however, that Fairy was first added as the 18th type in this generation, with a few old Pokemon either retroactively having their type changed from Normal to Fairy, or having it tacked on to their initial mono-typing). Let’s take a look at some older Pokemon like Pikachu and Charizard:
Pikachu
Type
['Electric']
Prediction
['Water']
Charizard
Type
['Fire', 'Flying']
Prediction
['Flying', 'Normal', 'Psychic']
Interestingly, Pikachu is predicted to be a Psychic-type, likely owing to the fact that it’s a very common type as shown in the type barplot (plus Psychic Pokemon vary substantially in color compared to other types); and Charizard’s Fire typing doesn’t pass the muster of our model. This could also be due to the fact that we didn’t use any 3D models of older Pokemon in the training, so the different format could have also thrown our model off a bit.
In this article, we learned how to train a neural network to predict Pokemon types, taking advantage of the diversity of sprites and forms across various iterations of games over the years. We obtained fairly good prediction on Pokemon from newer games with 3D models. There’s still plenty of room for improvement, nonetheless. For example, subsampling could be performed to curb the class imbalance problem and further reduce our loss function. We could also evaluate model performance using different loss functions, such as binary cross-entropy instead of the soft F1 loss, which is more commonly employed in multi-label classification problems. There is also a recent preprint on arXiv that introduces another loss function derived from the F1 score for multi-label classification called sigmoidF1 that may hold promise.
Nevertheless, I hope this was fun introduction to multi-label classification with Pokemon. Let me know in the comments any thoughts or questions you have about the model, or even other suggestions to improve performance. Happy coding!
References
[1] https://towardsdatascience.com/multi-label-image-classification-in-tensorflow-2-0-7d4cf8a4bc72 [2] https://github.com/PokeAPI/sprites [3] https://www.kaggle.com/abcsds/pokemon [4] https://glassboxmedicine.com/2019/05/26/classification-sigmoid-vs-softmax/ [5] https://medium.com/arteos-ai/the-differences-between-sigmoid-and-softmax-activation-function-12adee8cf322 [6] https://github.com/ashrefm/multi-label-soft-f1 [7] https://towardsdatascience.com/the-unknown-benefits-of-using-a-soft-f1-loss-in-classification-systems-753902c0105d [8] https://arxiv.org/pdf/2108.10566.pdf
Share This Article
Towards Data Science is a community publication. Submit your insights to reach our global audience and earn through the TDS Author Payment Program.
Write for TDS