I am a fan of pokemon, and remember playing generation one when I was a kid. I could exactly call out the name every pokemon, it's type and many other stats. I feel it would be interesting with deep learning, how far a model can accurately classify 150 gen-one pokemons. I feel human accuracy should be more than 99%

from fastai.vision.all import *

Download datasets and clean

I searched online and found this 7000 labeled dataset from kaggle, credit to Lance Zhang for prepapring and cleaning it.

path = Path('PokemonData')
path.ls()
(#150) [Path('PokemonData/Weezing'),Path('PokemonData/Magneton'),Path('PokemonData/Horsea'),Path('PokemonData/Rhydon'),Path('PokemonData/Meowth'),Path('PokemonData/Weedle'),Path('PokemonData/Machoke'),Path('PokemonData/Ivysaur'),Path('PokemonData/Vulpix'),Path('PokemonData/Snorlax')...]
Path.BASE_PATH = path
(path/"Weezing").ls()
(#50) [Path('Weezing/d7759ca041a54e40bcd0e5983593a398.jpg'),Path('Weezing/91e6d87ec73143f6870b49711c64916b.jpg'),Path('Weezing/df575a5993254fbca9cdea90dd91588d.jpg'),Path('Weezing/dbb9cf5baf0d4db7aa4106fe500713fc.jpg'),Path('Weezing/55da5238b2a04623b61594716a59bf6e.jpg'),Path('Weezing/2341ef30aa36401f8abefcb4cf41556d.jpg'),Path('Weezing/51bfdcd289a04e5db15f5c4ba883c17c.jpg'),Path('Weezing/1cdbf5d1e44840a88af213c3d5db7e65.jpg'),Path('Weezing/6c898f7248c6485d9a23efc7196a96ca.jpg'),Path('Weezing/ec8c7d0c9f39409e89feb10ec5332f0f.jpg')...]

Let's take some examples to see how the image look like.

example = Image.open((path/"Weezing").ls()[0])
example

Note that the shape of images are different.

example.shape
(360, 413)
example1 = Image.open((path/"Weezing").ls()[1])
example1.shape
(881, 984)
fns = get_image_files(path)
fns
(#6820) [Path('Weezing/d7759ca041a54e40bcd0e5983593a398.jpg'),Path('Weezing/91e6d87ec73143f6870b49711c64916b.jpg'),Path('Weezing/df575a5993254fbca9cdea90dd91588d.jpg'),Path('Weezing/dbb9cf5baf0d4db7aa4106fe500713fc.jpg'),Path('Weezing/55da5238b2a04623b61594716a59bf6e.jpg'),Path('Weezing/2341ef30aa36401f8abefcb4cf41556d.jpg'),Path('Weezing/51bfdcd289a04e5db15f5c4ba883c17c.jpg'),Path('Weezing/1cdbf5d1e44840a88af213c3d5db7e65.jpg'),Path('Weezing/6c898f7248c6485d9a23efc7196a96ca.jpg'),Path('Weezing/ec8c7d0c9f39409e89feb10ec5332f0f.jpg')...]

I found there are some svg files that fast.ai model can't take. Thus, we need to unlink those failure images.

failed = verify_images(fns)
failed.map(Path.unlink);

Prepare datablock

Since the image size is different and we want to have data augumentation. I adopt the presize methodology that is mentioned in fast.ai notebook 5: 05_pet_breeds.

Basically, it resizes the images to a relative larger dimensions (here 460) which is signidicantly larger than the target training dimensions, and then composes all of the common augementation oeprations into one, and perform the combined operation on the GPU only once at the end of processing, rather than performing the operations individually and interpolating multiple times

pokemons = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    splitter=RandomSplitter(seed=1),
    get_items=get_image_files,
    get_y=parent_label,
    item_tfms=Resize(460),
    batch_tfms=aug_transforms(size=224, min_scale=0.75)
)
dls = pokemons.dataloaders(path)
dls.show_batch(nrows=1, ncols=4)
dls.show_batch(nrows=1, ncols=4, unique=True)
resnet18_learn = cnn_learner(dls, resnet18, metrics=error_rate)
resnet18_learn.fine_tune(2)
epoch train_loss valid_loss error_rate time
0 4.632047 1.978062 0.450147 00:25
epoch train_loss valid_loss error_rate time
0 1.664783 0.759323 0.175953 00:27
1 0.921666 0.581305 0.139296 00:28

Interpret the basic model

The baseline is around 86% accuracy. Let's find out some wrong predictions.

interp = ClassificationInterpretation.from_learner(resnet18_learn)
interp.most_confused()

[('Kingler', 'Krabby', 5),
 ('Mankey', 'Primeape', 5),
 ('Electrode', 'Voltorb', 4),
 ('Growlithe', 'Arcanine', 4),
 ('Golbat', 'Zubat', 3),
 ('Ivysaur', 'Venusaur', 3),
 ('Poliwhirl', 'Poliwrath', 3),
 ('Ponyta', 'Rapidash', 3),
 ('Rhyhorn', 'Onix', 3),
 ('Kadabra', 'Alakazam', 2),
 ('Machoke', 'Rhydon', 2),
 ('Nidoqueen', 'Blastoise', 2),
 ('Pidgeot', 'Pidgeotto', 2),
 ('Pidgeot', 'Pidgey', 2),
 ('Slowpoke', 'Slowbro', 2),
 ('Abra', 'Kadabra', 1),
 ('Aerodactyl', 'Mewtwo', 1),
 ('Aerodactyl', 'Nidoking', 1),
 ('Alakazam', 'Kadabra', 1),
 ('Arcanine', 'Rapidash', 1),
 ('Blastoise', 'Wartortle', 1),
 ('Chansey', 'Venonat', 1),
 ('Charizard', 'Charmeleon', 1),
 ('Charmander', 'Lickitung', 1),
 ('Charmander', 'Psyduck', 1),
 ('Charmander', 'Vulpix', 1),
 ('Clefairy', 'Clefable', 1),
 ('Cloyster', 'MrMime', 1),
 ('Cloyster', 'Tangela', 1),
 ('Cubone', 'Wartortle', 1),
 ('Dewgong', 'Dragonair', 1),
 ('Dewgong', 'Ninetales', 1),
 ('Diglett', 'Dugtrio', 1),
 ('Ditto', 'Grimer', 1),
 ('Ditto', 'Metapod', 1),
 ('Doduo', 'Dodrio', 1),
 ('Dragonair', 'Dratini', 1),
 ('Dragonair', 'Vaporeon', 1),
 ('Dratini', 'Raichu', 1),
 ('Dratini', 'Vaporeon', 1),
 ('Ekans', 'Dratini', 1),
 ('Exeggutor', 'Victreebel', 1),
 ('Flareon', 'Dragonite', 1),
 ('Flareon', 'Farfetchd', 1),
 ('Gengar', 'Haunter', 1),
 ('Geodude', 'Golem', 1),
 ('Geodude', 'Kingler', 1),
 ('Geodude', 'Machamp', 1),
 ('Geodude', 'Muk', 1),
 ('Geodude', 'Rhyhorn', 1),
 ('Gloom', 'Beedrill', 1),
 ('Golduck', 'Rhydon', 1),
 ('Golduck', 'Seadra', 1),
 ('Golem', 'Graveler', 1),
 ('Graveler', 'Onix', 1),
 ('Grimer', 'Omastar', 1),
 ('Growlithe', 'Ponyta', 1),
 ('Gyarados', 'Alolan Sandslash', 1),
 ('Gyarados', 'Electabuzz', 1),
 ('Haunter', 'Abra', 1),
 ('Hitmonchan', 'Machamp', 1),
 ('Horsea', 'Nidoqueen', 1),
 ('Hypno', 'Alakazam', 1),
 ('Hypno', 'Kadabra', 1),
 ('Hypno', 'Magmar', 1),
 ('Ivysaur', 'Tentacool', 1),
 ('Jigglypuff', 'Clefairy', 1),
 ('Jolteon', 'Zapdos', 1),
 ('Jynx', 'Aerodactyl', 1),
 ('Kabuto', 'Diglett', 1),
 ('Kabuto', 'Voltorb', 1),
 ('Kabutops', 'Gyarados', 1),
 ('Kingler', 'Gyarados', 1),
 ('Koffing', 'Weezing', 1),
 ('Krabby', 'Kingler', 1),
 ('Krabby', 'Omastar', 1),
 ('Lapras', 'Graveler', 1),
 ('Lapras', 'Mew', 1),
 ('Lickitung', 'Dragonite', 1),
 ('Lickitung', 'Vileplume', 1),
 ('Machamp', 'Machoke', 1),
 ('Machamp', 'Onix', 1),
 ('Machamp', 'Pinsir', 1),
 ('Machoke', 'Arcanine', 1),
 ('Machoke', 'Machamp', 1),
 ('Machoke', 'Machop', 1),
 ('Magikarp', 'Krabby', 1),
 ('Magmar', 'Kingler', 1),
 ('Magmar', 'Moltres', 1),
 ('Magneton', 'MrMime', 1),
 ('Mankey', 'Arcanine', 1),
 ('Mankey', 'Jynx', 1),
 ('Marowak', 'Cubone', 1),
 ('Marowak', 'Farfetchd', 1),
 ('Mew', 'Lickitung', 1),
 ('Mew', 'Squirtle', 1),
 ('Mewtwo', 'Dewgong', 1),
 ('Moltres', 'Ponyta', 1),
 ('MrMime', 'Slowpoke', 1),
 ('Muk', 'Grimer', 1),
 ('Nidoking', 'Nidorino', 1),
 ('Nidoqueen', 'Lapras', 1),
 ('Nidoqueen', 'Machamp', 1),
 ('Nidorina', 'Machop', 1),
 ('Nidorina', 'Nidorino', 1),
 ('Ninetales', 'Scyther', 1),
 ('Onix', 'Golem', 1),
 ('Onix', 'Weedle', 1),
 ('Paras', 'Tangela', 1),
 ('Parasect', 'Hitmonchan', 1),
 ('Parasect', 'Moltres', 1),
 ('Parasect', 'Vileplume', 1),
 ('Persian', 'Ninetales', 1),
 ('Pidgeot', 'Bellsprout', 1),
 ('Pidgeot', 'Spearow', 1),
 ('Pidgeotto', 'Spearow', 1),
 ('Pidgey', 'Pidgeot', 1),
 ('Pidgey', 'Sandslash', 1),
 ('Pidgey', 'Spearow', 1),
 ('Pikachu', 'Drowzee', 1),
 ('Pinsir', 'Charmeleon', 1),
 ('Poliwrath', 'Poliwhirl', 1),
 ('Poliwrath', 'Tangela', 1),
 ('Ponyta', 'Goldeen', 1),
 ('Primeape', 'Cubone', 1),
 ('Psyduck', 'Blastoise', 1),
 ('Raichu', 'Dragonite', 1),
 ('Raticate', 'Arcanine', 1),
 ('Raticate', 'Primeape', 1),
 ('Raticate', 'Spearow', 1),
 ('Rhydon', 'Kangaskhan', 1),
 ('Rhyhorn', 'Rhydon', 1),
 ('Rhyhorn', 'Scyther', 1),
 ('Sandshrew', 'Kakuna', 1),
 ('Sandslash', 'Drowzee', 1),
 ('Seadra', 'Articuno', 1),
 ('Seel', 'Dewgong', 1),
 ('Seel', 'Magnemite', 1),
 ('Slowbro', 'Clefable', 1),
 ('Slowbro', 'Grimer', 1),
 ('Slowpoke', 'Charmeleon', 1),
 ('Slowpoke', 'Exeggcute', 1),
 ('Staryu', 'Starmie', 1),
 ('Tentacool', 'Porygon', 1),
 ('Vaporeon', 'Seadra', 1),
 ('Venonat', 'Staryu', 1),
 ('Venusaur', 'Gastly', 1),
 ('Venusaur', 'Ivysaur', 1),
 ('Venusaur', 'Scyther', 1),
 ('Vulpix', 'Koffing', 1),
 ('Vulpix', 'Weedle', 1),
 ('Wartortle', 'Kadabra', 1),
 ('Wartortle', 'Nidoqueen', 1),
 ('Weedle', 'Hypno', 1),
 ('Weedle', 'Psyduck', 1),
 ('Weepinbell', 'Weedle', 1),
 ('Weezing', 'Graveler', 1),
 ('Weezing', 'Koffing', 1),
 ('Zubat', 'Aerodactyl', 1),
 ('Zubat', 'Diglett', 1)]

Let's first examine ('Mankey', 'Primeape', 4)

Image.open((path/"Mankey").ls()[0])
Image.open((path/"Primeape").ls()[0])

The above two pokemon are quite similar to each other. From wiki)

it esolves from Mankey starting at level 28.

Let's take a look at another confusion pair:('Ponyta', 'Rapidash', 4)

Image.open((path/"Ponyta").ls()[0])
Image.open((path/"Rapidash").ls()[0])

It's another pair of evolution.From wiki)

It evolves from Ponyta starting at level 40

Let's fine tune our model

First step, pick a more deeper model resnet34 and find a optimal learning rate

resnet34_learn = cnn_learner(dls, resnet34, metrics=error_rate)
lr_min,lr_steep = resnet34_learn.lr_find()

Print out lr_min / 10, and lr_steep

print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")
Minimum/10: 1.20e-02, steepest point: 1.10e-02

In the learning rate plot, it appers that a learning rate around 1e-2 would be approriate, so let's choose that

lr=1e-2

Second step, use the lr and manually train the last layer for several epochs (3).

resnet34_learn = cnn_learner(dls, resnet34, metrics=error_rate)
resnet34_learn.fit_one_cycle(3, lr)
epoch train_loss valid_loss error_rate time
0 2.437599 0.898877 0.222874 00:29
1 1.162741 0.594926 0.159824 00:29
2 0.522703 0.386840 0.096774 00:29

Then, we unfreeze the model so that we can train all the layers (including all pretrained layers by ImageNet dataset. Before that, we need to find better learning rate since we've trained for 3 epochs.

resnet34_learn.unfreeze()
resnet34_learn.lr_find()
SuggestedLRs(lr_min=5.248074739938602e-06, lr_steep=6.309573450380412e-07)

From above, we can pick a slice of learning rate [1e-6, 1e-4]

resnet34_learn.fit_one_cycle(12, lr_max=slice(1e-6,1e-4))
epoch train_loss valid_loss error_rate time
0 0.303374 0.374824 0.097507 00:35
1 0.303054 0.356334 0.087243 00:36
2 0.256385 0.336976 0.084311 00:35
3 0.213164 0.317025 0.081378 00:35
4 0.190199 0.310731 0.078446 00:35
5 0.174076 0.304764 0.074047 00:35
6 0.151117 0.297080 0.073314 00:36
7 0.139543 0.291337 0.071114 00:36
8 0.141836 0.292920 0.073314 00:36
9 0.128678 0.290009 0.070381 00:36
10 0.125571 0.291345 0.074780 00:36
11 0.109895 0.287401 0.071848 00:36

Let's take a look at loss plot. As you can see, the error_rate is sometimes increasing and reach a plateau.

resnet34_learnrn.recorder.plot_loss()

Now, we are having a model with around 93% accuracy. There are several ways to further improve the model:

  1. In terms of accuracy, we can try larger model e.g. resnet50
  2. In terms of training time, we can use mixed-precision training to_fb16() with half folating point precision to expediate training 2-3x faster.
resnet50_learn = cnn_learner(dls, resnet50, metrics=error_rate).to_fp16()
resnet50_learn.fine_tune(12, freeze_epochs=3)
epoch train_loss valid_loss error_rate time
0 4.912830 2.322846 0.519795 00:48
1 2.299072 0.941021 0.229472 00:46
2 1.181747 0.640837 0.164956 00:46
epoch train_loss valid_loss error_rate time
0 0.435982 0.416226 0.110704 01:06
1 0.273867 0.355325 0.085777 01:04
2 0.230902 0.362856 0.092375 01:03
3 0.194805 0.355526 0.087243 01:04
4 0.171032 0.329496 0.079912 01:04
5 0.128437 0.276524 0.072581 01:03
6 0.073753 0.252845 0.063783 01:04
7 0.053434 0.255001 0.060117 01:03
8 0.037929 0.249069 0.049853 01:04
9 0.025870 0.246209 0.049120 01:04
10 0.018314 0.240860 0.049120 01:03
11 0.015480 0.238500 0.048387 01:03

Result

As you can see, using a resnet50 model, it can push the accuracy to 95% which is pretty good. We can have a final look at what the model gets wrong.

resnet50_interp = ClassificationInterpretation.from_learner(resnet50_learn)
resnet50_interp.most_confused()

[('Kadabra', 'Alakazam', 4),
 ('Kingler', 'Krabby', 3),
 ('Marowak', 'Cubone', 3),
 ('Pidgeot', 'Pidgeotto', 3),
 ('Cubone', 'Marowak', 2),
 ('Mankey', 'Primeape', 2),
 ('Abra', 'Hypno', 1),
 ('Blastoise', 'Nidorina', 1),
 ('Charmander', 'Charmeleon', 1),
 ('Dewgong', 'Mew', 1),
 ('Diglett', 'Dugtrio', 1),
 ('Doduo', 'Dodrio', 1),
 ('Dratini', 'Dragonair', 1),
 ('Geodude', 'Vulpix', 1),
 ('Golem', 'Graveler', 1),
 ('Graveler', 'Geodude', 1),
 ('Graveler', 'Kadabra', 1),
 ('Grimer', 'Muk', 1),
 ('Gyarados', 'Machamp', 1),
 ('Gyarados', 'Rhydon', 1),
 ('Haunter', 'Venonat', 1),
 ('Hitmonchan', 'Machamp', 1),
 ('Jynx', 'Rattata', 1),
 ('Kabutops', 'Electabuzz', 1),
 ('Koffing', 'Weezing', 1),
 ('Krabby', 'Kingler', 1),
 ('Lapras', 'Tangela', 1),
 ('Machoke', 'Machamp', 1),
 ('Mankey', 'Rapidash', 1),
 ('Mewtwo', 'Dewgong', 1),
 ('Muk', 'Grimer', 1),
 ('Nidoqueen', 'Golduck', 1),
 ('Omastar', 'Dragonair', 1),
 ('Parasect', 'Paras', 1),
 ('Parasect', 'Vileplume', 1),
 ('Pidgeot', 'Fearow', 1),
 ('Pidgey', 'Sandslash', 1),
 ('Poliwhirl', 'Poliwrath', 1),
 ('Poliwrath', 'Poliwhirl', 1),
 ('Ponyta', 'Rapidash', 1),
 ('Rapidash', 'Ponyta', 1),
 ('Rhydon', 'Krabby', 1),
 ('Rhyhorn', 'Machoke', 1),
 ('Rhyhorn', 'Nidorina', 1),
 ('Rhyhorn', 'Onix', 1),
 ('Seadra', 'Jolteon', 1),
 ('Seel', 'Dewgong', 1),
 ('Slowpoke', 'Slowbro', 1),
 ('Squirtle', 'Wartortle', 1),
 ('Tentacool', 'Tentacruel', 1),
 ('Tentacruel', 'Tentacool', 1),
 ('Venusaur', 'Ivysaur', 1),
 ('Wartortle', 'Cubone', 1),
 ('Zubat', 'Golbat', 1),
 ('Zubat', 'Haunter', 1)]

Image.open((path/"Kadabra").ls()[0])
Image.open((path/"Alakazam").ls()[0])
Image.open((path/"Kingler").ls()[0])
Image.open((path/"Krabby").ls()[0])
Image.open((path/"Marowak").ls()[0])
Image.open((path/"Cubone").ls()[0])

All above, are pretty confusing pairs of evolutions. For example, between the Kingler and the Krabby, which are really hard to tell the difference, it seems that Kingler has one crab plier bigger than the other. I feel if you are new to the pokemon world, there is a non-trivial chance you can tell the which is which!