background-shape
feature-image

Table of Contents


info

This blog post is still a work in progress. If you require further clarifications before the contents are finalized, please get in touch with me here, on LinkedIn, or Twitter.

🔥 Motivation

You finally got into a Kaggle competition. You found a getting-started notebook written by a Kaggle Grandmaster and immediately trained a state-of-the-art (SOTA) image classification model.

After some fiddling, you found yourself in the leaderboard topping the charts with 99.9851247% accuracy on the test set 😎!

Proud of your achievement you reward yourself to some rest and a good night’s sleep.

And then..

With various high level libraries like Keras, Transformer and Fastai, the barrier to SOTA models have never been lower.

On top of that with platforms like Google Colab and Kaggle, pretty much anyone can train a reasonably good model using an old laptop or even a mobile phone (with some patience).

The question is no longer “can we train a SOTA model?”, but “what happens after that?

Unfortunately, after getting the model trained, majority data scientists wash their hands off at this point claiming their model works. But, what good would SOTA models do if it’s just in notebooks and Kaggle leaderboards?

Unless the model is deployed and put to use, it’s of little benefit to anyone out there.

But deployment is painful. Running a model on a mobile phone?

Forget it 🤷‍♂️.

The frustration is real. I remember spending nights exporting models into ONNX and it still fails me.

Mobile deployment doesn’t need to be complicated. In this post I’m going to show you how you can pick from over 600+ SOTA models on TIMM and deploy them on Android, for free.

tip

⚡ By the end of this post you will learn how to:

  • Train a SOTA model using TIMM and Fastai.
  • Export the trained model into TorchScript.
  • Create a beautiful Flutter app and run the model inference on your Android device.

💡NOTE: If you already have a trained TIMM model, feel free to jump straight into Exporting to TorchScript section.

But, if you’d like to discover how I train a model using some of the best techniques on Kaggle, read on 👇

🥇 PyTorch Image Models

PyTorch Image Models or TIMM is the open-source computer vision library by Ross Wightman.

The TIMM repository hosts hundreds of recent SOTA models maintained by Ross. At this point we have 964 pretrained model on TIMM and increasing as we speak.

You can install TIMM by simply:

pip install timm

The TIMM repo provides various utility functions and training script. Feel free to use them. In this post I’m going to show you an easy way to train a TIMM model using Fastai 👇

🏋️‍♀️ Training with Fastai

Fastai is a deep learning library which provides practitioners with high high-level components that can quickly provide SOTA results. Under the hood Fastai uses PyTorch but it abstracts away the details and incorporates various best practices in training a model.

Install Fastai with:

pip install fastai

You can access all TIMM models within Fastai. For example, we can search for model architectures a wildcard. Since we will be running the model on a mobile device, let’s search for models that has the word edge.

import timm
timm.list_models('*edge*')

This outputs all models that match the wildcard.

['cs3edgenet_x',
 'cs3se_edgenet_x',
 'edgenext_base',
 'edgenext_small',
 'edgenext_small_rw',
 'edgenext_x_small',
 'edgenext_xx_small']

Since, we’d run our model on a mobile device, let’s select the smallest model available edgenext_xx_small. Now let’s use Fastai and quickly train the model.

Firstly import all the necessary packages with

from fastai.vision.all import *

Next, load the images into a DataLoader.

trn_path = Path('../data/train_images')
dls = ImageDataLoaders.from_folder(trn_path, seed=316, 
                                   valid_pct=0.2, bs=128,
                                   item_tfms=[Resize((224, 224))], 
                                   batch_tfms=aug_transforms(min_scale=0.75))

note

Parameters for the from_folder method:

  • trn_path – A Path to the training images.
  • valid_pct – The percentage of dataset to allocate as the validation set.
  • bs – Batch size to use during training.
  • item_tfms – Transformation applied to each item.
  • batch_tfms – Random transformations applied to each batch to augment the dataset.

You can show a batch of the images loaded into the DataLoader with:

dls.train.show_batch(max_n=8, nrows=2)

Next create a Learner object which combines the model and data into one object for training.

learn = vision_learner(dls, 'edgenext_xx_small', metrics=accuracy).to_fp16()

Find the best learning rate.

learn.lr_find()

Now train the model.

learn.fine_tune(5, base_lr=1e-2, cbs=[ShowGraphCallback()])

Optionally export the Learner.

learn.export("../../train/export.pkl")

tip

View and fork my training notebook here.

📀 Exporting to TorchScript

Now that we are done training the model, it’s time we export the model in a form suiteble on a mobile device. We can do that easily with TorchScript.

TorchScript is a way to create serializable and optimizable models from PyTorch code.

TorchScript Docs

All the models on TIMM can be exported to TorchScript with the following code snippet.

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

learn.model.cpu()
learn.model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(learn.model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("model.pt")

📲 Inference in Flutter

We will be using the pytorch_lite Flutter package.

Supports object classification and detection with TorchScript.

Link to my GitHub repo.

The screen capture shows the Flutter app in action. The clip runs in real-time and not sped up.

🙏 Comments & Feedback

I hope you’ve learned a thing or two from this blog post. If you have any questions, comments, or feedback, please leave them on the following Twitter/LinkedIn post or drop me a message.