How to improve performance of Delta Lake MERGE INTO
queries using partition pruning
This article explains how to trigger partition pruning in Delta Lake MERGE INTO queries from Databricks.
Partition pruning is an optimization technique to limit the number of partitions that are inspected by a query.
Discussion
MERGE INTO
is an expensive operation when used with Delta tables. If you don’t partition the underlying data and use it appropriately, query performance can be severely impacted.
The main lesson is this: if you know which partitions a MERGE INTO
query needs to inspect, you should specify them in the query so that partition pruning is performed.
Demonstration: no partition pruning
Here is an example of a poorly performing MERGE INTO
query without partition pruning.
Start by creating the following Delta table, called delta_merge_into
:
val df = spark.range(30000000)
.withColumn("par", ($"id" % 1000).cast(IntegerType))
.withColumn("ts", current_timestamp())
.write
.format("delta")
.mode("overwrite")
.partitionBy("par")
.saveAsTable("delta_merge_into")
Then merge a DataFrame into the Delta table to create a table called update
:
val updatesTableName = "update"
val targetTableName = "delta_merge_into"
val updates = spark.range(100).withColumn("id", (rand() * 30000000 * 2).cast(IntegerType))
.withColumn("par", ($"id" % 2).cast(IntegerType))
.withColumn("ts", current_timestamp())
.dropDuplicates("id")
updates.createOrReplaceTempView(updatesTableName)
The update
table has 100 rows with three columns, id
, par
, and ts
. The value of par
is always either 1 or 0.
Let’s say you run the following simple MERGE INTO
query:
spark.sql(s"""
|MERGE INTO $targetTableName
|USING $updatesTableName
|ON $targetTableName.id = $updatesTableName.id
|WHEN MATCHED THEN
| UPDATE SET $targetTableName.ts = $updatesTableName.ts
|WHEN NOT MATCHED THEN
| INSERT (id, par, ts) VALUES ($updatesTableName.id, $updatesTableName.par, $updatesTableName.ts)
""".stripMargin)
The query takes 13.16 minutes to complete:

The physical plan for this query contains PartitionCount: 1000
, as shown below. This means Apache Spark is scanning all 1000 partitions in order to execute the query. This is not an efficient query, because the update
data only has partition values of 1
and 0
:
== Physical Plan ==
*(5) HashAggregate(keys=[], functions=[finalmerge_count(merge count#8452L) AS count(1)#8448L], output=[count#8449L])
+- Exchange SinglePartition
+- *(4) HashAggregate(keys=[], functions=[partial_count(1) AS count#8452L], output=[count#8452L])
+- *(4) Project
+- *(4) Filter (isnotnull(count#8440L) && (count#8440L > 1))
+- *(4) HashAggregate(keys=[_row_id_#8399L], functions=[finalmerge_sum(merge sum#8454L) AS sum(cast(one#8434 as bigint))#8439L], output=[count#8440L])
+- Exchange hashpartitioning(_row_id_#8399L, 200)
+- *(3) HashAggregate(keys=[_row_id_#8399L], functions=[partial_sum(cast(one#8434 as bigint)) AS sum#8454L], output=[_row_id_#8399L, sum#8454L])
+- *(3) Project [_row_id_#8399L, UDF(_file_name_#8404) AS one#8434]
+- *(3) BroadcastHashJoin [cast(id#7514 as bigint)], [id#8390L], Inner, BuildLeft, false
:- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)))
: +- *(2) HashAggregate(keys=[id#7514], functions=[], output=[id#7514])
: +- Exchange hashpartitioning(id#7514, 200)
: +- *(1) HashAggregate(keys=[id#7514], functions=[], output=[id#7514])
: +- *(1) Filter isnotnull(id#7514)
: +- *(1) Project [cast(((rand(8188829649009385616) * 3.0E7) * 2.0) as int) AS id#7514]
: +- *(1) Range (0, 100, step=1, splits=36)
+- *(3) Filter isnotnull(id#8390L)
+- *(3) Project [id#8390L, _row_id_#8399L, input_file_name() AS _file_name_#8404]
+- *(3) Project [id#8390L, monotonically_increasing_id() AS _row_id_#8399L]
+- *(3) Project [id#8390L, par#8391, ts#8392]
+- *(3) FileScan parquet [id#8390L,ts#8392,par#8391] Batched: true, DataFilters: [], Format: Parquet, Location: TahoeBatchFileIndex[dbfs:/user/hive/warehouse/delta_merge_into], PartitionCount: 1000, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint,ts:timestamp>
Solution
Rewrite the query to specify the partitions.
This MERGE INTO
query specifies the partitions directly:
spark.sql(s"""
|MERGE INTO $targetTableName
|USING $updatesTableName
|ON $targetTableName.par IN (1,0) AND $targetTableName.id = $updatesTableName.id
|WHEN MATCHED THEN
| UPDATE SET $targetTableName.ts = $updatesTableName.ts
|WHEN NOT MATCHED THEN
| INSERT (id, par, ts) VALUES ($updatesTableName.id, $updatesTableName.par, $updatesTableName.ts)
""".stripMargin)
Now the query takes just 20.54 seconds to complete on the same cluster.

The physical plan for this query contains PartitionCount: 2
, as shown below. With only minor changes, the query is now more than 40X faster:
== Physical Plan ==
*(5) HashAggregate(keys=[], functions=[finalmerge_count(merge count#7892L) AS count(1)#7888L], output=[count#7889L])
+- Exchange SinglePartition
+- *(4) HashAggregate(keys=[], functions=[partial_count(1) AS count#7892L], output=[count#7892L])
+- *(4) Project
+- *(4) Filter (isnotnull(count#7880L) && (count#7880L > 1))
+- *(4) HashAggregate(keys=[_row_id_#7839L], functions=[finalmerge_sum(merge sum#7894L) AS sum(cast(one#7874 as bigint))#7879L], output=[count#7880L])
+- Exchange hashpartitioning(_row_id_#7839L, 200)
+- *(3) HashAggregate(keys=[_row_id_#7839L], functions=[partial_sum(cast(one#7874 as bigint)) AS sum#7894L], output=[_row_id_#7839L, sum#7894L])
+- *(3) Project [_row_id_#7839L, UDF(_file_name_#7844) AS one#7874]
+- *(3) BroadcastHashJoin [cast(id#7514 as bigint)], [id#7830L], Inner, BuildLeft, false
:- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)))
: +- *(2) HashAggregate(keys=[id#7514], functions=[], output=[id#7514])
: +- Exchange hashpartitioning(id#7514, 200)
: +- *(1) HashAggregate(keys=[id#7514], functions=[], output=[id#7514])
: +- *(1) Filter isnotnull(id#7514)
: +- *(1) Project [cast(((rand(8188829649009385616) * 3.0E7) * 2.0) as int) AS id#7514]
: +- *(1) Range (0, 100, step=1, splits=36)
+- *(3) Project [id#7830L, _row_id_#7839L, _file_name_#7844]
+- *(3) Filter (par#7831 IN (1,0) && isnotnull(id#7830L))
+- *(3) Project [id#7830L, par#7831, _row_id_#7839L, input_file_name() AS _file_name_#7844]
+- *(3) Project [id#7830L, par#7831, monotonically_increasing_id() AS _row_id_#7839L]
+- *(3) Project [id#7830L, par#7831, ts#7832]
+- *(3) FileScan parquet [id#7830L,ts#7832,par#7831] Batched: true, DataFilters: [], Format: Parquet, Location: TahoeBatchFileIndex[dbfs:/user/hive/warehouse/delta_merge_into], PartitionCount: 2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint,ts:timestamp>