How-to: Build a pet detector model in an hour
A quick start of build a image classifer for your pet with deep learning
- Import your libraries
- Prepare your data
- Train your model
- Checkout how your model perform
- Export you model
- Make predictions
- Build some simple UI widgets
This is mostly inspired by a fast.ai MOOC: Practical Deep Learning for Coders.It used a top-down approach teaching the deep learning. Lesson 1 and 2 will directly build several the state of the art deep learning models by a few lines of code.
In this notebook, my goal is to train a model to tell the difference of my cat: Albus, a silver shaded british shorthair from all other british short hairs. There is little pratical utility of the model, though it can be extended to something like pet finder using cameras including CCTVs, phone's camera etc to help people find their missing pet.
As a reader, you just need some basic understanding of python (or similar) and really limited experience of coding such as variable assignment, import libraries, calling some functions. I will explain every step in this notebook of what it's achieving.
Import your libraries
This is some basic python code import fastai/fastbook libraries (built by fast.ai team) to your jupyter notebook so that you can use all classes, functions in the libraries. One thing you might notice, it's generally a bad habit to import *
, but according to fast.ai mooc, they pay extra attention on this, and it only imports necessary artifacts into the notebook, so you don't have to worry to much about it.
from fastbook import *
from fastai.vision.widgets import *
Prepare your data
In this step, it will prepare your training data.
- For my pet Albus, I upload around 150 pictures into a folder
british_shorthair/albus
. - For the not_albus part, I use bing image search API to find around 150 silver shaded british shorthair images and download them into
british_shorthair/albus
. More details about bing image searchTo download images with Bing Image Search, sign up at Microsoft Azure for a free account. You will be given a key, which you can copy and enter in a cell as follows (replacing 'XXX' with your key and executing it)3. Once you download data from internet, it's possble some of images are corrupted. Thus, I use verify_images provided by fast.ai to verify all of images and unlink failures.
silver_british_shorthair_path = Path('british_shorthair')
key = os.environ.get('AZURE_SEARCH_KEY', 'xx')
results = search_images_bing(key, 'sliver shaded british shorthair')
if not silver_british_shorthair_path.exists():
silver_british_shorthair_path.mkdir()
dest = (silver_british_shorthair_path/'not_albus')
dest.mkdir(exist_ok=True)
download_images(dest, urls=results.attrgot('contentUrl'))
fns = get_image_files(silver_british_shorthair_path)
failed = verify_images(fns)
failed.map(Path.unlink)
(silver_british_shorthair_path/'not_albus').ls()
not_albus_img = Image.open(silver_british_shorthair_path/'not_albus/00000099.jpeg')
not_albus_img.to_thumb(128, 128)
(silver_british_shorthair_path/'albus').ls()
albus_img = Image.open(silver_british_shorthair_path/'albus/IMG_20200726_223920.jpg')
albus_img.to_thumb(128, 128)
Train your model
In the following code cell, it only contains 4 lines of code but it does quite a lot:
- Line 1: It prepares your data for training your model. It returns a DataBlock which is a fast.ai class
-
blocks=(ImageBlock, CategoryBlock)
: It tells datablock that the input data is image and label is category -
get_items=get_image_files
: it means when loading input item, it will use get_image_files function to load the data -
splitter=RandomSplitter(valid_pct=0.2, seed=42)
: it means that a randomation split will be performed to split your data into training (80%) and validation (20%). Seed will guarantee everytime it does a separation, the result will be the same. -
get_y=parent_label
: it teams when load label, it will find the image's parent folder name as the label name, which is a common way to organize the data -
item_tfms=RandomResizedCrop(224, min_scale=0.5)
: It applies item-wise transformation, which randomly resize and crop you image with size 224 px and minimum 50% of original image. -
batch_tfms=aug_transforms()
: It's using image augmentation technics to transform batch of images. It will flip, twist, adjust different stats of your images to generate "new" images for training purpose so that model can learn from different perspectives.
-
- Line 2: It loads the actual data from your path
- Line 3: We use a cnn learner with our data, and resnet18 architecture and use error rate as our metrics. The architecture in our case is not supper important and it's a deep residual network pretrained on imagenet (more on https://pytorch.org/hub/pytorch_vision_resnet/). For error rate, you will see that at each epoch (iteration) of training, what percentage of images that model makes mistake predict the correct label on validation set.
- Line 4: it uses transfer learning technics and you don't need to train the model from scatch. Instead, you just need to call fine tune for 4 epochs on your training data which saves us a lot of time and reduces the requirements on the number of images we need to train a good model.
# Using transfer learning tactic since I don't have enough albus picture to train
british_sh = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=parent_label,
item_tfms=RandomResizedCrop(224, min_scale=0.5),
batch_tfms=aug_transforms()
)
dls = british_sh.dataloaders(silver_british_shorthair_path)
learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(4)
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
learn.export('albus-detector.pkl')
print(learn.predict('albus_test/IMG_20190328_193913.jpg'))
print(learn.predict('albus_test/not_albus_1.jpeg'))
Build some simple UI widgets
This part is totally optional. It's fun to have some interaction UI widgets to
- Upload a photo from you computer
- Shows the prediction as well as how confident it is
As you can see in the following cell, I uploaded a pretty "tricky" image and model says it has high confidence (0.9993) that it's Albus!
learn_inf = load_learner('albus-detector.pkl', cpu=True)
learn_inf.predict('albus_test/IMG_20190328_193913.jpg')
btn_upload = widgets.FileUpload()
output = widgets.Output()
label = widgets.Label()
def on_click(change):
img = PILImage.create(btn_upload.data[-1])
output.clear_output()
with output:
display(img.to_thumb(128, 128))
pred, index, prob = learn_inf.predict(img)
label.value = f'Prediction: {pred} with probability {prob[index]:.04f}'
btn_upload.observe(on_click, names=['data'])
display(VBox([
widgets.Label("Upload your picture to tell if it's albus"),
btn_upload,
output,
label
]))