Deploying MobileNet With TensorFlow.js and Flask

JavaScriptJavaScriptBeginner
Practice Now

Introduction

This project guides you through the process of deploying a pre-trained MobileNetV2 model using TensorFlow.js within a Flask web application. MobileNetV2 is a lightweight deep neural network used primarily for image classification. TensorFlow.js enables running machine learning models directly in the browser, allowing for interactive web applications. Flask, a Python web framework, will serve as the backend to host our application. By the end of this project, you will have a running web application that classifies images on the fly using the MobileNetV2 model.

👀 Preview

🎯 Tasks

In this project, you will learn:

  • How to export a pre-trained MobileNetV2 model from Keras to a TensorFlow.js compatible format.
  • How to create a simple Flask application to serve your web content and model.
  • How to design an HTML page to upload and display images for classification.
  • How to use TensorFlow.js to load the exported model in the browser.
  • How to preprocess images in the browser to match the input requirements of MobileNetV2.
  • How to run the model in the browser to classify images and display the results.

🏆 Achievements

After completing this project, you will be able to:

  • Convert a pre-trained Keras model into a format that can be used with TensorFlow.js, enabling ML models to run in the browser.
  • Set up a Flask application and serve HTML content and static files.
  • Integrate TensorFlow.js into a web application to perform machine learning tasks client-side.
  • Preprocess images in JavaScript to make them compatible with the input requirements of deep learning models.
  • Make predictions using a deep learning model in the browser and display the results dynamically on the web page.

Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL linux(("`Linux`")) -.-> linux/BasicSystemCommandsGroup(["`Basic System Commands`"]) linux(("`Linux`")) -.-> linux/PackagesandSoftwaresGroup(["`Packages and Softwares`"]) flask(("`Flask`")) -.-> flask/CoreConceptsGroup(["`Core Concepts`"]) flask(("`Flask`")) -.-> flask/DevelopmentToolsGroup(["`Development Tools`"]) ml(("`Machine Learning`")) -.-> ml/FrameworkandSoftwareGroup(["`Framework and Software`"]) css(("`CSS`")) -.-> css/BasicConceptsGroup(["`Basic Concepts`"]) css(("`CSS`")) -.-> css/BasicStylingGroup(["`Basic Styling`"]) css(("`CSS`")) -.-> css/CoreLayoutGroup(["`Core Layout`"]) css(("`CSS`")) -.-> css/DynamicStylingGroup(["`Dynamic Styling`"]) html(("`HTML`")) -.-> html/BasicStructureGroup(["`Basic Structure`"]) html(("`HTML`")) -.-> html/TextContentandFormattingGroup(["`Text Content and Formatting`"]) html(("`HTML`")) -.-> html/MultimediaandGraphicsGroup(["`Multimedia and Graphics`"]) html(("`HTML`")) -.-> html/FormsandInputGroup(["`Forms and Input`"]) python(("`Python`")) -.-> python/DataScienceandMachineLearningGroup(["`Data Science and Machine Learning`"]) css(("`CSS`")) -.-> css/IntermediateStylingGroup(["`Intermediate Styling`"]) javascript(("`JavaScript`")) -.-> javascript/BasicConceptsGroup(["`Basic Concepts`"]) javascript(("`JavaScript`")) -.-> javascript/AdvancedConceptsGroup(["`Advanced Concepts`"]) javascript(("`JavaScript`")) -.-> javascript/DOMManipulationGroup(["`DOM Manipulation`"]) javascript(("`JavaScript`")) -.-> javascript/NetworkingGroup(["`Networking`"]) linux/BasicSystemCommandsGroup -.-> linux/source("`Script Executing`") linux/PackagesandSoftwaresGroup -.-> linux/pip("`Python Package Installing`") flask/CoreConceptsGroup -.-> flask/application_object("`Application Object`") flask/DevelopmentToolsGroup -.-> flask/template_rendering("`Template Rendering`") ml/FrameworkandSoftwareGroup -.-> ml/tensorflow("`TensorFlow`") css/BasicConceptsGroup -.-> css/selectors("`Selectors`") css/BasicConceptsGroup -.-> css/properties("`Properties`") css/BasicConceptsGroup -.-> css/values("`Values`") css/BasicStylingGroup -.-> css/colors("`Colors`") css/CoreLayoutGroup -.-> css/box_model("`Box Model`") css/CoreLayoutGroup -.-> css/borders("`Borders`") css/CoreLayoutGroup -.-> css/positioning("`Positioning`") css/DynamicStylingGroup -.-> css/transitions("`Transitions`") html/BasicStructureGroup -.-> html/basic_elems("`Basic Elements`") html/BasicStructureGroup -.-> html/charset("`Character Encoding`") html/BasicStructureGroup -.-> html/lang_decl("`Language Declaration`") html/BasicStructureGroup -.-> html/viewport("`Viewport Declaration`") html/BasicStructureGroup -.-> html/head_elems("`Head Elements`") html/TextContentandFormattingGroup -.-> html/text_head("`Text and Headings`") html/MultimediaandGraphicsGroup -.-> html/multimedia("`Multimedia Elements`") html/FormsandInputGroup -.-> html/forms("`Form Elements`") python/DataScienceandMachineLearningGroup -.-> python/machine_learning("`Machine Learning`") css/IntermediateStylingGroup -.-> css/pseudo_elements("`Pseudo-elements`") javascript/BasicConceptsGroup -.-> javascript/variables("`Variables`") javascript/BasicConceptsGroup -.-> javascript/cond_stmts("`Conditional Statements`") javascript/BasicConceptsGroup -.-> javascript/functions("`Functions`") javascript/BasicConceptsGroup -.-> javascript/obj_manip("`Object Manipulation`") javascript/AdvancedConceptsGroup -.-> javascript/async_prog("`Asynchronous Programming`") javascript/AdvancedConceptsGroup -.-> javascript/error_handle("`Error Handling`") javascript/AdvancedConceptsGroup -.-> javascript/template_lit("`Template Literals`") javascript/DOMManipulationGroup -.-> javascript/dom_manip("`DOM Manipulation`") javascript/DOMManipulationGroup -.-> javascript/event_handle("`Event Handling`") javascript/NetworkingGroup -.-> javascript/api_interact("`API Interaction`") subgraph Lab Skills linux/source -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} linux/pip -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} flask/application_object -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} flask/template_rendering -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} ml/tensorflow -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} css/selectors -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} css/properties -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} css/values -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} css/colors -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} css/box_model -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} css/borders -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} css/positioning -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} css/transitions -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} html/basic_elems -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} html/charset -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} html/lang_decl -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} html/viewport -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} html/head_elems -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} html/text_head -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} html/multimedia -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} html/forms -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} python/machine_learning -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} css/pseudo_elements -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} javascript/variables -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} javascript/cond_stmts -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} javascript/functions -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} javascript/obj_manip -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} javascript/async_prog -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} javascript/error_handle -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} javascript/template_lit -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} javascript/dom_manip -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} javascript/event_handle -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} javascript/api_interact -.-> lab-298849{{"`Deploying MobileNet With TensorFlow.js and Flask`"}} end

