Introduction
You're using a compute policy to restrict the Databricks Runtime version for your clusters. You set your current configuration for standard computes to allow the following Apache Spark versions.
"spark_version": {
"defaultValue": "15.4.x-scala2.12",
"type": "allowlist",
"values": [
"13.3.x-scala2.12",
"14.3.x-scala2.12",
"15.4.x-scala2.12"
]}
However, you now want to extend this configuration to support ML computes, both standard and those with GPUs, and you're unsure about the correct Spark version values to use for ML computes.
Instructions
ML compute requires a Databricks Runtime version that includes ML support, which is typically denoted by a different naming convention compared to the standard Spark runtime versions. For instance, ML runtimes are usually marked with an ML suffix such as 15.4.x-cpu-ml-scala2.12
or 15.4.x-gpu-ml-scala2.12
for GPU-enabled ML computes.
To extend your compute policy to support both standard and GPU ML computes, include the appropriate Databricks Runtime ML versions in your spark_version
allowlist.
1. Identify the ML runtime versions that correspond to the Spark versions you're currently using. For example, if you use 13.3.x-scala2.12
then look for 13.3.x-cpu-ml-scala2.12
(Standard) and 13.3.x-gpu-ml-scala2.12
(GPU).
2. Modify your spark_version
configuration to include both the standard Spark runtime versions and their corresponding ML versions. The following code provides an example.
"spark_version": {
"type": "allowlist",
"values": [
"13.3.x-scala2.12",
"13.3.x-cpu-ml-scala2.12",
"13.3.x-gpu-ml-scala2.12",
"14.3.x-scala2.12",
"14.3.x-cpu-ml-scala2.12",
"14.3.x-gpu-ml-scala2.12",
"15.4.x-scala2.12",
"15.4.x-cpu-ml-scala2.12",
"15.4.x-gpu-ml-scala2.12"
],
"defaultValue": "15.4.x-cpu-ml-scala2.12" # Or any other suitable default
}
If needed, you can run the Python code below in a notebook to list all the Spark version values available and able to be used in your compute policy.
from databricks.sdk import WorkspaceClient
w = WorkspaceClient()
runtimes = w.clusters.spark_versions().versions
for r in runtimes:
print(r.key)