Spark ML to ONNX Model Conversion does not produce the same model - predictions differ

Define the TARGET_OPSET and then pass it as the target_opset parameter of the convert_sparkml function.

Written by jessica.santos

Last published at: January 16th, 2025

Problem

You have successfully converted your Apache Spark ML model or Spark ML pipeline to the ONNX format in a notebook, and although no errors were raised, you noticed that the predictions produced by the ONNX model differ from those produced by the original model on the same input dataset.

 

Important

MLeap is no longer available as of Databricks Runtime 15.1 ML and above. Databricks recommends using the ONNX format to package models for deployment on JVM-based frameworks. For more information please review the Databricks Runtime 15.1 for Machine Learning (EoS) (AWSAzureGCP) documentation.

 

 

Cause

target_opset value was not provided when calling the convert_sparkml function, which produces the ONNX model. 

 

Context

The target_opset is a parameter used during the conversion of machine learning models to the ONNX format, specifying the opset version that the converter should use when translating the model's operations into ONNX.

 

When converting a model from a framework like SparkML to ONNX, each operation in the model (for example, a decision tree split or activation function) must be mapped to its ONNX equivalent. The target_opset dictates which version of these operator definitions the converter should adhere to during this mapping.

 

Solution

Define the TARGET_OPSET and then pass it as the target_opset parameter of the convert_sparkml function. This ensures that the converted model is equivalent to its original model, and the model’s operator versions are defined accordingly.

 

from onnx.defs import onnx_opset_version
from onnxconverter_common.onnx_ex import DEFAULT_OPSET_NUMBER
from onnxmltools.convert.common.data_types import FloatTensorType, Int64TensorType

TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version())

spark.conf.set("ONNX_DFS_PATH", "file:///dbfs/onnx_tmp")

onx_r_forest_model = onnxmltools.convert_sparkml(
model=spark_model_object,
name="onnx_model",
initial_types=[("features", FloatTensorType([None, 3]))],
spark_session=spark, 
target_opset=TARGET_OPSET)