Gaußsche Prozesse-Klassifikation auf XOR-Datensatz

Machine LearningMachine LearningBeginner
Jetzt üben

This tutorial is from open-source community. Access the source code

💡 Dieser Artikel wurde von AI-Assistenten übersetzt. Um die englische Version anzuzeigen, können Sie hier klicken

Einführung

In diesem Lab werden wir ein Beispiel der Gaußschen Prozesse-Klassifikation (GPC) auf dem XOR-Datensatz mit scikit-learn durchgehen. Wir werden die Ergebnisse vergleichen, die durch die Verwendung eines stationären, isotropen Kerns (RBF) und eines nicht-stationären Kerns (DotProduct) erhalten werden.

Tipps für die VM

Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.

Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund der Einschränkungen in Jupyter Notebook nicht automatisiert werden.

Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL sklearn(("Sklearn")) -.-> sklearn/CoreModelsandAlgorithmsGroup(["Core Models and Algorithms"]) ml(("Machine Learning")) -.-> ml/FrameworkandSoftwareGroup(["Framework and Software"]) sklearn/CoreModelsandAlgorithmsGroup -.-> sklearn/gaussian_process("Gaussian Processes") ml/FrameworkandSoftwareGroup -.-> ml/sklearn("scikit-learn") subgraph Lab Skills sklearn/gaussian_process -.-> lab-49142{{"Gaußsche Prozesse-Klassifikation auf XOR-Datensatz"}} ml/sklearn -.-> lab-49142{{"Gaußsche Prozesse-Klassifikation auf XOR-Datensatz"}} end

Bibliotheken importieren

In diesem Schritt werden wir die erforderlichen Bibliotheken für dieses Lab importieren.

import numpy as np
import matplotlib.pyplot as plt

from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF, DotProduct

Erstellen des XOR-Datensatzes

In diesem Schritt werden wir einen XOR-Datensatz mit numpy erstellen. Wir werden die Funktion logical_xor verwenden, um die Labels basierend auf den Eingabefeatures zu erstellen.

xx, yy = np.meshgrid(np.linspace(-3, 3, 50), np.linspace(-3, 3, 50))
rng = np.random.RandomState(0)
X = rng.randn(200, 2)
Y = np.logical_xor(X[:, 0] > 0, X[:, 1] > 0)

Anpassen des Modells

In diesem Schritt werden wir den Gaußschen Prozesse-Klassifizierer an den Datensatz anpassen. Wir werden zwei verschiedene Kerne für den Vergleich verwenden - RBF und DotProduct.

plt.figure(figsize=(10, 5))
kernels = [1.0 * RBF(length_scale=1.15), 1.0 * DotProduct(sigma_0=1.0) ** 2]
for i, kernel in enumerate(kernels):
    clf = GaussianProcessClassifier(kernel=kernel, warm_start=True).fit(X, Y)

    ## plot the decision function for each datapoint on the grid
    Z = clf.predict_proba(np.vstack((xx.ravel(), yy.ravel())).T)[:, 1]
    Z = Z.reshape(xx.shape)

    plt.subplot(1, 2, i + 1)
    image = plt.imshow(
        Z,
        interpolation="nearest",
        extent=(xx.min(), xx.max(), yy.min(), yy.max()),
        aspect="auto",
        origin="lower",
        cmap=plt.cm.PuOr_r,
    )
    contours = plt.contour(xx, yy, Z, levels=[0.5], linewidths=2, colors=["k"])
    plt.scatter(X[:, 0], X[:, 1], s=30, c=Y, cmap=plt.cm.Paired, edgecolors=(0, 0, 0))
    plt.xticks(())
    plt.yticks(())
    plt.axis([-3, 3, -3, 3])
    plt.colorbar(image)
    plt.title(
        "%s\n Log-Marginal-Likelihood:%.3f"
        % (clf.kernel_, clf.log_marginal_likelihood(clf.kernel_.theta)),
        fontsize=12,
    )

plt.tight_layout()
plt.show()

Visualisierung der Ergebnisse

In diesem Schritt werden wir die Ergebnisse visualisieren, die durch das Anpassen des Modells erhalten wurden. Wir werden die Entscheidungsfunktion für jeden Datenpunkt auf dem Gitter und einen Streudiagramm für die Eingabefeatures plotten.

Zusammenfassung

In diesem Lab haben wir ein Beispiel der Gaußschen Prozesse-Klassifikation (GPC) auf einem XOR-Datensatz mit scikit-learn durchlaufen. Wir haben die Ergebnisse verglichen, die durch die Verwendung eines stationären, isotropen Kerns (RBF) und eines nicht-stationären Kerns (DotProduct) erhalten wurden.