Skip to content

#1 Learn PySpark from Scratch

  1. Install PySpark in Python
  2. Create s SparkSession
  3. Create a DataFrame from scratch
  4. Inspect DataFrame
  5. Inspect Schema and Data Types
  6. Transform DataFrame to other formats

1. Install PySpark in Python

!pip3 install pyspark
# check the installation
import pyspark

2. Create a SparkSession

SparkSession can be used create DataFrame, register DataFrame as tables, execute SQL over tables, cache tables, and read parquet files.

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('PySpark_Scratch_tutorial').getOrCreate()  # you can change the appName as you want

3. Create a DataFrame from scratch

df = spark.createDataFrame([
    (32, "Alice",170,3000), (35, "Bob",185,3200),
    (28, "John",181,3500), (24, "Jan",190,2800),
    (25, "Mary",155,2800), ( 33, "Nina",165,3100),
    ], 
    schema=["age", "name","height","salary"])
df.show()

# you will see this
+---+-----+------+------+
|age| name|height|salary|
+---+-----+------+------+
| 32|Alice|   170|  3000|
| 35|  Bob|   185|  3200|
| 28| John|   181|  3500|
| 24|  Jan|   190|  2800|
| 25| Mary|   155|  2800|
| 33| Nina|   165|  3100|
+---+-----+------+------+

4. Inspect DataFrame

show()

This method is used to display a certain number of rows from a DataFrame or RDD in a tabular format. By default, it displays the first 20 rows, but you can specify the number of rows to show as an argument.

df.show()

# you should see this
+---+-----+------+------+
|age| name|height|salary|
+---+-----+------+------+
| 32|Alice|   170|  3000|
| 35|  Bob|   185|  3200|
| 28| John|   181|  3500|
| 24|  Jan|   190|  2800|
| 25| Mary|   155|  2800|
| 33| Nina|   165|  3100|
+---+-----+------+------+

collect()

This method retrieves all the data from a DataFrame or RDD and returns it as a list of Row objects (in the case of DataFrames) or Python objects (in the case of RDDs). It’s important to use this method with caution, especially with large datasets, as it collects all the data to the driver program, which may lead to out-of-memory errors.

df.collect()

# you should see this
[Row(age=32, name='Alice', height=170, salary=3000),
 Row(age=35, name='Bob', height=185, salary=3200),
 Row(age=28, name='John', height=181, salary=3500),
 Row(age=24, name='Jan', height=190, salary=2800),
 Row(age=25, name='Mary', height=155, salary=2800),
 Row(age=33, name='Nina', height=165, salary=3100)]

first()

This method returns the first element of a DataFrame or RDD (Resilient Distributed Dataset). It is a transformation operation that can be applied to a DataFrame or RDD object.

df.first()

# you should see this
Row(age=32, name='Alice', height=170, salary=3000)

take()

This method retrieves a specified number of rows from a DataFrame or RDD and returns them as a list. Unlike collect(), take() retrieves only a subset of the data, which makes it safer to use with large datasets.

df.take(2)

# you should see this
[Row(age=32, name='Alice', height=170, salary=3000),
 Row(age=35, name='Bob', height=185, salary=3200)]

tail()

This method is specific to PySpark’s RDDs. It returns the last N elements from the RDD. It’s similar to take(), but instead of returning the first elements, it returns the last elements.

df.tail(3)

# you should see this
[Row(age=24, name='Jan', height=190, salary=2800),
 Row(age=25, name='Mary', height=155, salary=2800),
 Row(age=33, name='Nina', height=165, salary=3100)]

5. Inspect Schema and Data Types

printSchema()

It is is a method used to display the schema of a DataFrame. The schema defines the structure of the DataFrame, including column names and their corresponding data types.

df.printSchema()

#you should see this
root
 |-- age: long (nullable = true)
 |-- name: string (nullable = true)
 |-- height: long (nullable = true)
 |-- salary: long (nullable = true)

explain()

It is is a method used to display the execution plan of a DataFrame or SQL query. It provides insights into how Spark will execute the given operation, including stages, tasks, and optimizations applied.

df.explain()

# you should see this
== Physical Plan ==
*(1) Scan ExistingRDD[age#0L,name#1,height#2L,salary#3L]

describe()

It is is used to compute basic statistics for numeric and string columns in a DataFrame. It provides statistics such as count, mean, standard deviation, min, and max.

df.describe().show()

# you should see this
+-------+-----------------+-----+------------------+------------------+
|summary|              age| name|            height|            salary|
+-------+-----------------+-----+------------------+------------------+
|  count|                6|    6|                 6|                 6|
|   mean|             29.5| NULL|174.33333333333334|3066.6666666666665|
| stddev|4.505552130427524| NULL|13.291601358251256| 265.8320271650252|
|    min|               24|Alice|               155|              2800|
|    max|               35| Nina|               190|              3500|
+-------+-----------------+-----+------------------+------------------+

columns

Columns refer to the individual fields or attributes in a DataFrame. Each column has a name and a corresponding data type.

df.columns
# you should see this 
['age', 'name', 'height', 'salary']

dtypes

It is an attribute of a DataFrame in PySpark that returns a list of tuples representing the data types of each column. Each tuple contains the column name and its corresponding data type.

df.dtypes
# you should see this
[('age', 'bigint'),
 ('name', 'string'),
 ('height', 'bigint'),
 ('salary', 'bigint')]

6. Transform DataFrame to other formats

You can convert the PySpark DataFrame to Pandas format, JSON format or Parquet format.

Convert to Pandas DataFrame

df.toPandas()

# you should see this
age	name	height	salary
0	32	Alice	170	3000
1	35	Bob	185	3200
2	28	John	181	3500
3	24	Jan	190	2800
4	25	Mary	155	2800
5	33	Nina	165	3100

Convert to JSON Format

df.toJSON().collect()

# you should see this
['{"age":32,"name":"Alice","height":170,"salary":3000}',
 '{"age":35,"name":"Bob","height":185,"salary":3200}',
 '{"age":28,"name":"John","height":181,"salary":3500}',
 '{"age":24,"name":"Jan","height":190,"salary":2800}',
 '{"age":25,"name":"Mary","height":155,"salary":2800}',
 '{"age":33,"name":"Nina","height":165,"salary":3100}']

Convert to Parquet Format

df.write.parquet("df_output.parquet")

# Check the parquet file
parquet_df = spark.read.parquet("df_output.parquet")
parquet_df.show()

# you should see this
+---+-----+------+------+
|age| name|height|salary|
+---+-----+------+------+
| 32|Alice|   170|  3000|
| 28| John|   181|  3500|
| 33| Nina|   165|  3100|
| 25| Mary|   155|  2800|
| 24|  Jan|   190|  2800|
| 35|  Bob|   185|  3200|
+---+-----+------+------+