Preparing the Project Environment and Files

Before we begin coding, it's important to set up our project environment correctly. This includes installing necessary packages, and understanding the project file structure that is already in place.

First, familiarize yourself with the initial project file structure. The following files and folders are already provided in your working directory:

tree

Output:

.
├── app.py
├── model_convert.py
├── static
│ ├── imagenet_classes.js
│ ├── tfjs.css
│ └── tfjs.js
└── templates
└── tfjs.html

2 directories, 6 files

The structure of your project consists of several key components, each playing a vital role in deploying your web application for image classification using the MobileNetV2 model with TensorFlow.js and Flask. Below is an overview of each directory and file within your project:

  • app.py: This is the main Python file for your Flask application. It initializes the Flask app, sets up routing for your web page, and includes any backend logic necessary for serving your TensorFlow.js model and web content.
  • model_convert.py: This Python script is responsible for loading the pre-trained MobileNetV2 model and converting it to a format that is compatible with TensorFlow.js. This conversion is crucial for enabling the model to run in a web browser.
  • static/: This directory stores static files that are required by your web application. These include:
    • imagenet_classes.js: A JavaScript file containing the ImageNet classes. This file is used to map the numerical predictions of the model to human-readable class names.
    • tfjs.css: A new addition, this Cascading Style Sheets (CSS) file, is used to style the web application's user interface. It defines the visual aspects of your application, such as layouts, colors, and fonts, ensuring a more engaging and user-friendly interface.
    • tfjs.js: Another new file, this JavaScript file likely contains the logic for loading the TensorFlow.js model, processing images, and executing predictions within the browser. This script is central to the interactivity of your application, handling client-side operations related to the TensorFlow.js model.
  • templates/: This directory contains HTML files that define the structure and layout of your web application. In this case, it includes:
    • tfjs.html: The primary HTML template for your application, tfjs.html includes the necessary markup for displaying images, prediction results, and possibly user interaction elements like file upload buttons. It integrates the TensorFlow.js model, leveraging the tfjs.js script for model-related functionalities and tfjs.css for styling.

