On-Device Machine Learning: Train And Run TensorFlow Lite Models In Your Flutter Apps

Build, Train, and Run an Image Classifier Neural Network using Deep Transfer Learning in TensorFlow and Keras

Yashwardhan Deshmukh
Google Cloud - Community

--

Photo by Casper on Unsplash

NOTE: As of when I am writing this, the latest version of Python is 3.9. However, Tensorflow is currently only compatible with Python version 3.5–3.8. If u have a non-compatible version of Python and do not wish to downgrade, you could use Google Colab, which is a free online cloud-based Jupyter Notebook Enviroment.

This article will be divided into 2 modules:

  • Training a Convolutional Neural Network using Transfer Learning to be able to detect thirty plus different types of Fruits and Vegetables.
  • Creating Flutter UI and Importing the .tflite file into the project, then testing it and fixing possible errors.

Module 1: Training Neural Networks using TensorFlow and Keras

What is Tensorflow? → TensorFlow is an open-source library that can train and run deep neural networks (which is a slew of machine learning and deep learning) for image recognition, word embedding, natural language processing, etc. Download → https://www.tensorflow.org/install

By the end of this article, you will be able to make this app and run it on your phone

Start by first downloading a dataset having all the images. Our neural network will train,learn and validate from these images. You can download the one I will be using from kaggle.com by clicking here. Then extract the .zip file and create a Python Jupyter Notebook in the same directory.

The entire, completed project can be found and cloned for personal use from my GitHub Repo → https://git.io/JLhpv

Let's start Programming (Python Jupyter Notebook):

Code Cell 1 (importing modules)
  • The os module will provide us functions for fetching contents of and writing to a directory.
  • Set the base_dir variable to the location of the dataset containing the training images.
Code Cell 2 (Preprocessing, where we prepare raw data to be suitable for building and training models)
  • IMAGE_SIZE = 224 → the image size that we are going to set the images in the dataset to.
  • BATCH_SIZE = 64 → the number of images we are inputting into the neural network at once.
  • rescale=1./255 reduces the file size, to reduce the training time.
  • Datasets have a Test set and a Training set. The Validation set is normally to validate our neural network, to give us a measure of accuracy on how well the neural network is performing. So with validation_split=0.2, we are telling Keras to use 20% for validation and 80% for training
  • Then we have two generators (train_generator and val_generator), which take the path to the directory & generate batches of augmented data, which in this case give the output: Found 2872 images belonging to 36 classes and Found 709 images belonging to 36 classes.
Code Cell 3 (creating a labels.txt file that will hold all our labels)
  • Print all keys and classes (labels) of the dataset to re-check if everything is working fine.
  • Flutter requires two files: labels.txt and model.tflite.
  • The ‘w’ in the code creates a new file called labels.txt having the labels, which if already exists, then overwrites it.
The output of Code Cell 3

Now that we have successfully pre-processed our raw data, It’s time to start building our actual Neural Network using Transfer learning.

Image by Yashwardhan Deshmukh

Transfer Learning → It is a machine learning method in which we build a neural network off of an already pre-trained neural network.

Basically, the neural network has already been trained on some other tasks, hence it helps in understanding the patterns on a second task more efficiently. Think of it as teaching first-grade math to a newborn child (Traditional ML), versus a grown adult (Transfer Learning).

Here, we will be using MobileNetV2 which is a pre-trained CNN (convolutional neural network) architecture, made to perform very well for on-device Machine Learning and can predict up to 80 different classes.

Code Cell 4 (Creating a base model for Transfer Learning)
  • We start by grabbing MobileNetV2.
  • Since we don't want to re-train pre-trained CNN, but rather add on to it, include_top=False → is used, which will freeze all the weights, in which case the fully-connected output layers of the model used to make predictions are not loaded, allowing a new output layer to be added and trained.
Code Cell 5 (Adding Layers to NN)
  • base_model.trainable=False Freezes all the neurons for our base model.
  • Neural Networks act in a sequence of layers, hence now we will add our own layers:
  1. The Conv2D is a 2D convolution layer that creates a convolution kernel that is a wind with layers input which helps produce a tensor of outputs. Basically, it is trying to understand the image’s patterns. ‘relu’ stands for rectified linear unit activation function.
  2. Dropout layer prevents Neural Networks from Overfitting, i.e being too precise to a point where the NN is only able to recognize images that are present in the dataset and no other images.
  3. GlobalAveragePooling2D layer calculates the average output of each feature map in the previous layer, thus reducing the data significantly and preparing the model for the final layer.
  4. Dense layer is a deeply connected layer in which each neuron receives input from all neurons of its previous layer. 36’ here stands for the number of classes. ‘softmax’ converts a real vector to a vector of categorical probabilities.
