Datenaufteilung in Trainings- und Testsets mit train_test_split aus sklearn.model_selection
In diesem Schritt bereiten wir unsere Daten für den Trainingsprozess vor. Ein entscheidender Teil des maschinellen Lernens ist die Bewertung des Modells anhand von Daten, die es noch nie zuvor gesehen hat. Dazu teilen wir unseren Datensatz in zwei Teile auf: ein Trainingsset und ein Testset. Das Modell lernt aus dem Trainingsset, und wir verwenden das Testset, um zu sehen, wie gut es abschneidet.
Zuerst müssen wir unsere Features (die Eingabevariablen, X) von unserem Ziel (dem Wert, den wir vorhersagen möchten, y) trennen. In unserem Fall wird X aus allen Spalten außer MedHouseVal bestehen, und y wird die Spalte MedHouseVal sein.
Anschließend verwenden wir die Funktion train_test_split aus sklearn.model_selection, um die Aufteilung durchzuführen.
Fügen Sie den folgenden Code zu Ihrer Datei main.py hinzu.
from sklearn.model_selection import train_test_split
## Prepare the data
X = california_df.drop('MedHouseVal', axis=1) ## Features (input variables)
y = california_df['MedHouseVal'] ## Target variable (what we want to predict)
## Split the data into training and testing sets
## test_size=0.2: Reserve 20% of data for testing, 80% for training
## random_state=42: Ensures reproducible splits (same result every run)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
## Print the shapes of the new datasets to confirm the split
print("\n--- Data Split ---")
print("X_train shape:", X_train.shape) ## Training features
print("X_test shape:", X_test.shape) ## Test features
print("y_train shape:", y_train.shape) ## Training target values
print("y_test shape:", y_test.shape) ## Test target values
Führen Sie das Skript nun erneut im Terminal aus:
python3 main.py
Sie sehen die Formen der neu erstellten Trainings- und Testsets unter dem DataFrame ausgegeben. Dies bestätigt, dass die Daten korrekt aufgeteilt wurden.
--- Data Split ---
X_train shape: (16512, 8)
X_test shape: (4128, 8)
y_train shape: (16512,)
y_test shape: (4128,)