How to extract feature information for tree-based Apache SparkML pipeline models

Learn how to extract feature information for tree-based ML pipeline models in Databricks.

Written by Adam Pavlacka

Last published at: May 16th, 2022

When you are fitting a tree-based model, such as a decision tree, random forest, or gradient boosted tree, it is helpful to be able to review the feature importance levels along with the feature names. Typically models in SparkML are fit as the last stage of the pipeline. To extract the relevant feature information from the pipeline with the tree model, you must extract the correct pipeline stage. You can extract the feature names from the VectorAssembler object:

%python

from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline

pipeline = Pipeline(stages=[indexer, assembler, decision_tree)
DTmodel = pipeline.fit(train)
va = dtModel.stages[-2]
tree = DTmodel.stages[-1]

display(tree) #visualize the decision tree model
print(tree.toDebugString) #print the nodes of the decision tree model

list(zip(va.getInputCols(), tree.featureImportances))

You can also tune a tree-based model using a cross validator in the last stage of the pipeline. To visualize the decision tree and print the feature importance levels, you extract the bestModel from the CrossValidator object:

%python

from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

cv = CrossValidator(estimator=decision_tree, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3)
pipelineCV = Pipeline(stages=[indexer, assembler, cv)
DTmodelCV = pipelineCV.fit(train)
va = DTmodelCV.stages[-2]
treeCV = DTmodelCV.stages[-1].bestModel

display(treeCV) #visualize the best decision tree model
print(treeCV.toDebugString) #print the nodes of the decision tree model

list(zip(va.getInputCols(), treeCV.featureImportances))

The display function visualizes decision tree models only. See Machine learning visualizations (AWS | Azure | GCP).