Einführung
In diesem Lab verwenden wir einen Zufälligen Wald, um die Wichtigkeit von Merkmalen bei einer künstlichen Klassifizierungsaufgabe zu bewerten. Wir werden einen synthetischen Datensatz mit nur 3 informativen Merkmalen generieren. Die Merkmalswichtigkeiten des Walds werden geplottet, zusammen mit ihrer Variabilität zwischen den Bäumen, die durch die Fehlerbalken dargestellt wird.
VM-Tipps
Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.
Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.
Wenn Sie während des Lernens Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.
Bibliotheken importieren
Wir werden die erforderlichen Bibliotheken für dieses Lab importieren.
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
import pandas as pd
import numpy as np
import time
Daten generieren
Wir werden einen synthetischen Datensatz mit nur 3 informativen Merkmalen generieren. Wir werden ausdrücklich den Datensatz nicht mischen, um sicherzustellen, dass die informativen Merkmale den drei ersten Spalten von X entsprechen. Darüber hinaus werden wir unseren Datensatz in Trainings- und Testuntergruppen aufteilen.
X, y = make_classification(
n_samples=1000,
n_features=10,
n_informative=3,
n_redundant=0,
n_repeated=0,
n_classes=2,
random_state=0,
shuffle=False,
)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
Zufälligen Wald anpassen
Wir werden einen Zufälligen Wald-Klassifizierer anpassen, um die Merkmalswichtigkeiten zu berechnen.
feature_names = [f"feature {i}" for i in range(X.shape[1])]
forest = RandomForestClassifier(random_state=0)
forest.fit(X_train, y_train)
Merkmalswichtigkeit basierend auf der durchschnittlichen Verringerung der Unreinheit
Die Merkmalswichtigkeiten werden durch das angepasste Attribut feature_importances_ bereitgestellt und berechnet sich als Mittelwert und Standardabweichung der Akkumulation der Unreinheitsverringerung innerhalb jedes Baumes. Wir werden die auf Unreinheit basierende Wichtigkeit plotten.
start_time = time.time()
importances = forest.feature_importances_
std = np.std([tree.feature_importances_ for tree in forest.estimators_], axis=0)
elapsed_time = time.time() - start_time
print(f"Elapsed time to compute the importances: {elapsed_time:.3f} seconds")
forest_importances = pd.Series(importances, index=feature_names)
fig, ax = plt.subplots()
forest_importances.plot.bar(yerr=std, ax=ax)
ax.set_title("Feature importances using MDI")
ax.set_ylabel("Mean decrease in impurity")
fig.tight_layout()
Merkmalswichtigkeit basierend auf Merkmalspermutation
Die Merkmalspermutationswichtigkeit überwindet die Einschränkungen der auf Unreinheit basierenden Merkmalswichtigkeit: Sie haben keinen Bias gegenüber Merkmalen mit hoher Kardinalität und können auf einem separat gehaltenen Testset berechnet werden. Wir werden die volle Permutationswichtigkeit berechnen. Die Merkmale werden n-mal durchgemischt und das Modell erneut angepasst, um ihre Wichtigkeit zu schätzen. Wir werden die Wichtigkeitsrangfolge plotten.
start_time = time.time()
result = permutation_importance(
forest, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2
)
elapsed_time = time.time() - start_time
print(f"Elapsed time to compute the importances: {elapsed_time:.3f} seconds")
forest_importances = pd.Series(result.importances_mean, index=feature_names)
fig, ax = plt.subplots()
forest_importances.plot.bar(yerr=result.importances_std, ax=ax)
ax.set_title("Feature importances using permutation on full model")
ax.set_ylabel("Mean accuracy decrease")
fig.tight_layout()
plt.show()
Zusammenfassung
In diesem Lab haben wir einen synthetischen Datensatz mit nur 3 informativen Merkmalen erzeugt und einen Zufälligen Wald verwendet, um die Wichtigkeit der Merkmale zu evaluieren. Wir haben die Merkmalswichtigkeiten des Walds geplottet, zusammen mit ihrer Variabilität zwischen den Bäumen, die durch die Fehlerbalken dargestellt wird. Wir haben die auf Unreinheit basierende Wichtigkeit und die Merkmalspermutationswichtigkeit verwendet, um die Merkmalswichtigkeiten zu berechnen. Die gleichen Merkmale wurden mit beiden Methoden als die wichtigsten erkannt.