When there is a huge dataset, it is better to split them into equal chunks and then process each dataframe individually. This is possible if the operation on the dataframe is independent of the rows. Each chunk or equally split dataframe then can be processed parallel making use of the resources more efficiently. In this article, we will discuss how to split PySpark dataframes into an equal number of rows.
Creating Dataframe for demonstration:
Python
# importing module import pyspark # importing sparksession from pyspark.sql module from pyspark.sql import SparkSession # creating sparksession and giving an app name spark = SparkSession.builder.appName( 'sparkdf' ).getOrCreate() # Column names for the dataframe columns = [ "Brand" , "Product" ] # Row data for the dataframe data = [ ( "HP" , "Laptop" ), ( "Lenovo" , "Mouse" ), ( "Dell" , "Keyboard" ), ( "Samsung" , "Monitor" ), ( "MSI" , "Graphics Card" ), ( "Asus" , "Motherboard" ), ( "Gigabyte" , "Motherboard" ), ( "Zebronics" , "Cabinet" ), ( "Adata" , "RAM" ), ( "Transcend" , "SSD" ), ( "Kingston" , "HDD" ), ( "Toshiba" , "DVD Writer" ) ] # Create the dataframe using the above values prod_df = spark.createDataFrame(data = data, schema = columns) # View the dataframe prod_df.show() |
Output:
In the above code block, we have defined the schema structure for the dataframe and provided sample data. Our dataframe consists of 2 string-type columns with 12 records.
Example 1: Split dataframe using ‘DataFrame.limit()’
We will make use of the split() method to create ‘n’ equal dataframes.
Syntax: DataFrame.limit(num)
Where, Limits the result count to the number specified.
Code:
Python
# Define the number of splits you want n_splits = 4 # Calculate count of each dataframe rows each_len = prod_df.count() / / n_splits # Create a copy of original dataframe copy_df = prod_df # Iterate for each dataframe i = 0 while i < n_splits: # Get the top `each_len` number of rows temp_df = copy_df.limit(each_len) # Truncate the `copy_df` to remove # the contents fetched for `temp_df` copy_df = copy_df.subtract(temp_df) # View the dataframe temp_df.show(truncate = False ) # Increment the split number i + = 1 |
Output:
Example 2: Split the dataframe, perform the operation and concatenate the result
We will now split the dataframe in ‘n’ equal parts and perform concatenation operation on each of these parts individually and then concatenate the result to a `result_df`. This is to demonstrate how we can use the extension of the previous code to perform a dataframe operation separately on each dataframe and then append these individual dataframes to produce a new dataframe which has a length equal to the original dataframe.
Python
# Define the number of splits you want from pyspark.sql.types import StructType, StructField, StringType from pyspark.sql.functions import concat, col, lit n_splits = 4 # Calculate count of each dataframe rows each_len = prod_df.count() / / n_splits # Create a copy of original dataframe copy_df = prod_df # Function to modify columns of each individual split def modify_dataframe(data): return data.select( concat(col( "Brand" ), lit( " - " ), col( "Product" )) ) # Create an empty dataframe to # store concatenated results schema = StructType([ StructField( 'Brand - Product' , StringType(), True ) ]) result_df = spark.createDataFrame(data = [], schema = schema) # Iterate for each dataframe i = 0 while i < n_splits: # Get the top `each_len` number of rows temp_df = copy_df.limit(each_len) # Truncate the `copy_df` to remove # the contents fetched for `temp_df` copy_df = copy_df.subtract(temp_df) # Perform operation on the newly created dataframe temp_df_mod = modify_dataframe(data = temp_df) temp_df_mod.show(truncate = False ) # Concat the dataframe result_df = result_df.union(temp_df_mod) # Increment the split number i + = 1 result_df.show(truncate = False ) |
Output: