... print("Mittelwert:", scores.mean())
... print("Standardabweichung:", scores.std())
...
>>> display_scores(tree_rmse_scores)
Scores: [70194.33680785 66855.16363941 72432.58244769 70758.73896782
71115.88230639 75585.14172901 70262.86139133 70273.6325285
75366.87952553 71231.65726027]
Mittelwert: 71407.68766037929
Standardabweichung: 2439.4345041191004
Nun steht der Entscheidungsbaum nicht mehr so gut da wie zuvor. Tatsächlich scheint er schlechter als das lineare Regressionsmodell abzuschneiden! Beachten Sie, dass wir mit der Kreuzvalidierung nicht nur eine Schätzung der Leistung unseres Modells erhalten, sondern auch eine Angabe darüber, wie präzise diese Schätzung ist (d.h. die Standardabweichung). Der Entscheidungsbaum hat einen RMSE von etwa 71407, ±2439. Nur mit einem Validierungsdatensatz würden Sie diese Information nicht erhalten. Allerdings erfordert die Kreuzvalidierung, dass das Modell mehrmals trainiert wird. Deshalb ist sie nicht immer praktikabel.
Berechnen wir, um auf Nummer sicher zu gehen, die gleichen Scores für das lineare Regressionsmodell:
>>> lin_scores = cross_val_score(lin_reg, housing_prepared, housing_labels,
... scoring="neg_mean_squared_error", cv=10)
...
>>> lin_rmse_scores = np.sqrt(-lin_scores)
>>> display_scores(lin_rmse_scores)
Scores: [66782.73843989 66960.118071 70347.95244419 74739.57052552
68031.13388938 71193.84183426 64969.63056405 68281.61137997
71552.91566558 67665.10082067]
Mittelwert: 69052.46136345083
Standardabweichung: 2731.674001798348
Unsere Vermutung war richtig: Das Overfitting im Entscheidungsbaum ist so stark, dass dieses Modell ungenauer ist als die lineare Regression.
Probieren wir noch ein letztes Modell aus: den RandomForestRegressor. Wie wir in Kapitel 7 sehen werden, trainiert ein Random Forest viele Entscheidungsbäume auf zufällig ausgewählten Teilmengen der Merkmale und mittelt deren Vorhersagen. Ein Modell aus vielen anderen Modellen zusammenzusetzen, nennt man Ensemble Learning. Es ist oft eine gute Möglichkeit, ML-Algorithmen noch besser zu nutzen. Wir werden uns mit dem Code nicht groß beschäftigen, da er fast der gleiche ist wie bei den anderen beiden Modellen:
>>> from sklearn.ensemble import RandomForestRegressor
>>> forest_reg = RandomForestRegressor()
>>> forest_reg.fit(housing_prepared, housing_labels)
>>> [...]
>>> forest_rmse
18603.515021376355
>>> display_scores(forest_rmse_scores)
Scores: [49519.80364233 47461.9115823 50029.02762854 52325.28068953
49308.39426421 53446.37892622 48634.8036574 47585.73832311
53490.10699751 50021.5852922 ]
Mittelwert: 50182.303100336096
Standardabweichung: 2097.0810550985693
Wow, das ist viel besser: Random Forests wirken sehr vielversprechend. Allerdings ist der Score auf dem Trainingsdatensatz noch immer viel geringer als auf den Validierungsdatensätzen. Dies deutet darauf hin, dass das Modell die Trainingsdaten noch immer overfittet. Gegenmaßnahmen zum Overfitting sind, das Modell zu vereinfachen, Restriktionen einzuführen (es zu regularisieren) oder deutlich mehr Trainingsdaten zu beschaffen. Bevor Sie sich aber eingehender mit Random Forests beschäftigen, sollten Sie mehr Modelle aus anderen Familien von Machine-Learning-Algorithmen ausprobieren (mehrere Support Vector Machines mit unterschiedlichen Kernels, ein neuronales Netz und so weiter), ohne jedoch zu viel Zeit mit dem Einstellen der Hyperparameter zu verbringen. Das Ziel ist, eine engere Auswahl (zwei bis fünf) der vielversprechendsten Modelle zu treffen.
|
Sie sollten alle Modelle, mit denen Sie experimentieren, abspeichern, sodass Sie später zu jedem beliebigen Modell zurückkehren können. Stellen Sie sicher, dass Sie sowohl die Hyperparameter als auch die trainierten Parameter abspeichern, ebenso die Scores der Kreuzvalidierung und eventuell sogar die Vorhersagen. Damit können Sie leichter Vergleiche der Scores und der Arten der Fehler zwischen unterschiedlichen Modellen ziehen. Sie können in Scikit-Learn erstellte Modelle mit dem Python-Modul pickle oder der Bibliothek joblib speichern, wobei Letztere große NumPy-Arrays effizienter serialisiert: import joblib joblib.dump(my_model, "my_model.pkl") # und später ... my_model_loaded = joblib.load("my_model.pkl") |
Optimiere das Modell
Nehmen wir an, Sie hätten inzwischen eine engere Auswahl Erfolg versprechender Modelle. Nun müssen Sie diese optimieren. Betrachten wir dazu einige Alternativen.
Gittersuche
Eine Möglichkeit wäre, von Hand an den Hyperparametern herumzubasteln, bis Sie eine gute Kombination finden. Dies wäre sehr mühselig, und Sie hätten nicht die Zeit, viele Kombinationen auszuprobieren.
Stattdessen sollten Sie die Scikit-Learn-Klasse GridSearchCV die Suche für Sie erledigen lassen. Sie müssen ihr lediglich sagen, mit welchen Hyperparametern Sie experimentieren möchten und welche Werte ausprobiert werden sollen. Dann werden alle möglichen Kombinationen von Hyperparametern über eine Kreuzvalidierung evaluiert. Der folgende Code sucht die beste Kombination der Hyperparameter für den RandomForestRegressor:
from sklearn.model_selection import GridSearchCV
param_grid = [
{'n_estimators': [3, 10, 30], 'max_features': [2, 4, 6, 8]},
{'bootstrap': [False], 'n_estimators': [3, 10], 'max_features': [2, 3, 4]},
]
forest_reg = RandomForestRegressor()
grid_search = GridSearchCV(forest_reg, param_grid, cv=5,
scoring='neg_mean_squared_error',
return_train_score=True)
grid_search.fit(housing_prepared, housing_labels)
|
Wenn Sie keine Ahnung haben, welchen Wert ein Hyperparameter haben soll, können Sie einfach Zehnerpotenzen ausprobieren (oder für eine feinkörnigere Suche Potenzen einer kleineren Zahl, wie im Beispiel beim Hyperparameter n_estimators). |
Das param_grid weist Scikit-Learn an, zuerst alle Kombinationen von 3 × 4 = 12 der Hyperparameter n_estimators und max_features mit den im ersten dict angegebenen Werten auszuprobieren (keine Sorge, wenn Sie die Bedeutung der Hyperparameter noch nicht kennen; diese werden in Kapitel 7 erklärt). Anschließend werden alle Kombinationen 2 × 3 = 6 der Hyperparameter im zweiten dict ausprobiert, diesmal ist jedoch der Hyperparameter bootstrap auf False statt auf True (den Standardwert) gesetzt.