TensorFlow.js と Flask を使った MobileNet の展開

JavaScriptBeginner
オンラインで実践に進む

はじめに

このプロジェクトでは、Flask ウェブアプリケーション内で TensorFlow.js を使って事前学習済みの MobileNetV2 モデルを展開するプロセスを案内します。MobileNetV2 は主に画像分類に使用される軽量の深層ニューラルネットワークです。TensorFlow.js を使うことで、機械学習モデルを直接ブラウザ内で実行でき、インタラクティブなウェブアプリケーションが可能になります。Python のウェブフレームワークである Flask が、アプリケーションをホストするバックエンドとして機能します。このプロジェクトが終了すると、MobileNetV2 モデルを使って画像を即座に分類する動作するウェブアプリケーションを持つようになります。

👀 プレビュー

🎯 タスク

このプロジェクトで学ぶことは以下の通りです。

  • 事前学習済みの MobileNetV2 モデルを Keras から TensorFlow.js 互換形式にエクスポートする方法。
  • ウェブコンテンツとモデルを提供するためのシンプルな Flask アプリケーションを作成する方法。
  • 分類用の画像をアップロードして表示するための HTML ページをデザインする方法。
  • ブラウザ内でエクスポートされたモデルを読み込むために TensorFlow.js を使う方法。
  • MobileNetV2 の入力要件に合うようにブラウザ内で画像を前処理する方法。
  • ブラウザ内でモデルを実行して画像を分類し、結果を表示する方法。

🏆 成果

このプロジェクトを完了すると、以下のことができるようになります。

  • 事前学習済みの Keras モデルを TensorFlow.js で使用できる形式に変換し、ML モデルをブラウザ内で実行できるようにする。
  • Flask アプリケーションをセットアップし、HTML コンテンツと静的ファイルを提供する。
  • ウェブアプリケーションに TensorFlow.js を統合して、クライアントサイドで機械学習タスクを実行する。
  • JavaScript で画像を前処理して、深層学習モデルの入力要件に合うようにする。
  • ブラウザ内で深層学習モデルを使って予測を行い、結果をウェブページ上で動的に表示する。

プロジェクト環境とファイルの準備

コーディングを始める前に、正しくプロジェクト環境をセットアップすることが重要です。これには必要なパッケージをインストールし、既に存在するプロジェクトファイル構造を理解することが含まれます。

まず、初期のプロジェクトファイル構造に慣れてください。作業ディレクトリには既に以下のファイルとフォルダが用意されています。

tree

出力:

.
├── app.py
├── model_convert.py
├── static
│ ├── imagenet_classes.js
│ ├── tfjs.css
│ └── tfjs.js
└── templates
└── tfjs.html

2 directories, 6 files

プロジェクトの構造はいくつかの重要なコンポーネントで構成されており、それぞれが TensorFlow.js と Flask を使った MobileNetV2 モデルを用いた画像分類のウェブアプリケーションの展開において重要な役割を果たします。以下はプロジェクト内の各ディレクトリとファイルの概要です。

  • app.py:これは Flask アプリケーションのメインの Python ファイルです。Flask アプリを初期化し、ウェブページのルーティングを設定し、TensorFlow.js モデルとウェブコンテンツを提供するために必要なバックエンドロジックを含みます。
  • model_convert.py:この Python スクリプトは、事前学習済みの MobileNetV2 モデルを読み込み、TensorFlow.js と互換性のある形式に変換する責任があります。この変換は、モデルをウェブブラウザで実行できるようにするために重要です。
  • static/:このディレクトリは、ウェブアプリケーションに必要な静的ファイルを格納します。これには以下が含まれます。
    • imagenet_classes.js:ImageNet クラスを含む JavaScript ファイル。このファイルは、モデルの数値予測を人間が読みやすいクラス名にマッピングするために使用されます。
    • tfjs.css:新しく追加されたこのケースケーディング スタイル シート (CSS) ファイルは、ウェブアプリケーションのユーザー インターフェイスをスタイリッシュにするために使用されます。レイアウト、色、フォントなど、アプリケーションの視覚的な側面を定義し、より魅力的でユーザーフレンドリーなインターフェイスを保証します。
    • tfjs.js:もう 1 つの新しいファイルであるこの JavaScript ファイルは、おそらく TensorFlow.js モデルを読み込み、画像を処理し、ブラウザ内で予測を実行するロジックを含んでいます。このスクリプトは、アプリケーションのインタラクティビティの中心であり、TensorFlow.js モデルに関連するクライアントサイド操作を処理します。
  • templates/:このディレクトリには、ウェブアプリケーションの構造とレイアウトを定義する HTML ファイルが含まれています。この場合、それには以下が含まれます。
    • tfjs.html:アプリケーションの主要な HTML テンプレートである tfjs.html には、画像の表示、予測結果、おそらくファイルアップロード ボタンなどのユーザーインタラクション要素を表示するための必要なマークアップが含まれています。tfjs.js スクリプトを使ってモデル関連の機能を活用し、tfjs.css を使ってスタイリングすることで、TensorFlow.js モデルを統合しています。