This structure is designed to separate concerns, making your project modular and easier to manage. The static and templates directories are standard in Flask applications, helping to organize static assets and HTML templates, respectively. The separation of the model conversion script (model_convert.py) from the main application logic (app.py) enhances the modularity and maintainability of your code.

Next, install the required packages:

## Install the required Python packages
pip install tensorflow==2.14.0 tensorflowjs==4.17.0 flask==3.0.2 flask-cors==4.0.0

The packages include TensorFlow for the machine learning model, TensorFlow.js for converting the model to be used in a web environment, Flask for creating the web server, and Flask-CORS for handling cross-origin requests, which are common in web applications.

Exporting the Pre-Trained MobileNetV2 Model to TensorFlow.js Format

To use the MobileNetV2 model in the browser, we first need to export it from Keras to a format that TensorFlow.js can understand.

## Complete the model_convert.py

## Exporting MobileNetV2 model
import tensorflowjs as tfjs
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2

## Load the pre-trained MobileNetV2 model
model = MobileNetV2(weights='imagenet')

## Convert and save the model in TensorFlow.js format
tfjs.converters.save_keras_model(model, 'static/model/')

In this step, you leverage the TensorFlow and TensorFlow.js libraries to load a pre-trained MobileNetV2 model and convert it into a format compatible with TensorFlow.js.

MobileNetV2 is chosen for its efficiency and relatively small size, making it suitable for web deployment. This conversion is necessary because the original model format used in Python-based TensorFlow is not directly usable in a web environment. The tfjs.converters.save_keras_model function takes the Keras model and saves it in a directory structured in a way that TensorFlow.js can easily load it later in the web application.

Then you can run:

python model_convert.py

The converted model will save to static/model/ folder:

ls static/model

## group1-shard1of4.bin  group1-shard2of4.bin  group1-shard3of4.bin  group1-shard4of4.bin  model.json

This process includes saving the model's weights and architecture in a series of shard files and a model.json file, respectively.

Creating the Flask Application

Now, we will set up a simple Flask application to serve our web page and the TensorFlow.js model.

## Complete the app.py

## Setting up the Flask application
from flask import Flask, render_template
from flask_cors import CORS

app = Flask(__name__)
cors = CORS(app)  ## Enable Cross-Origin Resource Sharing

@app.route("/")
def hello():
    ## Serve the HTML page
    return render_template('tfjs.html')

if __name__ == '__main__':
    app.run(host='0.0.0.0', port='8080', debug=True)

This step involves setting up a basic Flask web server. Flask is a micro web framework written in Python, known for its simplicity and ease of use.

You start by importing Flask and CORS (Cross-Origin Resource Sharing) from their respective libraries. CORS is essential for web applications that request resources from different domains, ensuring that your web app can safely make requests to your Flask server.

You define a simple route ("/") that serves an HTML page (tfjs.html), which will contain your client-side TensorFlow.js code. The Flask application is configured to run on your local machine (host='0.0.0.0') and listen on port 8080. The debug=True setting is helpful during development as it provides detailed error messages and automatically reloads the server when code changes are detected.

Now, you can run the Flask web application:

python app.py
## * Serving Flask app 'app'
## * Debug mode: on
## WARNING: This is a development server. Do not use it in a production ## deployment. Use a production WSGI server instead.
## * Running on all addresses (0.0.0.0)
## * Running on http://127.0.0.1:8080
## * Running on http://172.18.0.7:8080
## Press CTRL+C to quit

Preparing the HTML Structure

Now, we need to create the HTML structure of our application in templates/tfjs.html. This includes the layout for image upload, preview, and display of the prediction results.