Code Cell 6 (Compiling the model before leaving it for training)
  • We use model.compilewhich defines the loss function, the optimizer, and the metrics, because a compiled model is needed to train (since training uses the loss function and the optimizer).
  • We will use Adam which is a popular optimizer, designed specifically for training deep neural networks.
Code Cell 7 (Training!)
  • epochs → the number of times that the learning algorithm will iterate and work through the entire training dataset. Higher the number, the more accurate the neural network, BUT… having the number too high could cause Overfitting, i.e being too precise to a point where the NN is only able to recognize images that are present in the dataset and no other images.
The Output for Code Cell 7 (Training Process, takes some time, go do yoga)
Code Cell 8 (Converting the Trained neural network into a Tensorflow Lite file)
  • saved_model_dir = ‘’ where ‘’ means the current directory.
  • tf.saved_model.save(model, saved_model_dir)saves to the current directory.
  • The next two lines convert our model into a .tflite model, used by flutter for on-device ML.
  • Finally, we write the converted model into the directory, as binary so add ‘wb’ instead of just ‘w’.
  • If you are using Google Colab then add this as an extra code cell:
#use these codes to download files locally if using google colab
from google.colab import files
files.download(‘model.tflite’)
files.download(‘labels.txt’)

Now that we have both the ‘model.tflite’ and ‘labels.txt’ files, we can import them into a Flutter Project! If you somehow messed up, you could download these files from the assets folder of my GitHub Repo → https://git.io/JLhpv

Module 2: Importing and using model.tflite in a Flutter app.

Go to ‘Visual Studio Code > View > Command Palette > Flutter New Application Project’ or just go to the terminal, navigate to the directory, and type ‘flutter create project_name

Next, head over to the ‘pubspec.yaml’, add the following dependencies, and save :

dependencies:
flutter:
sdk: flutter
tflite: ^1.1.1
image_picker: ^0.6.7+4

For tflite to work, in android/app/build.gradle, add the following setting in android block.

aaptOptions {
noCompress 'tflite'
noCompress 'lite'
}

For image_picker to work, in /ios/Runner/Info.plist, add the following to your Info.plist file.

<key>NSCameraUsageDescription</key>
<string>Need Camera Access</string>
<key>NSMicrophoneUsageDescription</key>
<string>Need Microphone Access</string>
<key>NSPhotoLibraryUsageDescription</key>
<string>Need Photo Accesss</string>

Its now time to create a simple app, which displays the user's image and prediction as shown below:

It predicts bananas too?!
  • In the ‘main.dart’ file, return MaterialApp that has a parameter home: Home(),
  • Then create a new ‘home.dart’ file having the Stateful class Home(). This will be our homepage. Let's start making the Functional Flutter App:
Assuming you have imported material,tflite, image_picker, and dart:io

_loading → used to check if an image has been chosen or not, _image → image that is chosen, _output → prediction made, picker → allows us to pick an image from gallery or camera.

Next, we will write 6 different methods for the class:

First 2 methods
  • The first 2 methods :
  1. initState() → This is the first method that is called when the Home widget is created i.e we the app is launched and navigated to Home(), before actually building the widget itself, anything inside initState() function will be called or initialized first and the widgets are built later. In this case, we will load our model using loadModel(), which is another method that will be written later. After that, we will pass in a value.
  2. dispose()This method disposes and clears our memory.
Last 4 methods
  • The last 4 methods:

3. classifyImage() → this method runs the classification model on the image. The numResults is the number of classes we have, then adding setState to save changes.

4. loadModel() → this function will load our model, hence we put it inside the initS method.

5. pickImage() → this function is used to grab the image from the camera.

6. pickGalleryImage() → this function is used to grab the image from the user’s gallery.

Flutter UI Time!

This part of the code is basic UI. For those who understand how it works can skip it and write their own!

The First Part of the UI

First, we have to make the AppBar and then the Container that holds the image. Here we will use ternary operators (condition ? true statement : false statement) to show an image and a text widget only if _loading is set to true and _output is set to not null, else we will show a blank container.

  @override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
backgroundColor: Colors.black,
title: Text(
'Fruits and Veggies Neural Network',
style: TextStyle(
color: Colors.white,
fontWeight: FontWeight.w200,
fontSize: 20,
letterSpacing: 0.8),
),),