この構造は、関心事を分離するように設計されており、プロジェクトをモジュール化して管理しやすくします。statictemplates ディレクトリは Flask アプリケーションにおいて標準的であり、それぞれ静的アセットと HTML テンプレートを整理するのに役立ちます。モデル変換スクリプト (model_convert.py) とメイン アプリケーションロジック (app.py) の分離は、コードのモジュラリティと保守性を高めます。

次に、必要なパッケージをインストールします。

## Install the required Python packages
pip install tensorflow==2.14.0 tensorflowjs==4.17.0 flask==3.0.2 flask-cors==4.0.0

これらのパッケージには、機械学習モデル用の TensorFlow、ウェブ環境で使用するためにモデルを変換する TensorFlow.js、ウェブサーバーを作成する Flask、およびウェブアプリケーションで一般的なクロスオリジン要求を処理する Flask-CORS が含まれます。

事前学習済みの MobileNetV2 モデルを TensorFlow.js 形式にエクスポートする

ブラウザで MobileNetV2 モデルを使用するには、まずそれを Keras から TensorFlow.js が理解できる形式にエクスポートする必要があります。

## Complete the model_convert.py

## Exporting MobileNetV2 model
import tensorflowjs as tfjs
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2

## Load the pre-trained MobileNetV2 model
model = MobileNetV2(weights='imagenet')

## Convert and save the model in TensorFlow.js format
tfjs.converters.save_keras_model(model,'static/model/')

このステップでは、TensorFlow と TensorFlow.js ライブラリを利用して、事前学習済みの MobileNetV2 モデルを読み込み、TensorFlow.js と互換性のある形式に変換します。

MobileNetV2 はその効率性と比較的小さなサイズのために選ばれ、ウェブ展開に適しています。この変換は必要ですが、Python ベースの TensorFlow で使用される元のモデル形式は、ウェブ環境で直接使用できません。tfjs.converters.save_keras_model 関数は Keras モデルを受け取り、TensorFlow.js が後でウェブアプリケーションで簡単に読み込めるように構造化されたディレクトリに保存します。

次に、以下を実行できます。

python model_convert.py

変換されたモデルは static/model/ フォルダに保存されます。

ls static/model

## group1-shard1of4.bin  group1-shard2of4.bin  group1-shard3of4.bin  group1-shard4of4.bin  model.json

このプロセスには、モデルの重みとアーキテクチャをそれぞれ一連のシャード ファイルと model.json ファイルに保存することが含まれます。

Flask アプリケーションの作成

次に、ウェブページと TensorFlow.js モデルを提供するためのシンプルな Flask アプリケーションをセットアップします。

## Complete the app.py

## Setting up the Flask application
from flask import Flask, render_template
from flask_cors import CORS

app = Flask(__name__)
cors = CORS(app)  ## Enable Cross-Origin Resource Sharing

@app.route("/")
def hello():
    ## Serve the HTML page
    return render_template('tfjs.html')

if __name__ == '__main__':
    app.run(host='0.0.0.0', port='8080', debug=True)

