使用TensorFlow.js和Flask部署MobileNet

JavaScriptJavaScriptBeginner
立即练习

💡 本教程由 AI 辅助翻译自英文原版。如需查看原文,您可以 切换至英文原版

简介

本项目将指导你完成在 Flask 网络应用程序中使用 TensorFlow.js 部署预训练的 MobileNetV2 模型的过程。MobileNetV2 是一个轻量级深度神经网络,主要用于图像分类。TensorFlow.js 能够直接在浏览器中运行机器学习模型,从而实现交互式网络应用程序。Python 网络框架 Flask 将作为后端来托管我们的应用程序。在本项目结束时,你将拥有一个能够使用 MobileNetV2 模型即时对图像进行分类的运行中的网络应用程序。

👀 预览

🎯 任务

在本项目中,你将学习:

  • 如何将预训练的 MobileNetV2 模型从 Keras 导出为与 TensorFlow.js 兼容的格式。
  • 如何创建一个简单的 Flask 应用程序来提供网页内容和模型。
  • 如何设计一个 HTML 页面来上传和显示用于分类的图像。
  • 如何使用 TensorFlow.js 在浏览器中加载导出的模型。
  • 如何在浏览器中预处理图像以匹配 MobileNetV2 的输入要求。
  • 如何在浏览器中运行模型以对图像进行分类并显示结果。

🏆 成果

完成本项目后,你将能够:

  • 将预训练的 Keras 模型转换为可与 TensorFlow.js 一起使用的格式,使机器学习模型能够在浏览器中运行。
  • 设置一个 Flask 应用程序并提供 HTML 内容和静态文件。
  • 将 TensorFlow.js 集成到网络应用程序中以在客户端执行机器学习任务。
  • 在 JavaScript 中预处理图像以使其与深度学习模型的输入要求兼容。
  • 使用浏览器中的深度学习模型进行预测并在网页上动态显示结果。

准备项目环境和文件

在开始编码之前,正确设置项目环境非常重要。这包括安装必要的软件包,并了解已有的项目文件结构。

首先,熟悉初始项目文件结构。工作目录中已经提供了以下文件和文件夹:

tree

输出:

.
├── app.py
├── model_convert.py
├── static
│ ├── imagenet_classes.js
│ ├── tfjs.css
│ └── tfjs.js
└── templates
└── tfjs.html

2个目录,6个文件

项目结构由几个关键组件组成,每个组件在使用 MobileNetV2 模型、TensorFlow.js 和 Flask 部署用于图像分类的 Web 应用程序中都起着至关重要的作用。以下是项目中每个目录和文件的概述:

  • app.py:这是 Flask 应用程序的主要 Python 文件。它初始化 Flask 应用,设置网页的路由,并包含提供 TensorFlow.js 模型和网页内容所需的任何后端逻辑。
  • model_convert.py:这个 Python 脚本负责加载预训练的 MobileNetV2 模型,并将其转换为与 TensorFlow.js 兼容的格式。这种转换对于使模型能够在网页浏览器中运行至关重要。
  • static/:此目录存储 Web 应用程序所需的静态文件。这些文件包括:
    • imagenet_classes.js:一个包含 ImageNet 类别的 JavaScript 文件。该文件用于将模型的数值预测映射为人类可读的类名。
    • tfjs.css:一个新增的层叠样式表(CSS)文件,用于设置 Web 应用程序用户界面的样式。它定义了应用程序的视觉方面,如布局、颜色和字体,确保界面更具吸引力且用户友好。
    • tfjs.js:另一个新增文件,这个 JavaScript 文件可能包含在浏览器中加载 TensorFlow.js 模型、处理图像和执行预测的逻辑。此脚本是应用程序交互性的核心,处理与 TensorFlow.js 模型相关的客户端操作。
  • templates/:此目录包含定义 Web 应用程序结构和布局的 HTML 文件。在这种情况下,它包括:
    • tfjs.html:应用程序的主要 HTML 模板,tfjs.html 包含显示图像、预测结果以及可能的用户交互元素(如文件上传按钮)所需的必要标记。它集成了 TensorFlow.js 模型,利用 tfjs.js 脚本实现与模型相关的功能,并使用 tfjs.css 进行样式设置。

这种结构旨在分离关注点,使项目模块化且更易于管理。statictemplates 目录在 Flask 应用程序中是标准的,分别有助于组织静态资产和 HTML 模板。将模型转换脚本(model_convert.py)与主应用程序逻辑(app.py)分离可增强代码的模块化和可维护性。

接下来,安装所需的软件包:

