Apache Spark Optimization Techniques for High-performance Data Processing

Apache Spark is an analytics engine that can handle very large data sets. This guide reveals strategies to optimize its performance using PySpark.


Toptalauthors are vetted experts in their fields and write on topics in which they have demonstrated experience. All of our content is peer reviewed and validated by Toptal experts in the same field.

Apache Spark is an analytics engine that can handle very large data sets. This guide reveals strategies to optimize its performance using PySpark.


Toptalauthors are vetted experts in their fields and write on topics in which they have demonstrated experience. All of our content is peer reviewed and validated by Toptal experts in the same field.
Necati Demir, PhD
Verified Expert in Engineering
19 Years of Experience

Necati is a software engineer specializing in data science, machine learning, back-end development, and DevOps. He is an AWS Certified Solutions Architect and AWS Certified Machine Learning Specialist with a doctorate in computer engineering. Necati serves as Chief AI Officer and CTO of Datagran, a machine learning automation company that he co-founded.

Previous Role

CTO

PREVIOUSLY AT

Ericsson
Share

Large-scale data analysis has become a transformative tool for many industries, with applications that include fraud detection for the banking industry, clinical research for healthcare, and predictive maintenance and quality control for manufacturing. However, processing such vast amounts of data can be a challenge, even with the power of modern computing hardware. Many tools are now available to address the challenge, with one of the most popular being Apache Spark, an open source analytics engine designed to speed up the processing of very large data sets.

Spark provides a powerful architecture capable of handling immense amounts of data. There are several Spark optimization techniques that streamline processes and data handling, including performing tasks in memory and storing frequently accessed data in a cache, thus reducing latency during retrieval. Spark is also designed for scalability; data processing can be distributed across multiple computers, increasing the available computing power. Spark is relevant to many projects: It supports a variety of programming languages (e.g., Java, Scala, R, and Python) and includes various libraries (e.g., MLlib for machine learning, GraphX for working with graphs, and Spark Streaming for processing streaming data).

While Spark’s default settings provide a good starting point, there are several adjustments that can enhance its performance—thus allowing many businesses to use it to its full potential. There are two areas to consider when thinking about optimization techniques in Spark: computation efficiency and optimizing the communication between nodes.

How Does Spark Work?

Before discussing optimization techniques in detail, it’s helpful to look at how Spark handles data. The fundamental data structure in Spark is the resilient distributed data set, or RDD. Understanding how RDDs work is key when considering how to use Apache Spark. An RDD represents a fault-tolerant, distributed collection of data capable of being processed in parallel across a cluster of computers. RDDs are immutable; their contents cannot be changed once they are created.

Spark’s fast processing speeds are enabled by RDDs. While many frameworks rely on external storage systems such as a Hadoop Distributed File System (HDFS) for reusing and sharing data between computations, RDDs support in-memory computation. Performing processing and data sharing in memory avoids the substantial overhead caused by replication, serialization, and disk read/write operations, not to mention network latency, when using an external storage system. Spark is often seen as a successor to MapReduce, the data processing component of Hadoop, an earlier framework from Apache. While the two systems share similar functionality, Spark’s in-memory processing allows it to run up to 100 times faster than MapReduce, which processes data on disk.

To work with the data in an RDD, Spark provides a rich set of transformations and actions. Transformations produce new RDDs from the data in existing ones using operations such as filter(), join(), or map(). The filter() function creates a new RDD with elements that satisfy a given condition, while join() creates a new RDD by combining two existing RDDs based on a common key. map() is used to apply a transformation to each element in a data set, for example, applying a mathematical operation such as calculating a percentage to every record in an RDD, outputting the results in a new RDD. An action, on the other hand, does not create a new RDD, but returns the result of a computation on the data set. Actions include operations such as count(), first(), or collect(). The count() action returns the number of elements in an RDD, while first() returns just the first element. collect() simply retrieves all of the elements in an RDD.

Transformations further differ from actions in that they are lazy. The execution of transformations is not immediate. Instead, Spark keeps track of the transformations that need to be applied to the base RDD, and the actual computation is triggered only when an action is called.

Understanding RDDs and how they work can provide valuable insight into Spark tuning and optimization; however, even though an RDD is the foundation of Spark’s functionality, it might not be the most efficient data structure for many applications.