このステップでは、基本的な Flask ウェブサーバーをセットアップします。Flask は Python で書かれたマイクロウェブフレームワークで、そのシンプルさと使いやすさで知られています。

まず、それぞれのライブラリから Flask と CORS(クロスオリジン リソース共有)をインポートします。CORS は、異なるドメインからのリソース要求を行うウェブアプリケーションにとって不可欠で、ウェブアプリが Flask サーバーに安全に要求を行えるようにします。

クライアントサイドの TensorFlow.js コードを含む HTML ページ (tfjs.html) を提供する単純なルート ("/") を定義します。Flask アプリケーションは、ローカルマシン (host='0.0.0.0') で実行され、ポート 8080 を監視します。debug=True 設定は、開発中に役立ちます。詳細なエラー メッセージを提供し、コードの変更が検出されると自動的にサーバーを再読み込みします。

次に、Flask ウェブアプリケーションを実行できます。

python app.py
## * Serving Flask app 'app'
## * Debug mode: on
## WARNING: This is a development server. Do not use it in a production ## deployment. Use a production WSGI server instead.
## * Running on all addresses (0.0.0.0)
## * Running on http://127.0.0.1:8080
## * Running on http://172.18.0.7:8080
## Press CTRL+C to quit

HTML 構造の準備

次に、templates/tfjs.html でアプリケーションの HTML 構造を作成する必要があります。これには、画像のアップロード、プレビュー、予測結果の表示のレイアウトが含まれます。

<!-- HTML structure for the Image Prediction application -->
<!doctype html>
<html lang="en">
  <head>
    <meta charset="UTF-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    <title>Image Prediction</title>
    <link
      href="https://cdn.jsdelivr.net/npm/tailwindcss@2.0.2/dist/tailwind.min.css"
      rel="stylesheet"
    />
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <link rel="stylesheet" href="static/tfjs.css" />
  </head>
  <body class="flex flex-col items-center justify-center min-h-screen">
    <div class="card bg-white p-6 rounded-lg max-w-sm">
      <h1 class="text-xl font-semibold mb-4 text-center">Image Prediction</h1>
      <div class="flex flex-col items-center">
        <label
          for="imageUpload"
          class="button-custom cursor-pointer mb-4 flex items-center justify-center"
        >
          <span>Upload Image</span>
          <input
            type="file"
            id="imageUpload"
            class="file-input"
            accept="image/*"
          />
        </label>
        <div
          id="imagePreviewContainer"
          class="mb-4 w-56 h-56 border border-dashed border-gray-300 flex items-center justify-center"
        >
          <img
            id="imagePreview"
            class="max-w-full max-h-full"
            style="display: none"
          />
        </div>
        <h5 id="output" class="text-md text-gray-700">
          Upload an image to start prediction
        </h5>
        <script type="module" src="static/tfjs.js"></script>
      </div>
    </div>
  </body>
</html>

このステップでは、HTML を使ってウェブアプリケーションの基本構造を作成します。文書タイプを HTML と定義し、言語属性を英語に設定します。

head セクションでは、レスポンシブ デザイン用の文字セットやビューポート設定などのメタデータを含めます。また、Tailwind CSS ライブラリにリンクして、アプリケーションのスタイリングにそのユーティリティ クラスを利用します。body セクションには、card クラスの div 要素が含まれており、これがアプリケーションのコンテンツのコンテナとなります。

このコンテナの中には、アプリケーションのタイトル用の h1 タグ、画像アップロード ボタン用の label 要素(Tailwind CSS とカスタム クラスを使ってボタンのようにスタイリングされています)、画像プレビュー用のコンテナとなる div 要素があります。file 型の input 要素は非表示になっており、ラベルをクリックするとトリガーされ、ユーザーが画像をアップロードできるようになっています。idimagePreviewContainerdiv 要素にアップロードされた画像が表示され、h5 タグはユーザーにメッセージを表示するために使用されます。

スタイリングを追加する