<!-- HTML structure for the Image Prediction application -->
<!doctype html>
<html lang="en">
  <head>
    <meta charset="UTF-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    <title>Image Prediction</title>
    <link
      href="https://cdn.jsdelivr.net/npm/tailwindcss@2.0.2/dist/tailwind.min.css"
      rel="stylesheet"
    />
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <link rel="stylesheet" href="static/tfjs.css" />
  </head>
  <body class="flex flex-col items-center justify-center min-h-screen">
    <div class="card bg-white p-6 rounded-lg max-w-sm">
      <h1 class="text-xl font-semibold mb-4 text-center">Image Prediction</h1>
      <div class="flex flex-col items-center">
        <label
          for="imageUpload"
          class="button-custom cursor-pointer mb-4 flex items-center justify-center"
        >
          <span>Upload Image</span>
          <input
            type="file"
            id="imageUpload"
            class="file-input"
            accept="image/*"
          />
        </label>
        <div
          id="imagePreviewContainer"
          class="mb-4 w-56 h-56 border border-dashed border-gray-300 flex items-center justify-center"
        >
          <img
            id="imagePreview"
            class="max-w-full max-h-full"
            style="display: none"
          />
        </div>
        <h5 id="output" class="text-md text-gray-700">
          Upload an image to start prediction
        </h5>
        <script type="module" src="static/tfjs.js"></script>
      </div>
    </div>
  </body>
</html>

This step involves creating the basic structure of our web application using HTML. We define the document type as HTML and set the language attribute to English.

In the head section, we include metadata such as character set and viewport settings for responsive design. We also link to the Tailwind CSS library to utilize its utility classes for styling our application. The body section contains a div element with a class of card, which serves as a container for our application's content.

Inside this container, we have an h1 tag for the application title, a label element for the image upload button (which is styled to look like a button using Tailwind CSS and custom classes), and a div element to serve as a container for the image preview. The input element of type file is hidden and triggered when the label is clicked, allowing the user to upload an image. The div element with an id of imagePreviewContainer will display the uploaded image, and the h5 tag will be used to display messages to the user.

Add Styling

In this step, we'll add some basic CSS styling using Tailwind CSS for layout and aesthetics in static/tfjs.css, along with some custom styles for our application.