That was the AppBar. Lets make a container now to hold an image that the user has selected!


body: Container(
color: Colors.black.withOpacity(0.9),
padding: EdgeInsets.symmetric(horizontal: 35, vertical: 50),
child: Container(
alignment: Alignment.center,
padding: EdgeInsets.all(30),
decoration: BoxDecoration(
color: Color(0xFF2A363B),
borderRadius: BorderRadius.circular(30),
),
child: Column(
mainAxisAlignment: MainAxisAlignment.center,
children: [
Container(
child: Center(
child: _loading == true
? null
//show nothing if no picture selected
: Container(
child: Column(
children: [
Container(
height: 250, width: 250,
child: ClipRRect(

ClipRRect is used to give circular borders to the image.

                                  borderRadius:
BorderRadius.circular(30),
child: Image.file(
_image,

fit: BoxFit.fill,
),),),
Divider(
height: 25,thickness: 1,
),
_output != null
? Text(

'The object is: ${_output[0]
['label']}!',

style: TextStyle(
color: Colors.white,
fontSize: 18,
fontWeight:
FontWeight.w400),
)
: Container(),
Divider(
height: 25,
thickness: 1,
),],),),),),
//code contd..

Next, we make two GestureDetectors, that onTap: refer to the pickImage and the pickGalleryImage function respectively.

NOTE: pickImage (without parenthesis) inside onTap is a function reference, which basically means it is not executed immediately, it is executed after the user clicks on the specific widget.(callback)

pickImage() is a function call and it is executed immediately.

The second part of the UI, Gesture Detectors!
              Container(
child: Column(
children: [
GestureDetector(
onTap: pickImage, //no parenthesis
child: Container(
width:
MediaQuery.of(context).size.width - 200,
alignment: Alignment.center,
padding:
EdgeInsets.symmetric(horizontal: 24,
vertical: 17),
decoration: BoxDecoration(
color: Colors.blueGrey[600],
borderRadius:
BorderRadius.circular(15)),
child: Text(
'Take A Photo',
style: TextStyle(color: Colors.white,
fontSize: 16),
),),),
SizedBox(
height: 30,
),
GestureDetector(
onTap: pickGalleryImage, //no parenthesis
child: Container(
width:
MediaQuery.of(context).size.width - 200,
alignment: Alignment.center,
padding:
EdgeInsets.symmetric(horizontal: 24,
vertical: 17),
decoration: BoxDecoration(
color: Colors.blueGrey[600],
borderRadius:
BorderRadius.circular(15)),
child: Text(
'Pick From Gallery',
style: TextStyle(color: Colors.white,
fontSize: 16),
),),),],),),],),),),);}}

Done! Now Save and Run it on a simulator or a real phone!

Test it out yourself; whenever an image resembling any of the categories the neural network has learnt from, is either chosen from the camera roll or clicked in real-time, the app should output the result!

This is just one of the very basic uses of deep learning. It really shows us what a blessing Artificial Neural Networks are to humanity, and if mastered could achieve things never imagined by humanity… maybe flying suits or cars someday?

Debugging time!

Some of the common errors (including the ones I personally encountered), and how to fix them:

  1. ‘Lexical or Preprocessor Issue ‘…’ file not found’
  2. ‘IOS Xcode Build error : ‘metal_delegate.h’ file not found’
  3. ‘vector’ file not found
  4. ‘tensorflow/lite/kernels/register.h’ file not found

As of when I'm writing this, these errors are very common while trying to sign and build the app to an iOS device using Xcode.

The fix:

  1. Navigate to project_file/ios, open the podfile.lock with any text editor and set the value of TensorFlowLiteC to (2.2.0), as shown below
Text Editor

2. Then launch terminal and type the following:

cd project_directory_here/
cd ios/
pod install
Terminal Output

3. Finally, Go to the project_file/ios folder launch the Runner.xcworkspace in Xcode, then click Runner > Targets > Runner > Build Settings, search Compile Sources As, change the value to Objective-C++

This should fix the ‘vector’ file not found

To fix the fourth error, Uncomment //#define CONTRIB_PATH from the TflitePlugin.mm file shown in the errors section of Xcode.

Thanks for reading, I hope you learnt something! Any comments, doubts or suggestions are highly valuable to me.

GitHub

--

--

Yashwardhan Deshmukh
Google Cloud - Community

What does it mean to be conscious? Would neural networks ever be conscious? What is the true meaning of life?