How to perform group K-fold cross validation with Apache Spark
Cross validation randomly splits the training data into a specified number of folds. To prevent data leakage where the same data shows up in multiple folds you can use groups. scikit-learn
supports group K-fold cross validation to ensure that the folds are distinct and non-overlapping.
On Spark you can use the spark-sklearn
library, which distributes tuning of scikit-learn models, to take advantage of this method. This example tunes a scikit-learn random forest model with the group k-fold method on Spark with a grp
variable:
from sklearn.ensemble import RandomForestClassifier
from spark_sklearn import GridSearchCV
from sklearn.model_selection import GroupKFold
param_grid = {"max_depth": [8, 12, None],
"max_features": [1, 3, 10],
"min_samples_split": [1, 3, 10],
"min_samples_leaf": [1, 3, 10],
"bootstrap": [True, False],
"criterion": ["gini", "entropy"],
"n_estimators": [20, 40, 80]}
group_kfold = GroupKFold(n_splits=3)
gs = GridSearchCV(sc, estimator = RandomForestClassifier(random_state=42), param_grid=param_grid, cv = group_kfold)
gs.fit(X1, y1 ,grp)
Note
- The library that is used to run the grid search is called
spark-sklearn
, so you must pass in the Spark context (sc
parameter) first. - The
X1
andy1
parameters must be pandas DataFrames. This grid search option only works on data that fits on the driver.