事前計算済み辞書を用いた疎な符号化

Beginner

This tutorial is from open-source community. Access the source code

はじめに

この実験では、疎な符号化手法を用いて、信号をリッカー小波の疎な組み合わせとして変換する方法を学びます。リッカー(メキシコ帽子またはガウス関数の 2 階微分とも呼ばれます)は、このような断片的に一定な信号を表すのに特に良いカーネルではありません。したがって、異なる幅の原子を追加することがどれほど重要であるかがわかり、それが信号の種類に最適に適合する辞書を学習する動機付けとなります。

SparseCoder 推定器を使用して、視覚的に異なる疎な符号化手法を比較します。右のより豊富な辞書はサイズが大きくないため、同じオーダーの大きさに留まるために、より重いサブサンプリングが行われます。

VM のヒント

VM の起動が完了したら、左上隅をクリックして ノートブック タブに切り替え、Jupyter Notebook を使って練習しましょう。

時々、Jupyter Notebook が読み込み終了するまで数秒待つ必要がある場合があります。Jupyter Notebook の制限により、操作の検証は自動化できません。

学習中に問題に遭遇した場合は、Labby にお問い合わせください。セッション後にフィードバックを提供してください。すぐに問題を解決いたします。

ライブラリのインポート

必要なライブラリをインポートして始めましょう。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import SparseCoder

リッカー小波関数の定義

リッカー小波とリッカー小波の辞書を生成する関数を定義します。

def ricker_function(resolution, center, width):
    """Discrete sub-sampled Ricker (Mexican hat) wavelet"""
    x = np.linspace(0, resolution - 1, resolution)
    x = (
        (2 / (np.sqrt(3 * width) * np.pi**0.25))
        * (1 - (x - center) ** 2 / width**2)
        * np.exp(-((x - center) ** 2) / (2 * width**2))
    )
    return x


def ricker_matrix(width, resolution, n_components):
    """Dictionary of Ricker (Mexican hat) wavelets"""
    centers = np.linspace(0, resolution - 1, n_components)
    D = np.empty((n_components, resolution))
    for i, center in enumerate(centers):
        D[i] = ricker_function(resolution, center, width)
    D /= np.sqrt(np.sum(D**2, axis=1))[:, np.newaxis]
    return D

信号の生成

Matplotlib を使って信号を生成し、可視化します。

resolution = 1024
subsampling = 3  ## subsampling factor
width = 100
n_components = resolution // subsampling

## Generate a signal
y = np.linspace(0, resolution - 1, resolution)
first_quarter = y < resolution / 4
y[first_quarter] = 3.0
y[np.logical_not(first_quarter)] = -1.0

## Visualize the signal
plt.figure()
plt.plot(y)
plt.title("Original Signal")
plt.show()

小波辞書の計算

Matplotlib を使って小波辞書を計算し、可視化します。

## Compute a wavelet dictionary
D_fixed = ricker_matrix(width=width, resolution=resolution, n_components=n_components)
D_multi = np.r_[
    tuple(
        ricker_matrix(width=w, resolution=resolution, n_components=n_components // 5)
        for w in (10, 50, 100, 500, 1000)
    )
]

## Visualize the wavelet dictionary
plt.figure(figsize=(10, 5))
for i, D in enumerate((D_fixed, D_multi)):
    plt.subplot(1, 2, i + 1)
    plt.imshow(D, cmap=plt.cm.gray, interpolation="nearest")
    plt.title("Wavelet Dictionary (%s)" % ("fixed width" if i == 0 else "multiple widths"))
    plt.axis("off")
plt.show()

疎な符号化

異なる方法を使って信号に対して疎な符号化を行い、結果を可視化します。

## List the different sparse coding methods in the following format:
## (title, transform_algorithm, transform_alpha,
##  transform_n_nozero_coefs, color)
estimators = [
    ("OMP", "omp", None, 15, "navy"),
    ("Lasso", "lasso_lars", 2, None, "turquoise"),
]
lw = 2

plt.figure(figsize=(13, 6))
for subplot, (D, title) in enumerate(
    zip((D_fixed, D_multi), ("fixed width", "multiple widths"))
):
    plt.subplot(1, 2, subplot + 1)
    plt.title("Sparse coding against %s dictionary" % title)
    plt.plot(y, lw=lw, linestyle="--", label="Original signal")
    ## Do a wavelet approximation
    for title, algo, alpha, n_nonzero, color in estimators:
        coder = SparseCoder(
            dictionary=D,
            transform_n_nonzero_coefs=n_nonzero,
            transform_alpha=alpha,
            transform_algorithm=algo,
        )
        x = coder.transform(y.reshape(1, -1))
        density = len(np.flatnonzero(x))
        x = np.ravel(np.dot(x, D))
        squared_error = np.sum((y - x) ** 2)
        plt.plot(
            x,
            color=color,
            lw=lw,
            label="%s: %s nonzero coefs,\n%.2f error" % (title, density, squared_error),
        )

    ## Soft thresholding debiasing
    coder = SparseCoder(
        dictionary=D, transform_algorithm="threshold", transform_alpha=20
    )
    x = coder.transform(y.reshape(1, -1))
    _, idx = np.where(x!= 0)
    x[0, idx], _, _, _ = np.linalg.lstsq(D[idx, :].T, y, rcond=None)
    x = np.ravel(np.dot(x, D))
    squared_error = np.sum((y - x) ** 2)
    plt.plot(
        x,
        color="darkorange",
        lw=lw,
        label="Thresholding w/ debiasing:\n%d nonzero coefs, %.2f error"
        % (len(idx), squared_error),
    )
    plt.axis("tight")
    plt.legend(shadow=False, loc="best")
plt.subplots_adjust(0.04, 0.07, 0.97, 0.90, 0.09, 0.2)
plt.show()

まとめ

この実験では、疎な符号化手法と SparseCoder 推定器を使って、信号をリッカー小波の疎な組み合わせとして変換する方法を学びました。また、異なる疎な符号化手法を比較し、結果を可視化しました。