## 安装所需的 Python 软件包
pip install tensorflow==2.14.0 tensorflowjs==4.17.0 flask==3.0.2 flask-cors==4.0.0

这些软件包包括用于机器学习模型的 TensorFlow、用于将模型转换为在 Web 环境中使用的 TensorFlow.js、用于创建 Web 服务器的 Flask 以及用于处理跨域请求的 Flask-CORS,跨域请求在 Web 应用程序中很常见。

将预训练的 MobileNetV2 模型导出为 TensorFlow.js 格式

要在浏览器中使用 MobileNetV2 模型,我们首先需要将其从 Keras 导出为 TensorFlow.js 能够理解的格式。

## 完成 model_convert.py

## 导出 MobileNetV2 模型
import tensorflowjs as tfjs
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2

## 加载预训练的 MobileNetV2 模型
model = MobileNetV2(weights='imagenet')

## 转换并以 TensorFlow.js 格式保存模型
tfjs.converters.save_keras_model(model,'static/model/')

在这一步中,你利用 TensorFlow 和 TensorFlow.js 库加载预训练的 MobileNetV2 模型,并将其转换为与 TensorFlow.js 兼容的格式。

选择 MobileNetV2 是因为它效率高且尺寸相对较小,适合进行网页部署。这种转换是必要的,因为基于 Python 的 TensorFlow 中使用的原始模型格式不能直接在网页环境中使用。tfjs.converters.save_keras_model 函数接受 Keras 模型,并将其保存在一个目录中,该目录的结构使 TensorFlow.js 稍后能够在网页应用程序中轻松加载它。

然后你可以运行:

python model_convert.py

转换后的模型将保存到 static/model/ 文件夹中:

ls static/model

## group1-shard1of4.bin  group1-shard2of4.bin  group1-shard3of4.bin  group1-shard4of4.bin  model.json

这个过程包括分别将模型的权重和架构保存在一系列分片文件和一个 model.json 文件中。

✨ 查看解决方案并练习

创建 Flask 应用程序

现在,我们将设置一个简单的 Flask 应用程序来提供我们的网页和 TensorFlow.js 模型。

## 完成 app.py

## 设置 Flask 应用程序
from flask import Flask, render_template
from flask_cors import CORS

app = Flask(__name__)
cors = CORS(app)  ## 启用跨域资源共享

@app.route("/")
def hello():
    ## 提供 HTML 页面
    return render_template('tfjs.html')

if __name__ == '__main__':
    app.run(host='0.0.0.0', port='8080', debug=True)

这一步涉及设置一个基本的 Flask 网络服务器。Flask 是一个用 Python 编写的微型网络框架,以其简单易用而闻名。

你首先从各自的库中导入 Flask 和 CORS(跨域资源共享)。对于从不同域请求资源的网络应用程序来说,CORS 至关重要,它确保你的网络应用程序能够安全地向 Flask 服务器发出请求。

你定义了一个简单的路由("/"),它提供一个 HTML 页面(tfjs.html),该页面将包含你的客户端 TensorFlow.js 代码。Flask 应用程序被配置为在你的本地机器上运行(host='0.0.0.0')并监听端口 8080。debug=True 设置在开发过程中很有帮助,因为它提供详细的错误消息,并在检测到代码更改时自动重新加载服务器。

现在,你可以运行 Flask 网络应用程序:

python app.py
## * 提供 Flask 应用 'app'
## * 调试模式:开启
## 警告:这是一个开发服务器。不要在生产部署中使用它。请改用生产 WSGI 服务器。
## * 在所有地址 (0.0.0.0) 上运行
## * 在 http://127.0.0.1:8080 上运行
## * 在 http://172.18.0.7:8080 上运行
## 按 CTRL+C 退出
✨ 查看解决方案并练习

准备 HTML 结构

现在,我们需要在 templates/tfjs.html 中创建应用程序的 HTML 结构。这包括图像上传、预览以及预测结果显示的布局。

