How to perform group K-fold cross validation with Apache Spark

Learn how to perform group K-fold cross validation with Apache Spark on Databricks.

Written by Adam Pavlacka

Last published at: May 16th, 2022

Cross validation randomly splits the training data into a specified number of folds. To prevent data leakage where the same data shows up in multiple folds you can use groups. scikit-learn supports group K-fold cross validation to ensure that the folds are distinct and non-overlapping.

On Spark you can use the spark-sklearn library, which distributes tuning of scikit-learn models, to take advantage of this method. This example tunes a scikit-learn random forest model with the group k-fold method on Spark with a grp variable:


from sklearn.ensemble import RandomForestClassifier
from spark_sklearn import GridSearchCV
from sklearn.model_selection import GroupKFold
param_grid = {"max_depth": [8, 12, None],
              "max_features": [1, 3, 10],
              "min_samples_split": [1, 3, 10],
              "min_samples_leaf": [1, 3, 10],
              "bootstrap": [True, False],
              "criterion": ["gini", "entropy"],
              "n_estimators": [20, 40, 80]}
group_kfold = GroupKFold(n_splits=3)
gs = GridSearchCV(sc, estimator = RandomForestClassifier(random_state=42), param_grid=param_grid, cv = group_kfold), y1 ,grp)


  • The library that is used to run the grid search is called spark-sklearn, so you must pass in the Spark context (sc parameter) first.
  • The X1 and y1 parameters must be pandas DataFrames. This grid search option only works on data that fits on the driver.