Choosing the Right Data Structures

While an RDD is the basic data structure of Spark, it is a lower-level API that requires a more verbose syntax and lacks the optimizations provided by higher-level data structures. Spark shifted toward a more user-friendly and optimized API with the introduction of DataFrames—higher-level abstractions built on top of RDDs. The data in a DataFrame is organized into named columns, structuring it more like the data in a relational database. DataFrame operations also benefit from Catalyst, Spark SQL’s optimized execution engine, which can increase computational efficiency, potentially improving performance. Transformations and actions can be run on DataFrames the way they are in RDDs.

Because of their higher-level API and optimizations, DataFrames are typically easier to use and offer better performance; however, due to their lower-level nature, RDDs can still be useful for defining custom operations, as well as debugging complex data processing tasks. RDDs offer more granular control over partitioning and memory usage. When dealing with raw, unstructured data, such as text streams, binary files, or custom formats, RDDs can be more flexible, allowing for custom parsing and manipulation in the absence of a predefined structure.

Following Caching Best Practices

Caching is an essential technique that can lead to significant improvements in computational efficiency. Frequently accessed data and intermediate computations can be cached, or persisted, in a memory location that allows for faster retrieval. Spark provides built-in caching functionality, which can be particularly beneficial for machine learning algorithms, graph processing, and any other application in which the same data must be accessed repeatedly. Without caching, Spark would recompute an RDD or DataFrame and all of its dependencies every time an action was called.

The following Python code block uses PySpark, Spark’s Python API, to cache a DataFrame named df:

df.cache()

It is important to keep in mind that caching requires careful planning, because it utilizes the memory resources of Spark’s worker nodes, which perform such tasks as executing computations and storing data. If the data set is significantly larger than the available memory, or you’re caching RDDs or DataFrames without reusing them in subsequent steps, the potential overflow and other memory management issues could introduce bottlenecks in performance.

Optimizing Spark’s Data Partitioning

Spark’s architecture is built around partitioning, the division of large amounts of data into smaller, more manageable units called partitions. Partitioning enables Spark to process large amounts of data in parallel by distributing computation across multiple nodes, each handling a subset of the total data.

While Spark provides a default partitioning strategy typically based on the number of available CPU cores, it also provides options for custom partitioning. Users might instead specify a custom partitioning function, such as dividing data on a certain key.

Number of Partitions

One of the most important factors affecting the efficiency of parallel processing is the number of partitions. If there aren’t enough partitions, the available memory and resources may be underutilized. On the other hand, too many partitions can lead to increased performance overhead due to task scheduling and coordination. The optimal number of partitions is usually set as a factor of the total number of cores available in the cluster.

Partitions can be set using repartition() and coalesce(). In this example, the DataFrame is repartitioned into 200 partitions:

df = df.repartition(200)	# repartition method

df = df.coalesce(200)		# coalesce method

The repartition() method increases or decreases the number of partitions in an RDD or DataFrame and performs a full shuffle of the data across the cluster, which can be costly in terms of processing and network latency. The coalesce() method decreases the number of partitions in an RDD or DataFrame and, unlike repartition(), does not perform a full shuffle, instead combining adjacent partitions to reduce the overall number.

Dealing With Skewed Data

In some situations, certain partitions may contain significantly more data than others, leading to a condition known as skewed data. Skewed data can cause inefficiencies in parallel processing due to an uneven workload distribution among the worker nodes. To address skewed data in Spark, clever techniques such as splitting or salting can be used.

Splitting

In some cases, skewed partitions can be separated into multiple partitions. If a numerical range causes the data to be skewed, the range can often be split up into smaller sub-ranges. For example, if a large number of students scored between 65% to 75% on an exam, the test scores can be divided into several sub-ranges, such as 65% to 68%, 69% to 71%, and 72% to 75%.

If a specific key value is causing the skew, the DataFrame can be divided based on that key. In the example code below, a skew in the data is caused by a large number of records that have an id value of “12345.” The filter() transformation is used twice: once to select all records with an id value of “12345,” and once to select all records where the id value is not “12345.” The records are placed into two new DataFrames: df_skew, which contains only the rows that have an id value of “12345,” and df_non_skew, which contains all of the other rows. Data processing can be performed on df_skew and df_non_skew separately, after which the resulting data can be combined:

from pyspark.sql.functions import rand

# Split the DataFrame into two DataFrames based on the skewed key.
df_skew = df.filter(df['id'] == 12345)	# contains all rows where id = 12345
df_non_skew = df.filter(df['id'] != 12345) # contains all other rows

# Repartition the skewed DataFrame into more partitions.
df_skew = df_skew.repartition(10)

# Now operations can be performed on both DataFrames separately.
df_result_skew = df_skew.groupBy('id').count()  # just an example operation
df_result_non_skew = df_non_skew.groupBy('id').count()

# Combine the results of the operations together using union().
df_result = df_result_skew.union(df_result_non_skew)

Salting

Another method of distributing data more evenly across partitions is to add a “salt” to the key or keys that are causing the skew. The salt value, typically a random number, is appended to the original key, and the salted key is used for partitioning. This forces a more even distribution of data.

To illustrate this concept, let’s imagine our data is split into partitions for three cities in the US state of Illinois: Chicago has many more residents than the nearby cities of Oak Park or Long Grove, causing the data to be skewed.

Skewed data on the left, with uneven data for three cities, and salted data on the right, with evenly distributed data and six city groups.
Skewed data on the left shows uneven data partitions. The salted data on the right evenly distributes data among six city groups.

To distribute the data more evenly, using PySpark, we combine the column city with a randomly generated integer to create a new key, called salted_city. “Chicago” becomes “Chicago1,” “Chicago2,” and “Chicago3,” with the new keys each representing a smaller number of records. The new keys can be used with actions or transformations such as groupby() or count():

# In this example, the DataFrame 'df' has a skewed column 'city'.
skewed_column = 'city'

# Create a new column 'salted_city'.
# 'salted_id' consists of the original 'id' with a random integer between 0-10 added behind it
df = df.withColumn('salted_city', (df[skewed_column].cast("string") + (rand()*10).cast("int").cast("string")))

# Now operations can be performed on 'salted_city' instead of 'city'.
# Let’s say we are doing a groupBy operation.
df_grouped = df.groupby('salted_city').count()

# After the transformation, the salt can be removed.
df_grouped = df_grouped.withColumn('original_city', df_grouped['salted_city'].substr(0, len(df_grouped['salted_city'])-1))

Broadcasting

A join() is a common operation in which two data sets are combined based on one or more common keys. Rows from two different data sets can be merged into a single data set by matching values in the specified columns. Because data shuffling across multiple nodes is required, a join() can be a costly operation in terms of network latency.

In scenarios in which a small data set is being joined with a larger data set, Spark offers an optimization technique called broadcasting. If one of the data sets is small enough to fit into the memory of each worker node, it can be sent to all nodes, reducing the need for costly shuffle operations. The join() operation simply happens locally on each node.

A large DataFrame split into four partitions, each one having a copy of the small DataFrame; the join operation happens at the partition worker nodes.
Broadcasting a Smaller DataFrame

In the following example, the small DataFrame df2 is broadcast across all of the worker nodes, and the join() operation with the large DataFrame df1 is performed locally on each node:

from pyspark.sql.functions import broadcast
df1.join(broadcast(df2), 'id')

df2 must be small enough to fit into the memory of each worker node; a DataFrame that is too large will cause out-of-memory errors.

Filtering Unused Data

When working with high-dimensional data, minimizing computational overhead is essential. Any rows or columns that are not absolutely required should be removed. Two key techniques that reduce computational complexity and memory usage are early filtering and column pruning:

Early filtering: Filtering operations should be applied as early as possible in the data processing pipeline. This cuts down on the number of rows that need to be processed in subsequent transformations, reducing the overall computational load and memory resources.

Column pruning: Many computations involve only a subset of columns in a data set. Columns that are not necessary for data processing should be removed. Column pruning can significantly decrease the amount of data that needs to be processed and stored.

The following code shows an example of the select() operation used to prune columns. Only the columns name and age are loaded into memory. The code also demonstrates how to use the filter() operation to only include rows in which the value of age is greater than 21:

df = df.select('name', 'age').filter(df['age'] > 21)

Minimizing Usage of Python User-defined Functions

