In this article, we are going to learn about under the hood: randomSplit() and sample() inner working with Pyspark in Python. In PySpark, whenever we work on large datasets we need to split the data into smaller chunks or get some percentage of data to perform some operations. Hence, PySpark provides two such methods randomSplit() and sample(). The randomSplit() is used to split the DataFrame within the provided limit, whereas sample() is used to get random samples of the DataFrame.
Required Module
!pip install pyspark
What is randomsplit()
Anomalies in randomsplit() in PySpark refer to unexpected or inconsistent behavior that can occur when using the randomsplit() function. This function is used to split a dataset into two or more subsets with a specified ratio. However, due to the random nature of the function, it may lead to duplicated or missing rows, or unexpected results when performing joins on the resulting subsets.
Syntax of randomSplit
It is used in PySpark to split the data frame with the provided weights. the data frame is a 2-Dimensional data structure or a table with rows and columns where the data of the table is split based on weights and seed value.
Syntax: DataFrame.randomSplit(weights, seed=None)
Parameters:
- weights: list, list of doubles as weights in order to split the data frame.
- seed: int, optional, seed for sampling, a random number generator.
Example
One example of an anomaly in randomsplit() is when the number of data points in each split differs with every execution, which can lead to inaccuracies in data analysis or machine learning model training. Another example is when the function generates a split where one of the subsets is empty or has very few rows which lead to an unbalanced dataset and can cause issues in the model training.
Occasionally, using the randomsplit() function may lead to inconsistent and surprising outcomes, such as duplication or loss of rows, or unusual results when performing joins. For instance, if a table of employees is divided into two parts (split1 and split2), and the split1 portion is then joined with another table (department), the number of data points in the resulting table may vary with each execution, as the number of data points in each split may differ.
Python3
# import pyspark.sql for making tables from pyspark.sql import SparkSession spark = SparkSession.builder.appName( "SparkByExamples.com" ).getOrCreate() # make a table employee emp = [( 1 , "Sagar" , - 1 , "2018" , "10" , "M" , 3000 ), ( 2 , "G" , 1 , "2010" , "20" , "M" , 4000 ), ( 3 , "F" , 1 , "2010" , "10" , "M" , 1000 ), ( 4 , "G" , 2 , "2005" , "10" , "M" , 2000 ), ( 5 , "Great" , 2 , "2010" , "40" , "", - 1 ), ( 6 , "Hash" , 2 , "2010" , "50" , "", - 1 )] # naming columns of employee table empColumns = [ "emp_id" , "name" , "superior_emp_id" , "year_joined" , "emp_dept_id" , "gender" , "salary" ] print ( "Employee Table" ) # here, we have created a framework for table employee empDF = spark.createDataFrame(data = emp, schema = empColumns) empDF.printSchema() empDF.show(truncate = False ) # make another table department dept = [( "F" , 10 ), ( "M" , 20 ), ( "S" , 30 ), ( "I" , 40 )] print ( "Department Table" ) # naming columns of department table deptColumns = [ "dept_name" , "dept_id" ] # here, we have created a framework for table department deptDF = spark.createDataFrame(data = dept, schema = deptColumns) deptDF.printSchema() deptDF.show(truncate = False ) print ( "- Inconsistencies on every run-" ) # splitting up of data into two parts and joining one of them splitted part # with department table in order to check the inconsistent behaviour of randomsplit split1, split2 = empDF.randomSplit([ 0.5 , 0.5 ]) print ( "Ist Run:" , split1.join( deptDF, empDF.emp_dept_id = = deptDF.dept_id, "inner" ).count()) split1, split2 = empDF.randomSplit([ 0.5 , 0.5 ]) print ( "IInd Run:" , split1.join( deptDF, empDF.emp_dept_id = = deptDF.dept_id, "inner" ).count()) split1, split2 = empDF.randomSplit([ 0.5 , 0.5 ]) print ( "IIIrd Run:" , split1.join( deptDF, empDF.emp_dept_id = = deptDF.dept_id, "inner" ).count()) split1, split2 = empDF.randomSplit([ 0.5 , 0.5 ]) print ( "IVth Run:" , split1.join( deptDF, empDF.emp_dept_id = = deptDF.dept_id, "inner" ).count()) split1, split2 = empDF.randomSplit([ 0.5 , 0.5 ]) print ( "Vth Run:" , split1.join( deptDF, empDF.emp_dept_id = = deptDF.dept_id, "inner" ).count()) |
Output:
PySpark Under the Hood
The randomsplit() function in PySpark is used to randomly split a dataset into two or more subsets with a specified ratio. Under the hood, the function first creates a random number generator, then for each element in the dataset, it generates a random number between 0 and 1, and compares it to the specified ratio. If the random number is less than the ratio, the element is placed in the first subset, otherwise, it is placed in the second subset.
When the function is called multiple times, the subsets generated will be different due to the random nature of the function. To mitigate this, the function allows you to set a seed value for the random number generator so that the subsets generated are reproducible.
The randomsplit() function can also be used on a Dataframe or Dataset, but the splitting will be based on the rows, not the elements as in RDD.
The process is repeated to induce each split data frame partitioning, sorting within partitions, and Bernoulli sampling. However,re-partitioned, and resorted for each split computation, If the original data frame isn’t cached also the data will be prefetched. The data distribution across partitions and sorting order is essential for both randomSplit() and sample(). Still, there may be duplicates or missing values across splits and the same sample using the same seed may produce different results, If either changes upon data prefetch. These types of inconsistencies might not be on every prosecution, but to exclude them fully, persist your data frame, repartition on a column( s), or apply aggregate functions like groupBy.
Step 1: Loading DataFrame
You can download the CSV file here.
Python3
# import pyspark along with sparksession import pyspark from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() # read the dataset df = spark.read. format ( 'csv' ).option ( 'header' , 'true' ).load( 'gfg.csv' ) # to show the content of dataset df.show() |
Output:
Step 2: Splitting the above Dataframe into two DataFrames
Here, we have split the DataFrame into two parts i.e., 50% in split1 and 50% in split2. In the original DataFrame, we have 4 Rows and after performing randomSplit(), DataFrame is split into two parts i.e., 2 Rows in split1 and 2 Rows in split2.
Python3
# import pyspark along with sparksession import pyspark from pyspark.sql import SparkSession # to get an existing sparksession spark = SparkSession.builder.getOrCreate() # read the dataset df = spark.read. format ( 'csv' ).option( 'header' , 'true' ).load( 'gfg.csv' ) split1, split2 = df.randomSplit([ 0.5 , 0.5 ], seed = 13 ) # count of rows in original data frame print ( "original: " , df.count()) # count of rows in splitted data frames print ( "split1: " , split1.count()) print ( "split2: " , split2.count()) |
Output:
What is sample() method
The sample() method is used to get random sample records from the dataset. When we work on a larger dataset and want to analyze/test only a chunk of that large dataset then the sample() method helps us to get a sample to perform our execution.
Syntax of sample()
Syntax: sample(withReplacement, fraction, seed=None)
Parameters:
- fraction: list #Fraction of rows to generate , range
- seed: int, optional # seed for sampling, a random number generator
- withReplacement: Sample with replacement or not (default False)
Using Fraction to get a Sample of Records
By using Fraction, only an approximate number of records based on a given fraction value will be generated. The below code returns the 5% records out of 100%. Fraction does not give exact output, it may happen that when you run the below code it will return different values.
Python3
# importing sparksession from pyspark.sql from pyspark.sql import SparkSession # to get an existing sparksession spark = SparkSession.builder.appName ( "SparkByExamples.com" ).getOrCreate() # taking random 100 records df = spark. range ( 100 ) # to get 5% records out of total records print (df.sample( 0.05 ).collect()) |
Output:
Using Fraction and Seed
Seed is used to generate consistent random sampling.
Python3
# importing sparksession from pyspark.sql from pyspark.sql import SparkSession # to get an existing sparksession spark = SparkSession.builder.appName ( "SparkByExamples.com" ).getOrCreate() # taking random 100 records df = spark. range ( 100 ) # to get fraction of records on the basis of seed value print (df.sample( 0.1 , 97 ).collect()) |
Output:
Using withReplacement True/False
Whenever we need duplicate rows in our sample, write True in place of withReplacement=False, otherwise, there is no need to specify anything.
Python3
# importing sparksession from pyspark.sql from pyspark.sql import SparkSession # to get an existing sparksession spark = SparkSession.builder.appName ( "SparkByExamples.com" ).getOrCreate() # taking 100 random records df = spark. range ( 100 ) print ( "With Duplicates:" ) # to get fraction of record on the basis of seed value # which may contain duplicates print (df.sample( True , 0.2 , 98 ).collect()) print ( "Without Duplicates: " ) # to get fraction of records without redundancy print (df.sample( False , 0.1 , 123 ).collect()) |
Output: