Merkmalswichtigkeit mit Zufälligem Wald

Beginner

This tutorial is from open-source community. Access the source code

Einführung

In diesem Lab verwenden wir einen Zufälligen Wald, um die Wichtigkeit von Merkmalen bei einer künstlichen Klassifizierungsaufgabe zu bewerten. Wir werden einen synthetischen Datensatz mit nur 3 informativen Merkmalen generieren. Die Merkmalswichtigkeiten des Walds werden geplottet, zusammen mit ihrer Variabilität zwischen den Bäumen, die durch die Fehlerbalken dargestellt wird.

VM-Tipps

Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.

Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.

Wenn Sie während des Lernens Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.

Bibliotheken importieren

Wir werden die erforderlichen Bibliotheken für dieses Lab importieren.

import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
import pandas as pd
import numpy as np
import time

Daten generieren

Wir werden einen synthetischen Datensatz mit nur 3 informativen Merkmalen generieren. Wir werden ausdrücklich den Datensatz nicht mischen, um sicherzustellen, dass die informativen Merkmale den drei ersten Spalten von X entsprechen. Darüber hinaus werden wir unseren Datensatz in Trainings- und Testuntergruppen aufteilen.

X, y = make_classification(
    n_samples=1000,
    n_features=10,
    n_informative=3,
    n_redundant=0,
    n_repeated=0,
    n_classes=2,
    random_state=0,
    shuffle=False,
)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)

Zufälligen Wald anpassen

Wir werden einen Zufälligen Wald-Klassifizierer anpassen, um die Merkmalswichtigkeiten zu berechnen.

feature_names = [f"feature {i}" for i in range(X.shape[1])]
forest = RandomForestClassifier(random_state=0)
forest.fit(X_train, y_train)

Merkmalswichtigkeit basierend auf der durchschnittlichen Verringerung der Unreinheit

Die Merkmalswichtigkeiten werden durch das angepasste Attribut feature_importances_ bereitgestellt und berechnet sich als Mittelwert und Standardabweichung der Akkumulation der Unreinheitsverringerung innerhalb jedes Baumes. Wir werden die auf Unreinheit basierende Wichtigkeit plotten.

start_time = time.time()
importances = forest.feature_importances_
std = np.std([tree.feature_importances_ for tree in forest.estimators_], axis=0)
elapsed_time = time.time() - start_time

print(f"Elapsed time to compute the importances: {elapsed_time:.3f} seconds")

forest_importances = pd.Series(importances, index=feature_names)

fig, ax = plt.subplots()
forest_importances.plot.bar(yerr=std, ax=ax)
ax.set_title("Feature importances using MDI")
ax.set_ylabel("Mean decrease in impurity")
fig.tight_layout()

Merkmalswichtigkeit basierend auf Merkmalspermutation

Die Merkmalspermutationswichtigkeit überwindet die Einschränkungen der auf Unreinheit basierenden Merkmalswichtigkeit: Sie haben keinen Bias gegenüber Merkmalen mit hoher Kardinalität und können auf einem separat gehaltenen Testset berechnet werden. Wir werden die volle Permutationswichtigkeit berechnen. Die Merkmale werden n-mal durchgemischt und das Modell erneut angepasst, um ihre Wichtigkeit zu schätzen. Wir werden die Wichtigkeitsrangfolge plotten.

start_time = time.time()
result = permutation_importance(
    forest, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2
)
elapsed_time = time.time() - start_time
print(f"Elapsed time to compute the importances: {elapsed_time:.3f} seconds")

forest_importances = pd.Series(result.importances_mean, index=feature_names)

fig, ax = plt.subplots()
forest_importances.plot.bar(yerr=result.importances_std, ax=ax)
ax.set_title("Feature importances using permutation on full model")
ax.set_ylabel("Mean accuracy decrease")
fig.tight_layout()
plt.show()

Zusammenfassung

In diesem Lab haben wir einen synthetischen Datensatz mit nur 3 informativen Merkmalen erzeugt und einen Zufälligen Wald verwendet, um die Wichtigkeit der Merkmale zu evaluieren. Wir haben die Merkmalswichtigkeiten des Walds geplottet, zusammen mit ihrer Variabilität zwischen den Bäumen, die durch die Fehlerbalken dargestellt wird. Wir haben die auf Unreinheit basierende Wichtigkeit und die Merkmalspermutationswichtigkeit verwendet, um die Merkmalswichtigkeiten zu berechnen. Die gleichen Merkmale wurden mit beiden Methoden als die wichtigsten erkannt.