How-to: classify pokemon generation one
Practice deep learning image classification problem
I am a fan of pokemon, and remember playing generation one when I was a kid. I could exactly call out the name every pokemon, it's type and many other stats. I feel it would be interesting with deep learning, how far a model can accurately classify 150 gen-one pokemons. I feel human accuracy should be more than 99%
from fastai.vision.all import *
Download datasets and clean
I searched online and found this 7000 labeled dataset from kaggle, credit to Lance Zhang for prepapring and cleaning it.
path = Path('PokemonData')
path.ls()
Path.BASE_PATH = path
(path/"Weezing").ls()
Let's take some examples to see how the image look like.
example = Image.open((path/"Weezing").ls()[0])
example
Note that the shape of images are different.
example.shape
example1 = Image.open((path/"Weezing").ls()[1])
example1.shape
fns = get_image_files(path)
fns
I found there are some svg files that fast.ai model can't take. Thus, we need to unlink those failure images.
failed = verify_images(fns)
failed.map(Path.unlink);
Prepare datablock
Since the image size is different and we want to have data augumentation. I adopt the presize methodology that is mentioned in fast.ai notebook 5: 05_pet_breeds.
Basically, it resizes the images to a relative larger dimensions (here 460) which is signidicantly larger than the target training dimensions, and then composes all of the common augementation oeprations into one, and perform the combined operation on the GPU only once at the end of processing, rather than performing the operations individually and interpolating multiple times
pokemons = DataBlock(
blocks=(ImageBlock, CategoryBlock),
splitter=RandomSplitter(seed=1),
get_items=get_image_files,
get_y=parent_label,
item_tfms=Resize(460),
batch_tfms=aug_transforms(size=224, min_scale=0.75)
)
dls = pokemons.dataloaders(path)
dls.show_batch(nrows=1, ncols=4)
dls.show_batch(nrows=1, ncols=4, unique=True)
resnet18_learn = cnn_learner(dls, resnet18, metrics=error_rate)
resnet18_learn.fine_tune(2)
interp = ClassificationInterpretation.from_learner(resnet18_learn)
interp.most_confused()
Let's first examine ('Mankey', 'Primeape', 4)
Image.open((path/"Mankey").ls()[0])
Image.open((path/"Primeape").ls()[0])
The above two pokemon are quite similar to each other. From wiki)
it esolves from Mankey starting at level 28.
Let's take a look at another confusion pair:('Ponyta', 'Rapidash', 4)
Image.open((path/"Ponyta").ls()[0])
Image.open((path/"Rapidash").ls()[0])
It's another pair of evolution.From wiki)
It evolves from Ponyta starting at level 40
resnet34_learn = cnn_learner(dls, resnet34, metrics=error_rate)
lr_min,lr_steep = resnet34_learn.lr_find()
Print out lr_min / 10, and lr_steep
print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")
In the learning rate plot, it appers that a learning rate around 1e-2 would be approriate, so let's choose that
lr=1e-2
Second step, use the lr and manually train the last layer for several epochs (3).
resnet34_learn = cnn_learner(dls, resnet34, metrics=error_rate)
resnet34_learn.fit_one_cycle(3, lr)
Then, we unfreeze the model so that we can train all the layers (including all pretrained layers by ImageNet dataset. Before that, we need to find better learning rate since we've trained for 3 epochs.
resnet34_learn.unfreeze()
resnet34_learn.lr_find()
From above, we can pick a slice of learning rate [1e-6, 1e-4]
resnet34_learn.fit_one_cycle(12, lr_max=slice(1e-6,1e-4))
Let's take a look at loss plot. As you can see, the error_rate is sometimes increasing and reach a plateau.
resnet34_learnrn.recorder.plot_loss()
Now, we are having a model with around 93% accuracy. There are several ways to further improve the model:
- In terms of accuracy, we can try larger model e.g. resnet50
- In terms of training time, we can use mixed-precision training
to_fb16()
with half folating point precision to expediate training 2-3x faster.
resnet50_learn = cnn_learner(dls, resnet50, metrics=error_rate).to_fp16()
resnet50_learn.fine_tune(12, freeze_epochs=3)
resnet50_interp = ClassificationInterpretation.from_learner(resnet50_learn)
resnet50_interp.most_confused()
Image.open((path/"Kadabra").ls()[0])
Image.open((path/"Alakazam").ls()[0])
Image.open((path/"Kingler").ls()[0])
Image.open((path/"Krabby").ls()[0])
Image.open((path/"Marowak").ls()[0])
Image.open((path/"Cubone").ls()[0])
All above, are pretty confusing pairs of evolutions. For example, between the Kingler and the Krabby, which are really hard to tell the difference, it seems that Kingler has one crab plier bigger than the other. I feel if you are new to the pokemon world, there is a non-trivial chance you can tell the which is which!