Best practice for cache(), count(), and take()

Learn best practices for using `cache()`, `count()`, and `take()` with a Spark DataFrame.

Written by ram.sankarasubramanian

Last published at: August 15th, 2023

cache() is an Apache Spark transformation that can be used on a DataFrame, Dataset, or RDD when you want to perform more than one action. cache() caches the specified DataFrame, Dataset, or RDD in the memory of your cluster’s workers. Since cache() is a transformation, the caching operation takes place only when a Spark action (for example, count(), show(), take(), or write()) is also used on the same DataFrame, Dataset, or RDD in a single action.

Calling cache() and count() separately

%scala

df1=spark.read.parquet(input_path1)
df2=spark.read.parquet(input_path2)
df1.cache()                                         # Cache DataFrame df1

joined_df = df1.join(df2, df1.id==df2.id, ‘inner’)  # Join DataFrame df1 and df2
filtered_df = joined_df.filter(“name == ‘John’”)    # Filter the joined DataFrame for the name “John”
df1.count()                                         # Call count() on the cached DataFrame
filtered_df.show()                                  # Show the filtered DataFrame filtered_df

In this example, DataFrame df1 is cached into memory when df1.count() is executed. df1.cache() does not initiate the caching operation on DataFrame df1.

Calling take() on a cached DataFrame

%scala

df=spark.table(“input_table_name”)
df.cache.take(5)                   # Call take(5) on the DataFrame df, while also caching it
df.count()                         # Call count() on the DataFrame df

In this example, DataFrame df is cached into memory when take(5) is executed. Only one partition of DataFrame df is cached in this case, because take(5) only processes 5 records. Only the partition from which the records are fetched is processed, and only that processed partition is cached. Other partitions of DataFrame df are not cached.

As a result, when df.count() is called, DataFrame df is created again, since only one partition is available in the cluster’s cache.

Calling take(5) in the example only caches 14% of the DataFrame.

14% of the DataFrame is cached when calling cache and take.

Calling count() on a cached DataFrame

%scala

df=spark.table(“input_table_name”)
df.cache.count()                    # Call count() on the DataFrame df, while also caching it
df.count()                          # Call count() on the DataFrame df
df.filter(“name==’John’”).count()

In this example, DataFrame df is cached into memory when df.count() is executed. To return the count of the dataframe, all the partitions are processed. This means that all the partitions are cached.

As a result, when df.count() and df.filter(“name==’John'”).count() are called as subsequent actions, DataFrame df is fetched from the cluster’s cache, rather than getting created again.

Calling count() in the example caches 100% of the DataFrame.

100% of the DataFrame is cached when calling cache and count.

Summary

You should call count() or write() immediately after calling cache() so that the entire DataFrame is processed and cached in memory. If you only cache part of the DataFrame, the entire DataFrame may be recomputed when a subsequent action is performed on the DataFrame.

Delete

Info

The advice for cache() also applies to persist().


Was this article helpful?