Python user-defined functions (UDFs) are custom functions written in Python that can be applied to RDDs or DataFrames. With UDFs, users can define their own custom logic or computations; however, there are performance considerations. Each time a Python UDF is invoked, data needs to be serialized and then deserialized between the Spark JVM and the Python interpreter, which leads to additional overhead due to data serialization, process switching, and data copying. This can significantly impact the speed of your data processing pipeline.

One of the most effective PySpark optimization techniques is to use PySpark’s built-in functions whenever possible. PySpark comes with a rich library of functions, all of which are optimized.

In cases in which complex logic can’t be implemented with the built-in functions, using vectorized UDFs, also known as Pandas UDFs, can help to achieve better performance. Vectorized UDFs operate on entire columns or arrays of data, rather than on individual rows. This batch processing often leads to improved performance over row-wise UDFs.

Consider a task in which all of the elements in a column must be multiplied by two. In the following example, this operation is performed using a Python UDF:

from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

def multiply_by_two(n):
   return n * 2
multiply_by_two_udf = udf(multiply_by_two, IntegerType())
df = df.withColumn("col1_doubled", multiply_by_two_udf(df["col1"]))

The multiply_by_two() function is a Python UDF which takes an integer n and multiplies it by two. This function is registered as a UDF using udf() and applied to the column col1 within the DataFrame df.

The same multiplication operation can be implemented in a more efficient manner using PySpark’s built-in functions:

from pyspark.sql.functions import col
df = df.withColumn("col1_doubled", col("col1") * 2)

In cases in which the operation cannot be performed using built-in functions and a Python UDF is necessary, a vectorized UDF can offer a more efficient alternative:

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import IntegerType

@pandas_udf(IntegerType())
def multiply_by_two_pd(s: pd.Series) -> pd.Series:
   return s * 2
df = df.withColumn("col1_doubled", multiply_by_two_pd(df["col1"]))

This method applies the function multiply_by_two_pd to an entire series of data at once, reducing the serialization overhead. Note that the input and return of the multiply_by_two_pd function are both Pandas Series. A Pandas Series is a one-dimensional labeled array that can be used to represent the data in a single column in a DataFrame.

Optimizing Performance in Data Processing

As machine learning and big data become more commonplace, engineers are adopting Apache Spark to handle the vast amounts of data that these technologies need to process. Boosting the performance of Spark involves a range of strategies, all designed to optimize the usage of available resources. Implementing the techniques discussed here will help Spark process large volumes of data much more efficiently.

Understanding the basics

  • What are Spark optimization techniques?

    While Spark’s default settings provide very good performance, there are several optimization techniques to make processing faster. These involve reducing the amount of data Spark needs to process, balancing the data distribution, and improving the way data is moved around.

  • How can I improve my Spark performance?

    Spark’s performance can be improved through several techniques, including caching frequently used data to reduce the amount of processing and making sure partitions are well-balanced. Additionally, broadcasting can significantly reduce the overhead caused by shuffling data around.

  • What are the five S’s of Spark?

    The five S’s of Spark are skew, spill, shuffle, serialization, and storage. These are the most common reasons for poor performance in Spark. Their effects can be minimized with various techniques, including salting to eliminate skew, and early filtering to reduce the amount of data that Spark has to deal with.

  • What is an optimizer in Spark?

    Spark’s Catalyst Optimizer is a component that automatically optimizes DataFrame or SQL operations.

  • Is Apache Spark free?

    Apache Spark is open source and free to use, however, there may be fees associated with using Spark on a third-party SaaS.

Hire a Toptal expert on this topic.
Hire Now
Necati Demir, PhD

Necati Demir, PhD

Verified Expert in Engineering
19 Years of Experience

Summit, NJ, United States

Member since November 17, 2015

About the author

Necati is a software engineer specializing in data science, machine learning, back-end development, and DevOps. He is an AWS Certified Solutions Architect and AWS Certified Machine Learning Specialist with a doctorate in computer engineering. Necati serves as Chief AI Officer and CTO of Datagran, a machine learning automation company that he co-founded.

authors are vetted experts in their fields and write on topics in which they have demonstrated experience. All of our content is peer reviewed and validated by Toptal experts in the same field.

Previous Role

CTO

PREVIOUSLY AT

Ericsson

World-class articles, delivered weekly.

By entering your email, you are agreeing to our privacy policy.

World-class articles, delivered weekly.

By entering your email, you are agreeing to our privacy policy.

Join the Toptal® community.