<!-- 图像预测应用程序的 HTML 结构 -->
<!doctype html>
<html lang="en">
  <head>
    <meta charset="UTF-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    <title>图像预测</title>
    <link
      href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css"
      rel="stylesheet"
    />
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <link rel="stylesheet" href="static/tfjs.css" />
  </head>
  <body class="flex flex-col items-center justify-center min-h-screen">
    <div class="card bg-white p-6 rounded-lg max-w-sm">
      <h1 class="text-xl font-semibold mb-4 text-center">图像预测</h1>
      <div class="flex flex-col items-center">
        <label
          for="imageUpload"
          class="button-custom cursor-pointer mb-4 flex items-center justify-center"
        >
          <span>上传图像</span>
          <input
            type="file"
            id="imageUpload"
            class="file-input"
            accept="image/*"
          />
        </label>
        <div
          id="imagePreviewContainer"
          class="mb-4 w-56 h-56 border border-dashed border-gray-300 flex items-center justify-center"
        >
          <img
            id="imagePreview"
            class="max-w-full max-h-full"
            style="display: none"
          />
        </div>
        <h5 id="output" class="text-md text-gray-700">上传图像以开始预测</h5>
        <script type="module" src="static/tfjs.js"></script>
      </div>
    </div>
  </body>
</html>

这一步涉及使用 HTML 创建我们的 Web 应用程序的基本结构。我们将文档类型定义为 HTML,并将语言属性设置为英语。

head 部分,我们包含元数据,如字符集和用于响应式设计的视口设置。我们还链接到 Tailwind CSS 库,以利用其实用类来设计应用程序的样式。body 部分包含一个类为 carddiv 元素,它用作我们应用程序内容的容器。

在这个容器中,我们有一个用于应用程序标题的 h1 标签、一个用于图像上传按钮的 label 元素(使用 Tailwind CSS 和自定义类将其样式设计为看起来像一个按钮),以及一个用作图像预览容器的 div 元素。类型为 fileinput 元素是隐藏的,当点击标签时会触发,允许用户上传图像。idimagePreviewContainerdiv 元素将显示上传的图像,h5 标签将用于向用户显示消息。

✨ 查看解决方案并练习

添加样式

在这一步中,我们将使用 Tailwind CSS 在 static/tfjs.css 中添加一些基本的 CSS 样式,用于布局和美观,同时为我们的应用程序添加一些自定义样式。