body {
  background-color: #f0f2f5;
}
.card {
  box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.button-custom {
  background-color: #4f46e5; /* Indigo 600 */
  color: white;
  padding: 0.5rem 1.5rem;
  border-radius: 0.375rem; /* rounded-md */
  transition: background-color 0.2s;
}
.button-custom:hover {
  background-color: #4338ca; /* Indigo 700 */
}
.file-input {
  opacity: 0;
  position: absolute;
  z-index: -1;
}

This step, we add custom CSS styles to enhance the appearance of our web application.

We set the background color of the body to a light gray for a neutral background. The card class adds a box shadow to create a card-like effect for the container. The button-custom class styles the upload button with an indigo background color, white text, padding, and rounded corners.

We also include a hover effect to slightly change the button's background color when hovered over, providing a visual feedback to the user. The file-input class is used to hide the actual file input element, making the custom-styled label the main interactive element for file uploads.

Load the TensorFlow Model

In the remaining steps, we will complete the templates/tfjs.js file.

Now, let's add TensorFlow.js to our project, then write JavaScript code to load our TensorFlow model. We'll use an async function to await the model's loading before making any predictions.

import { IMAGENET_CLASSES } from "./imagenet_classes.js";

const outputDiv = document.getElementById("output");
let model;

async function loadModel() {
  try {
    outputDiv.textContent = "Loading TF Model...";
    model = await tf.loadLayersModel(
      "https://****.labex.io/static/model/model.json"
    ); // Update URL
    outputDiv.textContent = "TF Model Loaded.";
  } catch (error) {
    outputDiv.textContent = `Error loading model: ${error}`;
  }
}

loadModel();

Note: You should replace the URL in tf.loadLayersModel with the URL of the current environment. You can find it by switching to the Web 8080 tab.

url

This step involves adding TensorFlow.js to our project by including its script tag in our HTML document.

Then, we write JavaScript code to asynchronously load a pre-trained TensorFlow model. We use the tf.loadLayersModel function, providing it with a URL to our model's JSON file. This function returns a promise that resolves with the loaded model. We update the text content of the outputDiv to inform the user about the model loading status. If the model loads successfully, we display "TF Model Loaded."

Otherwise, we catch any errors and display an error message to the user. This step is crucial for enabling our application to make predictions, as the model needs to be loaded and ready before any image processing can occur.

Handle Image Upload and Preview

This step involves creating functionality for users to upload images and see a preview before prediction is made. We'll use a FileReader to read the uploaded file and display it.

// Continue in static/tfjs.js
const imageUpload = document.getElementById("imageUpload");
const imagePreview = document.getElementById("imagePreview");

imageUpload.addEventListener("change", async (e) => {
  const file = e.target.files[0];
  if (file) {
    const reader = new FileReader();
    reader.onload = (e) => {
      const img = new Image();
      img.src = e.target.result;
      img.onload = async () => {
        imagePreview.src = img.src;
        imagePreview.style.display = "block";
        const processedImage = await preprocessImage(img);
        makePrediction(processedImage);
      };
    };
    reader.readAsDataURL(file);
  }
});

This step involves setting up an event listener for the image upload input.

When a user selects a file, we use a FileReader to read the file as a Data URL. We then create an Image object and set its src attribute to the result from the FileReader, effectively loading the image into the browser. Once the image is loaded, we display it inside the imagePreview container by setting the imagePreview's src attribute to the image's src and making the imagePreview element visible.

Then the image is pre-processed and predicted by our model, we will complete those functions in later steps.

Preprocess the Image for Prediction

Before making a prediction, we need to preprocess the uploaded image to match the input requirements of our model.

// Continue in static/tfjs.js
async function preprocessImage(imageElement) {
  try {
    let img = tf.browser.fromPixels(imageElement).toFloat();
    img = tf.image.resizeBilinear(img, [224, 224]);
    const offset = tf.scalar(127.5);
    const normalized = img.sub(offset).div(offset);
    const batched = normalized.reshape([1, 224, 224, 3]);
    return batched;
  } catch (error) {
    outputDiv.textContent = `Error in model prediction: ${error}`;
  }
}

Before making a prediction, we need to preprocess the uploaded image to match the input format expected by our TensorFlow model. This preprocessing includes resizing the image to the required dimensions (in this case, 224x224 pixels) and normalizing the pixel values.

We use TensorFlow.js operations like tf.browser.fromPixels to convert the image to a tensor, tf.image.resizeBilinear for resizing, and arithmetic operations to normalize the pixel values. The preprocessed image is then reshaped into a batch of one (to match the model's expected input shape), making it ready for prediction.

Make a Prediction

Once the image is preprocessed, we're ready to make a prediction. The makePrediction function takes the processed image as input, feeds it through the model, and interprets the output to determine the most likely class label.

// Continue in static/tfjs.js
async function makePrediction(processedImage) {
  try {
    const prediction = model.predict(processedImage);
    const highestPredictionIndex = await tf.argMax(prediction, 1).data();
    const label = IMAGENET_CLASSES[highestPredictionIndex];
    outputDiv.textContent = `Prediction: ${label}`;
  } catch (error) {
    outputDiv.textContent = `Error making prediction: ${error}`;
  }
}

In this step, we use the model.predict(processedImage) function to feed the preprocessed image into the TensorFlow model. The tf.argMax(prediction, 1).data() function is used to find the index of the highest value in the predictions array, which corresponds to the most likely class label for the image. This label is then displayed to the user.

Switch to the "Web 8080" tab and reload the web page to see the following effects.

Summary

In this project, you learned how to use a pre-trained MobileNetV2 model in a Flask web application using TensorFlow.js. We started by installing the necessary dependencies. We then exported the MobileNetV2 model into a TensorFlow.js compatible format and set up a Flask application to serve our web page. Finally, we created an HTML page that uses TensorFlow.js to classify images in the browser with our MobileNetV2 model. By following these steps, you've created a web application that uses deep learning for real-time image classification.

This project demonstrates the power of combining traditional web technologies with advanced machine learning models to create interactive and intelligent web applications. You can extend this project by adding more features, such as the ability to upload different images, improving the user interface, or even integrating more complex models for different tasks.

Other JavaScript Tutorials you may like