このステップでは、static/tfjs.css で Tailwind CSS を使ってレイアウトと美観のための基本的な CSS スタイリングを追加し、アプリケーション用のカスタム スタイルも追加します。

body {
  background-color: #f0f2f5;
}
.card {
  box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.button-custom {
  background-color: #4f46e5; /* Indigo 600 */
  color: white;
  padding: 0.5rem 1.5rem;
  border-radius: 0.375rem; /* rounded-md */
  transition: background-color 0.2s;
}
.button-custom:hover {
  background-color: #4338ca; /* Indigo 700 */
}
.file-input {
  opacity: 0;
  position: absolute;
  z-index: -1;
}

このステップでは、ウェブアプリケーションの外観を向上させるためにカスタム CSS スタイルを追加します。

ボディの背景色をニュートラルな背景にするために淡い灰色に設定します。card クラスはボックス シャドウを追加して、コンテナにカードのような効果を与えます。button-custom クラスは、インディゴの背景色、白い文字、パディング、丸い角でアップロード ボタンをスタイリングします。

また、ホバー時にボタンの背景色を少し変更するホバー エフェクトも含めており、ユーザーに視覚的なフィードバックを提供します。file-input クラスは、実際のファイル入力要素を非表示にするために使用され、カスタム スタイル付きのラベルをファイルアップロードの主なインタラクティブ要素にします。

TensorFlow モデルを読み込む

残りの手順では、templates/tfjs.js ファイルを完成させます。

まず、TensorFlow.js をプロジェクトに追加し、その後、JavaScript コードを書いて TensorFlow モデルを読み込みます。予測を行う前に、モデルの読み込みを待つために非同期関数を使用します。

import { IMAGENET_CLASSES } from "./imagenet_classes.js";

const outputDiv = document.getElementById("output");
let model;

async function loadModel() {
  try {
    outputDiv.textContent = "Loading TF Model...";
    model = await tf.loadLayersModel(
      "https://****.labex.io/static/model/model.json"
    ); // Update URL
    outputDiv.textContent = "TF Model Loaded.";
  } catch (error) {
    outputDiv.textContent = `Error loading model: ${error}`;
  }
}

loadModel();

注:tf.loadLayersModel の URL を現在の環境の URL に置き換える必要があります。Web 8080 タブに切り替えることで見つけることができます。

TensorFlow model loading diagram

このステップでは、HTML ドキュメントにスクリプト タグを含めることで、TensorFlow.js をプロジェクトに追加します。

その後、事前学習済みの TensorFlow モデルを非同期で読み込む JavaScript コードを書きます。tf.loadLayersModel 関数を使用し、モデルの JSON ファイルの URL を指定します。この関数は、読み込まれたモデルで解決するプロミスを返します。outputDiv のテキストコンテンツを更新して、ユーザーにモデルの読み込み状態を知らせます。モデルが正常に読み込まれた場合、「TF Model Loaded.」を表示します。

それ以外の場合、エラーをキャッチしてユーザーにエラー メッセージを表示します。このステップは、画像処理が行われる前にモデルを読み込んで準備する必要があるため、アプリケーションが予測を行うために不可欠です。

画像のアップロードとプレビューを処理する

このステップでは、ユーザーが画像をアップロードし、予測を行う前にプレビューを表示する機能を作成します。アップロードされたファイルを読み取り、表示するために FileReader を使用します。

// Continue in static/tfjs.js
const imageUpload = document.getElementById("imageUpload");
const imagePreview = document.getElementById("imagePreview");

imageUpload.addEventListener("change", async (e) => {
  const file = e.target.files[0];
  if (file) {
    const reader = new FileReader();
    reader.onload = (e) => {
      const img = new Image();
      img.src = e.target.result;
      img.onload = async () => {
        imagePreview.src = img.src;
        imagePreview.style.display = "block";
        const processedImage = await preprocessImage(img);
        makePrediction(processedImage);
      };
    };
    reader.readAsDataURL(file);
  }
});

このステップでは、画像アップロード入力に対するイベントリスナーを設定します。

ユーザーがファイルを選択すると、FileReader を使用してファイルを Data URL として読み取ります。その後、Image オブジェクトを作成し、その src 属性を FileReader の結果に設定することで、実際に画像をブラウザに読み込みます。画像が読み込まれると、imagePreviewsrc 属性を画像の src に設定し、imagePreview 要素を表示することで、imagePreview コンテナ内に表示します。

その後、画像を前処理してモデルによる予測を行いますが、それらの関数は後のステップで完成させます。

予測用に画像を前処理する

予測を行う前に、アップロードされた画像を前処理して、モデルの入力要件に合わせる必要があります。

// Continue in static/tfjs.js
async function preprocessImage(imageElement) {
  try {
    let img = tf.browser.fromPixels(imageElement).toFloat();
    img = tf.image.resizeBilinear(img, [224, 224]);
    const offset = tf.scalar(127.5);
    const normalized = img.sub(offset).div(offset);
    const batched = normalized.reshape([1, 224, 224, 3]);
    return batched;
  } catch (error) {
    outputDiv.textContent = `Error in model prediction: ${error}`;
  }
}

予測を行う前に、アップロードされた画像を前処理して、TensorFlow モデルが期待する入力形式に合わせる必要があります。この前処理には、画像を必要なサイズ(この場合は 224x224 ピクセル)にリサイズし、画素値を正規化することが含まれます。

画像をテンソルに変換するために tf.browser.fromPixels のような TensorFlow.js 操作を使用し、リサイズするために tf.image.resizeBilinear を使用し、画素値を正規化するために算術演算を使用します。その後、前処理された画像を 1 つのバッチに整形(モデルが期待する入力形状に合わせる)して、予測の準備が整います。

予測を行う

画像が前処理されると、予測を行う準備が整います。makePrediction 関数は、前処理された画像を入力として受け取り、それをモデルに通して出力を解釈して、最も可能性の高いクラス ラベルを決定します。

// Continue in static/tfjs.js
async function makePrediction(processedImage) {
  try {
    const prediction = model.predict(processedImage);
    const highestPredictionIndex = await tf.argMax(prediction, 1).data();
    const label = IMAGENET_CLASSES[highestPredictionIndex];
    outputDiv.textContent = `Prediction: ${label}`;
  } catch (error) {
    outputDiv.textContent = `Error making prediction: ${error}`;
  }
}

このステップでは、model.predict(processedImage) 関数を使用して、前処理された画像を TensorFlow モデルに入力します。tf.argMax(prediction, 1).data() 関数は、予測配列の中で最も高い値のインデックスを見つけるために使用され、これは画像に対する最も可能性の高いクラス ラベルに対応します。その後、このラベルがユーザーに表示されます。

「Web 8080」タブに切り替えてウェブページを再読み込みすると、以下の効果が見られます。

まとめ

このプロジェクトでは、TensorFlow.js を使用して Flask ウェブアプリケーションで事前学習済みの MobileNetV2 モデルをどのように使用するかを学びました。まず必要な依存関係をインストールしました。その後、MobileNetV2 モデルを TensorFlow.js と互換性のある形式にエクスポートし、ウェブページを提供するための Flask アプリケーションをセットアップしました。最後に、MobileNetV2 モデルを使ってブラウザで画像を分類する TensorFlow.js を使用する HTML ページを作成しました。これらの手順に従うことで、ディープラーニングを使ったリアルタイム画像分類のウェブアプリケーションを作成しました。

このプロジェクトは、従来のウェブ技術と高度な機械学習モデルを組み合わせて、対話型でインテリジェントなウェブアプリケーションを作成する力を示しています。このプロジェクトを拡張するには、より多くの機能を追加することができます。たとえば、異なる画像をアップロードする機能、ユーザーインターフェイスを改善する機能、さらには異なるタスクに対してより複雑なモデルを統合する機能などです。

✨ 解答を確認して練習✨ 解答を確認して練習✨ 解答を確認して練習✨ 解答を確認して練習✨ 解答を確認して練習✨ 解答を確認して練習✨ 解答を確認して練習✨ 解答を確認して練習✨ 解答を確認して練習