body {
  background-color: #f0f2f5;
}
.card {
  box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.button-custom {
  background-color: #4f46e5; /* 靛蓝色 600 */
  color: white;
  padding: 0.5rem 1.5rem;
  border-radius: 0.375rem; /* 圆角-md */
  transition: background-color 0.2s;
}
.button-custom:hover {
  background-color: #4338ca; /* 靛蓝色 700 */
}
.file-input {
  opacity: 0;
  position: absolute;
  z-index: -1;
}

在这一步中,我们添加自定义 CSS 样式以增强 Web 应用程序的外观。

我们将主体的背景颜色设置为浅灰色,以获得中性背景。card 类添加了一个盒阴影,为容器创建类似卡片的效果。button-custom 类为上传按钮设置样式,背景颜色为靛蓝色,文本为白色,有内边距和圆角。

我们还添加了悬停效果,当鼠标悬停时稍微改变按钮的背景颜色,为用户提供视觉反馈。file-input 类用于隐藏实际的文件输入元素,使自定义样式的标签成为文件上传的主要交互元素。

✨ 查看解决方案并练习

加载 TensorFlow 模型

在接下来的步骤中,我们将完成 templates/tfjs.js 文件。

现在,让我们将 TensorFlow.js 添加到项目中,然后编写 JavaScript 代码来加载我们的 TensorFlow 模型。我们将使用一个异步函数在进行任何预测之前等待模型加载完成。

import { IMAGENET_CLASSES } from "./imagenet_classes.js";

const outputDiv = document.getElementById("output");
let model;

async function loadModel() {
  try {
    outputDiv.textContent = "Loading TF Model...";
    model = await tf.loadLayersModel(
      "https://****.labex.io/static/model/model.json"
    ); // 更新 URL
    outputDiv.textContent = "TF Model Loaded.";
  } catch (error) {
    outputDiv.textContent = `Error loading model: ${error}`;
  }
}

loadModel();

注意:你应该将 tf.loadLayersModel 中的 URL 替换为当前环境的 URL。你可以通过切换到 Web 8080 标签页找到它。

TensorFlow 模型加载示意图

这一步涉及通过在 HTML 文档中包含其脚本标签将 TensorFlow.js 添加到我们的项目中。

然后,我们编写 JavaScript 代码来异步加载预训练的 TensorFlow 模型。我们使用 tf.loadLayersModel 函数,并为其提供模型 JSON 文件的 URL。这个函数返回一个 Promise,在模型加载完成时会解析。我们更新 outputDiv 的文本内容,以告知用户模型的加载状态。如果模型加载成功,我们显示“TF Model Loaded.”。

否则,我们捕获任何错误并向用户显示错误消息。这一步对于使我们的应用程序能够进行预测至关重要,因为在进行任何图像处理之前,模型需要加载并准备好。

✨ 查看解决方案并练习

处理图像上传和预览

这一步涉及创建功能,让用户在进行预测之前上传图像并查看预览。我们将使用 FileReader 来读取上传的文件并显示它。

// 在 static/tfjs.js 中继续
const imageUpload = document.getElementById("imageUpload");
const imagePreview = document.getElementById("imagePreview");

imageUpload.addEventListener("change", async (e) => {
  const file = e.target.files[0];
  if (file) {
    const reader = new FileReader();
    reader.onload = (e) => {
      const img = new Image();
      img.src = e.target.result;
      img.onload = async () => {
        imagePreview.src = img.src;
        imagePreview.style.display = "block";
        const processedImage = await preprocessImage(img);
        makePrediction(processedImage);
      };
    };
    reader.readAsDataURL(file);
  }
});

这一步涉及为图像上传输入设置一个事件监听器。

当用户选择一个文件时,我们使用 FileReader 将文件读取为 Data URL。然后我们创建一个 Image 对象,并将其 src 属性设置为 FileReader 的结果,从而有效地将图像加载到浏览器中。一旦图像加载完成,我们通过将 imagePreviewsrc 属性设置为图像的 src 并使 imagePreview 元素可见,将其显示在 imagePreview 容器中。

然后,图像由我们的模型进行预处理和预测,我们将在后续步骤中完成这些功能。

✨ 查看解决方案并练习

预处理图像以进行预测

在进行预测之前,我们需要对上传的图像进行预处理,以使其符合模型的输入要求。

// 在 static/tfjs.js 中继续
async function preprocessImage(imageElement) {
  try {
    let img = tf.browser.fromPixels(imageElement).toFloat();
    img = tf.image.resizeBilinear(img, [224, 224]);
    const offset = tf.scalar(127.5);
    const normalized = img.sub(offset).div(offset);
    const batched = normalized.reshape([1, 224, 224, 3]);
    return batched;
  } catch (error) {
    outputDiv.textContent = `Error in model prediction: ${error}`;
  }
}

在进行预测之前,我们需要对上传的图像进行预处理,以使其符合我们的 TensorFlow 模型预期的输入格式。这种预处理包括将图像调整为所需的尺寸(在这种情况下为 224x224 像素)并对像素值进行归一化。

我们使用 TensorFlow.js 操作,如 tf.browser.fromPixels 将图像转换为张量,tf.image.resizeBilinear 进行调整大小,并使用算术运算对像素值进行归一化。然后将预处理后的图像重塑为一批一个(以匹配模型预期的输入形状),使其准备好进行预测。

✨ 查看解决方案并练习

进行预测

一旦图像预处理完成,我们就准备好进行预测了。makePrediction 函数将处理后的图像作为输入,通过模型进行传递,并解释输出以确定最可能的类别标签。

// 在 static/tfjs.js 中继续
async function makePrediction(processedImage) {
  try {
    const prediction = model.predict(processedImage);
    const highestPredictionIndex = await tf.argMax(prediction, 1).data();
    const label = IMAGENET_CLASSES[highestPredictionIndex];
    outputDiv.textContent = `Prediction: ${label}`;
  } catch (error) {
    outputDiv.textContent = `Error making prediction: ${error}`;
  }
}

在这一步中,我们使用 model.predict(processedImage) 函数将预处理后的图像输入到 TensorFlow 模型中。tf.argMax(prediction, 1).data() 函数用于在预测数组中找到最高值的索引,该索引对应于图像最可能的类别标签。然后将此标签显示给用户。

切换到“Web 8080”标签页并重新加载网页以查看以下效果。

✨ 查看解决方案并练习

总结

在这个项目中,你学习了如何在使用 TensorFlow.js 的 Flask Web 应用程序中使用预训练的 MobileNetV2 模型。我们首先安装了必要的依赖项。然后,我们将 MobileNetV2 模型导出为与 TensorFlow.js 兼容的格式,并设置了一个 Flask 应用程序来提供我们的网页服务。最后,我们创建了一个 HTML 页面,该页面使用 TensorFlow.js 通过我们的 MobileNetV2 模型在浏览器中对图像进行分类。通过遵循这些步骤,你创建了一个使用深度学习进行实时图像分类的 Web 应用程序。

这个项目展示了将传统 Web 技术与先进的机器学习模型相结合以创建交互式和智能 Web 应用程序的强大功能。你可以通过添加更多功能来扩展这个项目,例如上传不同图像的能力、改进用户界面,甚至为不同任务集成更复杂的模型。

您可能感兴趣的其他 JavaScript 教程