In [8]:
import pandas as pd
import numpy as np
from pydataset import data

import matplotlib.pyplot as plt
import seaborn as sns

import pyspark
from pyspark.sql.functions import *
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType

from datetime import datetime
from dateutil.relativedelta import relativedelta

current_time = datetime.now()
last_year = datetime.now() - relativedelta(years=1)

print('Today:', current_time.strftime("%Y-%m-%d"))
print('LY:', last_year.strftime("%Y-%m-%d"))
Today: 2021-11-02
LY: 2020-11-02

Working with PySpark DataFrames

This notebook provides some basic code snippets to perform common DataFrame manipulations, transformations, and actions using PySpark.

For pandas users, I've incluced some code snippets to perform some of the most common Initial EDA you are probabably used to performing upon reading in data.

Acquire and Summarize

In [5]:
# Create spark session

spark = pyspark.sql.SparkSession.builder.getOrCreate()

From Pandas DF

In [9]:
# Import mpg dataset into a pandas DataFrame

pdf = data('mpg')
In [10]:
# Convert pandas DF into a Spark DataFrame

sdf = spark.createDataFrame(pdf)

From SQL Database

# Create PySpark DataFrame from a table using a SQL Query.

query = f"""
        SELECT * FROM table_name
        """

df = spark.sql(query)

.shape()

In [15]:
print(sdf.count(), len(sdf.columns))
234 11

.dtypes

In [16]:
# Just like pandas .dtypes

sdf.dtypes
Out[16]:
[('manufacturer', 'string'),
 ('model', 'string'),
 ('displ', 'double'),
 ('year', 'bigint'),
 ('cyl', 'bigint'),
 ('trans', 'string'),
 ('drv', 'string'),
 ('cty', 'bigint'),
 ('hwy', 'bigint'),
 ('fl', 'string'),
 ('class', 'string')]
In [5]:
# OR view columns and spark data types this way.

sdf.printSchema()
root
 |-- manufacturer: string (nullable = true)
 |-- model: string (nullable = true)
 |-- displ: double (nullable = true)
 |-- year: long (nullable = true)
 |-- cyl: long (nullable = true)
 |-- trans: string (nullable = true)
 |-- drv: string (nullable = true)
 |-- cty: long (nullable = true)
 |-- hwy: long (nullable = true)
 |-- fl: string (nullable = true)
 |-- class: string (nullable = true)

.columns

In [6]:
# Print DataFrame columns.

sdf.columns
Out[6]:
['manufacturer',
 'model',
 'displ',
 'year',
 'cyl',
 'trans',
 'drv',
 'cty',
 'hwy',
 'fl',
 'class']

.head()

In [10]:
# Display the first five rows.

sdf.show(5)
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
|manufacturer|model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
|        audi|   a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|
|        audi|   a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|
|        audi|   a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|
|        audi|   a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|
|        audi|   a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
only showing top 5 rows

# In DataBricks

display(sdf.limit(5))

.describe()

In [11]:
# View descriptive statistics for our Spark DF

sdf.describe().show()
+-------+------------+-----------------+------------------+-----------------+-----------------+----------+---+------------------+-----------------+----+-------+
|summary|manufacturer|            model|             displ|             year|              cyl|     trans|drv|               cty|              hwy|  fl|  class|
+-------+------------+-----------------+------------------+-----------------+-----------------+----------+---+------------------+-----------------+----+-------+
|  count|         234|              234|               234|              234|              234|       234|234|               234|              234| 234|    234|
|   mean|        null|             null| 3.471794871794871|           2003.5|5.888888888888889|      null|4.0|16.858974358974358|23.44017094017094|null|   null|
| stddev|        null|             null|1.2919590310839348|4.509646313320439|1.611534484684289|      null|0.0| 4.255945678889394|5.954643441166446|null|   null|
|    min|        audi|      4runner 4wd|               1.6|             1999|                4|  auto(av)|  4|                 9|               12|   c|2seater|
|    max|  volkswagen|toyota tacoma 4wd|               7.0|             2008|                8|manual(m6)|  r|                35|               44|   r|    suv|
+-------+------------+-----------------+------------------+-----------------+-----------------+----------+---+------------------+-----------------+----+-------+

# In DataBricks

display(sdf.describe())
In [14]:
# Make it more readable using your good friend pandas.

sdf.describe().toPandas().set_index('summary').T
Out[14]:
summary count mean stddev min max
manufacturer 234 None None audi volkswagen
model 234 None None 4runner 4wd toyota tacoma 4wd
displ 234 3.471794871794871 1.2919590310839348 1.6 7.0
year 234 2003.5 4.509646313320439 1999 2008
cyl 234 5.888888888888889 1.611534484684289 4 8
trans 234 None None auto(av) manual(m6)
drv 234 4.0 0.0 4 r
cty 234 16.858974358974358 4.255945678889394 9 35
hwy 234 23.44017094017094 5.954643441166446 12 44
fl 234 None None c r
class 234 None None 2seater suv

.value_counts()

In [21]:
(
  sdf
    .groupBy('model')
    .count()
    .show()
)
+-------------------+-----+
|              model|count|
+-------------------+-----+
| grand cherokee 4wd|    8|
|             altima|    6|
|      navigator 2wd|    3|
|        caravan 2wd|   11|
|        4runner 4wd|    6|
|    k1500 tahoe 4wd|    4|
|       camry solara|    7|
|    mountaineer 4wd|    4|
| c1500 suburban 2wd|    5|
|                 a4|    7|
|    f150 pickup 4wd|    7|
|              jetta|    9|
|             sonata|    7|
|                gti|    5|
|       explorer 4wd|    6|
|     pathfinder 4wd|    4|
|            corolla|    5|
|  toyota tacoma 4wd|    7|
|ram 1500 pickup 4wd|   10|
|             passat|    7|
+-------------------+-----+
only showing top 20 rows

.sort()

You can use .orderBy() the same as .sort(), but I'll demo it later as another option combined with the .when() function to show how you can perform a manual sort.

In [20]:
(
  sdf
    .groupBy('model')
    .count()
    .sort('count', ascending=False)
    .show()
)
+-------------------+-----+
|              model|count|
+-------------------+-----+
|        caravan 2wd|   11|
|ram 1500 pickup 4wd|   10|
|              civic|    9|
|              jetta|    9|
|  dakota pickup 4wd|    9|
|            mustang|    9|
|         a4 quattro|    8|
| grand cherokee 4wd|    8|
|        impreza awd|    8|
|            tiburon|    7|
|  toyota tacoma 4wd|    7|
|             sonata|    7|
|       camry solara|    7|
|              camry|    7|
|                 a4|    7|
|    f150 pickup 4wd|    7|
|             passat|    7|
|        durango 4wd|    7|
|        4runner 4wd|    6|
|       explorer 4wd|    6|
+-------------------+-----+
only showing top 20 rows

# In DataBricks

display(
   sdf
    .groupBy('department')
    .count()
    .sort('count', ascending=False)
)

OR

display(
   sdf
    .groupBy('department)
    .count()
    .sort(desc('count'))
)

Prepare and Manipulate

Data Types

.astype()

# I can change the data type of a PySpark column using .withColumn() and .cast() methods

df.withColumn( 'col_name', df['col_name'].cast(DoubleType()) )  # For a decimal number
df.withColumn( 'col_name', df['col_name'].cast(IntegerType()) ) # For an integer
df.withColumn( 'col_name', df['col_name'].cast(StringType()) )  # For a string

Filter Columns

.select()

In [23]:
# Select multiple columns from the Spark DataFrame.

sdf.select(['year', 'manufacturer', 'model']).show(5)
+----+------------+-----+
|year|manufacturer|model|
+----+------------+-----+
|1999|        audi|   a4|
|1999|        audi|   a4|
|2008|        audi|   a4|
|2008|        audi|   a4|
|1999|        audi|   a4|
+----+------------+-----+
only showing top 5 rows

Using col() & expr()

I can even chain .alias() onto my columns to rename them on the fly.

In [48]:
# Use the col and expr functions with the alias method to create a new DF.

sdf.select(
    col('hwy').alias('highway_mileage'),
    col('cty').alias('city_mileage'),
    col('trans').alias('transimission'),
    expr('(hwy + cty) / 2').alias('average_mileage')
).show(5)
+---------------+------------+-------------+---------------+
|highway_mileage|city_mileage|transimission|average_mileage|
+---------------+------------+-------------+---------------+
|             29|          18|     auto(l5)|           23.5|
|             29|          21|   manual(m5)|           25.0|
|             31|          20|   manual(m6)|           25.5|
|             30|          21|     auto(av)|           25.5|
|             26|          16|     auto(l5)|           21.0|
+---------------+------------+-------------+---------------+
only showing top 5 rows

Using Bracket Notation

  • Notice, this is only a single pair of brackets, df['col', 'col']
  • This is unlike pandas double brackets, df[['col', 'col']], to create a subset when you pass a list to the indexing operators [].
In [34]:
sdf['manufacturer', 'model'].show(5)
+------------+-----+
|manufacturer|model|
+------------+-----+
|        audi|   a4|
|        audi|   a4|
|        audi|   a4|
|        audi|   a4|
|        audi|   a4|
+------------+-----+
only showing top 5 rows

# In DataBricks

display(sdf['manufacturer', 'model'].limit(5))

Create Columns

.withColumn()

In [23]:
# Create a new column in a copy of our Spark DataFrame.
# Note: This does not change the original DataFrame.
(
    sdf
    .withColumn( 'make_&_model', concat(sdf['manufacturer'], lit(' '), sdf['model']) )
    .show(5)
)
+------------+-----+-----+----+---+----------+---+---+---+---+-------+------------+
|manufacturer|model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|make_&_model|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+------------+
|        audi|   a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|     audi a4|
|        audi|   a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|     audi a4|
|        audi|   a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|     audi a4|
|        audi|   a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|     audi a4|
|        audi|   a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|     audi a4|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+------------+
only showing top 5 rows

Using .when()

In [26]:
# I can create a column to flag all automatic transmissions if I like.
(
    sdf
    .withColumn( 'is_auto', when( col('trans').startswith('a'), 1 )
                .otherwise(0) )
    .show(5)
)
+------------+-----+-----+----+---+----------+---+---+---+---+-------+-------+
|manufacturer|model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|is_auto|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+-------+
|        audi|   a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|      1|
|        audi|   a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|      0|
|        audi|   a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|      0|
|        audi|   a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|      1|
|        audi|   a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|      1|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+-------+
only showing top 5 rows

In [27]:
# I can use the lit() function to insert literal string values, too.
(
    sdf
    .withColumn( 'transmission', when( col('trans').startswith('a'), lit('automatic') )
                .otherwise('manual') )
    .show(5)
)
+------------+-----+-----+----+---+----------+---+---+---+---+-------+------------+
|manufacturer|model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|transmission|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+------------+
|        audi|   a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|   automatic|
|        audi|   a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|      manual|
|        audi|   a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|      manual|
|        audi|   a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|   automatic|
|        audi|   a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|   automatic|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+------------+
only showing top 5 rows

# In DataBricks

display(
    sdf
    .withColumn( 'transmission', when( col('trans').startswith('a'), lit('automatic') )
                 .otherwise('manual') )
    .limit(5)
)

Using .map()

PySpark doesn't have one of my favorite pandas functions, .map(), but I can still create a new column based on the values in an existing column; I'll just use the .when() function.

In [46]:
(
    sdf
    .withColumn('fuel_efficiency', when( col('hwy') < 10, lit('terrible') )
                                  .when( col('hwy') < 15, lit('bad') )
                                  .when( col('hwy') < 20, lit('ok') )
                                  .when( col('hwy') < 25, lit('good') )
                                  .when( col('hwy') < 30, lit('really good') )
                                  .otherwise('great')
               )
    .show()
)
+------------+------------------+-----+----+---+----------+---+---+---+---+-------+---------------+
|manufacturer|             model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|fuel_efficiency|
+------------+------------------+-----+----+---+----------+---+---+---+---+-------+---------------+
|        audi|                a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|    really good|
|        audi|                a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|    really good|
|        audi|                a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|          great|
|        audi|                a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|          great|
|        audi|                a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|    really good|
|        audi|                a4|  2.8|1999|  6|manual(m5)|  f| 18| 26|  p|compact|    really good|
|        audi|                a4|  3.1|2008|  6|  auto(av)|  f| 18| 27|  p|compact|    really good|
|        audi|        a4 quattro|  1.8|1999|  4|manual(m5)|  4| 18| 26|  p|compact|    really good|
|        audi|        a4 quattro|  1.8|1999|  4|  auto(l5)|  4| 16| 25|  p|compact|    really good|
|        audi|        a4 quattro|  2.0|2008|  4|manual(m6)|  4| 20| 28|  p|compact|    really good|
|        audi|        a4 quattro|  2.0|2008|  4|  auto(s6)|  4| 19| 27|  p|compact|    really good|
|        audi|        a4 quattro|  2.8|1999|  6|  auto(l5)|  4| 15| 25|  p|compact|    really good|
|        audi|        a4 quattro|  2.8|1999|  6|manual(m5)|  4| 17| 25|  p|compact|    really good|
|        audi|        a4 quattro|  3.1|2008|  6|  auto(s6)|  4| 17| 25|  p|compact|    really good|
|        audi|        a4 quattro|  3.1|2008|  6|manual(m6)|  4| 15| 25|  p|compact|    really good|
|        audi|        a6 quattro|  2.8|1999|  6|  auto(l5)|  4| 15| 24|  p|midsize|           good|
|        audi|        a6 quattro|  3.1|2008|  6|  auto(s6)|  4| 17| 25|  p|midsize|    really good|
|        audi|        a6 quattro|  4.2|2008|  8|  auto(s6)|  4| 16| 23|  p|midsize|           good|
|   chevrolet|c1500 suburban 2wd|  5.3|2008|  8|  auto(l4)|  r| 14| 20|  r|    suv|           good|
|   chevrolet|c1500 suburban 2wd|  5.3|2008|  8|  auto(l4)|  r| 11| 15|  e|    suv|             ok|
+------------+------------------+-----+----+---+----------+---+---+---+---+-------+---------------+
only showing top 20 rows

Rename Columns

.withColumnRenamed

In [31]:
# Rename a column in a copy of our Spark DF.
# Note: Again, this does not mutate the original DF.

sdf.withColumnRenamed('manufacturer', 'make').show(5)
+----+-----+-----+----+---+----------+---+---+---+---+-------+
|make|model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|
+----+-----+-----+----+---+----------+---+---+---+---+-------+
|audi|   a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|
|audi|   a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|
|audi|   a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|
|audi|   a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|
|audi|   a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|
+----+-----+-----+----+---+----------+---+---+---+---+-------+
only showing top 5 rows

Drop Columns

.drop

In [30]:
# This does not mutate my original DataFrame; I would have to reassign if I want to do that.

sdf.drop('fl').show(5)
+------------+-----+-----+----+---+----------+---+---+---+-------+
|manufacturer|model|displ|year|cyl|     trans|drv|cty|hwy|  class|
+------------+-----+-----+----+---+----------+---+---+---+-------+
|        audi|   a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|compact|
|        audi|   a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|compact|
|        audi|   a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|compact|
|        audi|   a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|compact|
|        audi|   a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|compact|
+------------+-----+-----+----+---+----------+---+---+---+-------+
only showing top 5 rows

In [31]:
# I can drop as many columns as I want.

sdf.drop('fl', 'cyl').show(5)
+------------+-----+-----+----+----------+---+---+---+-------+
|manufacturer|model|displ|year|     trans|drv|cty|hwy|  class|
+------------+-----+-----+----+----------+---+---+---+-------+
|        audi|   a4|  1.8|1999|  auto(l5)|  f| 18| 29|compact|
|        audi|   a4|  1.8|1999|manual(m5)|  f| 21| 29|compact|
|        audi|   a4|  2.0|2008|manual(m6)|  f| 20| 31|compact|
|        audi|   a4|  2.0|2008|  auto(av)|  f| 21| 30|compact|
|        audi|   a4|  2.8|1999|  auto(l5)|  f| 16| 26|compact|
+------------+-----+-----+----+----------+---+---+---+-------+
only showing top 5 rows

Filter Rows

  • You can also use methods you will recognize from python like .contains(), .startswith(), .endswith().
  • With multiple conditions:
    • df.where( (condition1) & (condition2) )
    • df.where(condition1).where(condition2)
    • df.where( (condition1) & ~(condition2) )
    • df.where( (condition1) | (condition2) )
  • Using SQL syntax
    • df.where("col_name = value")
    • df.where("col_name <= value")
  • Using PySpark syntax
    • df.where( col('col_name') == value )
    • df.where( sdf['col_name'] == value )
    • df.where( sdf.col_name == value )

.filter()

In [43]:
# Filter Spark DataFrame to be a subset of compact cars only. I could also use .where() here.

sdf.filter(sdf['class'] == 'compact').show(5)
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
|manufacturer|model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
|        audi|   a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|
|        audi|   a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|
|        audi|   a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|
|        audi|   a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|
|        audi|   a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
only showing top 5 rows

.where()

In [35]:
# Filter with multiple conditions using .where(). .filter() will do the exact same thing.

sdf.where( (sdf['class'] == 'compact') & (sdf.year > 2000) ).show(5)
+------------+----------+-----+----+---+----------+---+---+---+---+-------+
|manufacturer|     model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|
+------------+----------+-----+----+---+----------+---+---+---+---+-------+
|        audi|        a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|
|        audi|        a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|
|        audi|        a4|  3.1|2008|  6|  auto(av)|  f| 18| 27|  p|compact|
|        audi|a4 quattro|  2.0|2008|  4|manual(m6)|  4| 20| 28|  p|compact|
|        audi|a4 quattro|  2.0|2008|  4|  auto(s6)|  4| 19| 27|  p|compact|
+------------+----------+-----+----+---+----------+---+---+---+---+-------+
only showing top 5 rows

# In DataBricks

display(
    sdf.where( (sdf['class'] == 'compact') & (sdf.year > 2000) )
    .limit(5)
In [46]:
# I can even combine them and filter with multiple conditions using .filter & .where if I like.

sdf.filter(sdf['class'] == 'compact').where(sdf.year > 2000).show(5)
+------------+----------+-----+----+---+----------+---+---+---+---+-------+
|manufacturer|     model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|
+------------+----------+-----+----+---+----------+---+---+---+---+-------+
|        audi|        a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|
|        audi|        a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|
|        audi|        a4|  3.1|2008|  6|  auto(av)|  f| 18| 27|  p|compact|
|        audi|a4 quattro|  2.0|2008|  4|manual(m6)|  4| 20| 28|  p|compact|
|        audi|a4 quattro|  2.0|2008|  4|  auto(s6)|  4| 19| 27|  p|compact|
+------------+----------+-----+----+---+----------+---+---+---+---+-------+
only showing top 5 rows

Using .isin()

In [41]:
sdf.where( col('class').isin('compact', 'subcompact') ).show()
+------------+----------+-----+----+---+----------+---+---+---+---+----------+
|manufacturer|     model|displ|year|cyl|     trans|drv|cty|hwy| fl|     class|
+------------+----------+-----+----+---+----------+---+---+---+---+----------+
|        audi|        a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|   compact|
|        audi|        a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|   compact|
|        audi|        a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|   compact|
|        audi|        a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|   compact|
|        audi|        a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|   compact|
|        audi|        a4|  2.8|1999|  6|manual(m5)|  f| 18| 26|  p|   compact|
|        audi|        a4|  3.1|2008|  6|  auto(av)|  f| 18| 27|  p|   compact|
|        audi|a4 quattro|  1.8|1999|  4|manual(m5)|  4| 18| 26|  p|   compact|
|        audi|a4 quattro|  1.8|1999|  4|  auto(l5)|  4| 16| 25|  p|   compact|
|        audi|a4 quattro|  2.0|2008|  4|manual(m6)|  4| 20| 28|  p|   compact|
|        audi|a4 quattro|  2.0|2008|  4|  auto(s6)|  4| 19| 27|  p|   compact|
|        audi|a4 quattro|  2.8|1999|  6|  auto(l5)|  4| 15| 25|  p|   compact|
|        audi|a4 quattro|  2.8|1999|  6|manual(m5)|  4| 17| 25|  p|   compact|
|        audi|a4 quattro|  3.1|2008|  6|  auto(s6)|  4| 17| 25|  p|   compact|
|        audi|a4 quattro|  3.1|2008|  6|manual(m6)|  4| 15| 25|  p|   compact|
|        ford|   mustang|  3.8|1999|  6|manual(m5)|  r| 18| 26|  r|subcompact|
|        ford|   mustang|  3.8|1999|  6|  auto(l4)|  r| 18| 25|  r|subcompact|
|        ford|   mustang|  4.0|2008|  6|manual(m5)|  r| 17| 26|  r|subcompact|
|        ford|   mustang|  4.0|2008|  6|  auto(l5)|  r| 16| 24|  r|subcompact|
|        ford|   mustang|  4.6|1999|  8|  auto(l4)|  r| 15| 21|  r|subcompact|
+------------+----------+-----+----+---+----------+---+---+---+---+----------+
only showing top 20 rows

Using .like() with %

In [43]:
# Use the .like() operator like you would in SQL.

sdf.where( col('trans').like('%auto%') ).show(10)
+------------+----------+-----+----+---+--------+---+---+---+---+-------+
|manufacturer|     model|displ|year|cyl|   trans|drv|cty|hwy| fl|  class|
+------------+----------+-----+----+---+--------+---+---+---+---+-------+
|        audi|        a4|  1.8|1999|  4|auto(l5)|  f| 18| 29|  p|compact|
|        audi|        a4|  2.0|2008|  4|auto(av)|  f| 21| 30|  p|compact|
|        audi|        a4|  2.8|1999|  6|auto(l5)|  f| 16| 26|  p|compact|
|        audi|        a4|  3.1|2008|  6|auto(av)|  f| 18| 27|  p|compact|
|        audi|a4 quattro|  1.8|1999|  4|auto(l5)|  4| 16| 25|  p|compact|
|        audi|a4 quattro|  2.0|2008|  4|auto(s6)|  4| 19| 27|  p|compact|
|        audi|a4 quattro|  2.8|1999|  6|auto(l5)|  4| 15| 25|  p|compact|
|        audi|a4 quattro|  3.1|2008|  6|auto(s6)|  4| 17| 25|  p|compact|
|        audi|a6 quattro|  2.8|1999|  6|auto(l5)|  4| 15| 24|  p|midsize|
|        audi|a6 quattro|  3.1|2008|  6|auto(s6)|  4| 17| 25|  p|midsize|
+------------+----------+-----+----+---+--------+---+---+---+---+-------+
only showing top 10 rows

Grouping & Aggregating

I can use the .agg() method to run a function, like .round(), on the aggregate or even perform more than on aggregate on one or more columns at a time.

  • Notice that I can clean up my aggregate column name using the .alias() method.
    • A cool thing to point out here is that once I alias an aggregate column, I can refer to that column by its alias in my transformation.
  • I can also pass a dictionary to .agg() with one or more columns and more or more types of aggregations.
    sdf
    .groupBy('col_name')
    .agg({'col1': 'agg_func', 'col2': 'agg_func'})
    
  • I can also add an alias for the column and format the number being returned if I like.
    sdf
    .select(stddev("Sales").alias('std'))
    .select(format_number('std',2).alias('std_2digits'))
    

countDistinct()

I can simply use .select() with countDistinct() if I want to know how many unique values I have in a column. I can even throw in a .alias() to clean up my column name.

In [47]:
sdf.select(countDistinct('trans').alias('transmission_type_count')).show()
+-----------------------+
|transmission_type_count|
+-----------------------+
|                     10|
+-----------------------+

In [11]:
# I can use the distinct and count methods if I want to return the integer.

sdf.select('trans').distinct().count()
Out[11]:
10

.groupBy()

In [36]:
sdf.groupBy('class').count().show()
+----------+-----+
|     class|count|
+----------+-----+
|subcompact|   35|
|   compact|   47|
|   minivan|   11|
|       suv|   62|
|   midsize|   41|
|    pickup|   33|
|   2seater|    5|
+----------+-----+

In [49]:
(
    sdf
    .groupBy('trans')
    .agg( count('trans').alias('row_count') )
    .show()
)
+----------+---------+
|     trans|row_count|
+----------+---------+
|  auto(l4)|       83|
|manual(m6)|       19|
|  auto(s6)|       16|
|  auto(l5)|       39|
|manual(m5)|       58|
|  auto(l6)|        6|
|  auto(av)|        5|
|  auto(s5)|        3|
|  auto(l3)|        2|
|  auto(s4)|        3|
+----------+---------+

Save & Use

I may want to grab a metric to use somewhere else.

In [64]:
# Save average city mileage to a variable.

avg_city = sdf.select(avg('cty').alias('average_city'))
avg_city.show()
+------------------+
|      average_city|
+------------------+
|16.858974358974358|
+------------------+

format_number()

Maybe I want to clean up the formatting of my metric.

In [67]:
avg_city.select(format_number('average_city', 2).alias('average_city')).show()
+------------+
|average_city|
+------------+
|       16.86|
+------------+

.agg()

In [53]:
sdf.groupBy('trans').agg({'cty': 'avg', 'hwy': 'avg'}).show()
+----------+------------------+------------------+
|     trans|          avg(cty)|          avg(hwy)|
+----------+------------------+------------------+
|  auto(l4)|15.939759036144578| 21.96385542168675|
|manual(m6)|16.894736842105264|24.210526315789473|
|  auto(s6)|            17.375|           25.1875|
|  auto(l5)|14.717948717948717| 20.71794871794872|
|manual(m5)| 19.25862068965517| 26.29310344827586|
|  auto(l6)|13.666666666666666|              20.0|
|  auto(av)|              20.0|              27.8|
|  auto(s5)|17.333333333333332|25.333333333333332|
|  auto(l3)|              21.0|              27.0|
|  auto(s4)|18.666666666666668|25.666666666666668|
+----------+------------------+------------------+

.round()

I can clean up my values using .round(), too. Yep, I'll chain on .alias() for kicks. Squeaky clean!

In [50]:
(
sdf
.groupBy('trans')
.agg( round(avg('cty'), 2).alias('city_average')
     ,round(avg('hwy'), 2).alias('highway_average')
    )
.show()
)
+----------+------------+---------------+
|     trans|city_average|highway_average|
+----------+------------+---------------+
|  auto(l4)|       15.94|          21.96|
|manual(m6)|       16.89|          24.21|
|  auto(s6)|       17.38|          25.19|
|  auto(l5)|       14.72|          20.72|
|manual(m5)|       19.26|          26.29|
|  auto(l6)|       13.67|           20.0|
|  auto(av)|        20.0|           27.8|
|  auto(s5)|       17.33|          25.33|
|  auto(l3)|        21.0|           27.0|
|  auto(s4)|       18.67|          25.67|
+----------+------------+---------------+


Pivot Tables

In [55]:
model_pivot = (
sdf
    .withColumn( 'transmission_type', when( col('trans').startswith('a'), 'automatic' )
               .otherwise('manual') )
    .groupBy('transmission_type')
    .pivot('class')
    .agg( countDistinct('model').alias('unique_models') )
)
model_pivot.show()
+-----------------+-------+-------+-------+-------+------+----------+---+
|transmission_type|2seater|compact|midsize|minivan|pickup|subcompact|suv|
+-----------------+-------+-------+-------+-------+------+----------+---+
|        automatic|      1|      8|      8|      1|     4|         5| 13|
|           manual|      1|      8|      5|   null|     4|         5|  4|
+-----------------+-------+-------+-------+-------+------+----------+---+

Handle Null Values

.fillna()

This is one way I can handle unwanted Null values. In this context, it makes the most sense.

In [56]:
# I can replace my Null values here with 0 because it means there are no models with a certain transmission type.

model_pivot.fillna(0).show()
+-----------------+-------+-------+-------+-------+------+----------+---+
|transmission_type|2seater|compact|midsize|minivan|pickup|subcompact|suv|
+-----------------+-------+-------+-------+-------+------+----------+---+
|        automatic|      1|      8|      8|      1|     4|         5| 13|
|           manual|      1|      8|      5|      0|     4|         5|  4|
+-----------------+-------+-------+-------+-------+------+----------+---+


Window Functions

.rank() & .dense_rank()

In [62]:
# Create a column that ranks the average mileage by class of vehicle. .rank() skips numbers after rows with a tie.

(
    sdf
    .groupBy('manufacturer', 'class')
    .agg( avg('hwy').alias('average_highway_mileage') )
    .withColumn( 'mileage_rank', rank().over(Window.partitionBy('class').orderBy(desc('average_highway_mileage'))) )
    .show()
)
+------------+----------+-----------------------+------------+
|manufacturer|     class|average_highway_mileage|mileage_rank|
+------------+----------+-----------------------+------------+
|  volkswagen|subcompact|     32.833333333333336|           1|
|       honda|subcompact|      32.55555555555556|           2|
|      subaru|subcompact|                   26.0|           3|
|     hyundai|subcompact|                   26.0|           3|
|        ford|subcompact|      23.22222222222222|           5|
|      toyota|   compact|     30.583333333333332|           1|
|  volkswagen|   compact|                   28.5|           2|
|      nissan|   compact|                   28.0|           3|
|        audi|   compact|     26.933333333333334|           4|
|      subaru|   compact|                   26.0|           5|
|       dodge|   minivan|     22.363636363636363|           1|
|      subaru|       suv|                   25.0|           1|
|      toyota|       suv|                  18.25|           2|
|     mercury|       suv|                   18.0|           3|
|      nissan|       suv|                   18.0|           3|
|        ford|       suv|      17.77777777777778|           5|
|        jeep|       suv|                 17.625|           6|
|   chevrolet|       suv|      17.11111111111111|           7|
|     lincoln|       suv|                   17.0|           8|
|  land rover|       suv|                   16.5|           9|
+------------+----------+-----------------------+------------+
only showing top 20 rows

In [65]:
# Grab all of the top ranking vehicles for each class.

(
    sdf
    .groupBy('manufacturer', 'class')
    .agg( avg('hwy').alias('average_highway_mileage') )
    .withColumn( 'mileage_rank', rank().over(Window.partitionBy('class').orderBy(desc('average_highway_mileage'))) )
    .where(col('mileage_rank') == 1)
    .show()
)
+------------+----------+-----------------------+------------+
|manufacturer|     class|average_highway_mileage|mileage_rank|
+------------+----------+-----------------------+------------+
|  volkswagen|subcompact|     32.833333333333336|           1|
|      toyota|   compact|     30.583333333333332|           1|
|       dodge|   minivan|     22.363636363636363|           1|
|      subaru|       suv|                   25.0|           1|
|      toyota|   midsize|     28.285714285714285|           1|
|      toyota|    pickup|     19.428571428571427|           1|
|   chevrolet|   2seater|                   24.8|           1|
+------------+----------+-----------------------+------------+

In [66]:
# What happens when I filter for mileage_rank when there is a tie in a class?

(
    sdf
    .groupBy('manufacturer', 'class')
    .agg( avg('hwy').alias('average_highway_mileage') )
    .withColumn( 'mileage_rank', rank().over(Window.partitionBy('class').orderBy(desc('average_highway_mileage'))) )
    .where(col('mileage_rank') == 3)
    .show()
)
+------------+----------+-----------------------+------------+
|manufacturer|     class|average_highway_mileage|mileage_rank|
+------------+----------+-----------------------+------------+
|      subaru|subcompact|                   26.0|           3|
|     hyundai|subcompact|                   26.0|           3|
|      nissan|   compact|                   28.0|           3|
|     mercury|       suv|                   18.0|           3|
|      nissan|       suv|                   18.0|           3|
|   chevrolet|   midsize|                   27.6|           3|
|       dodge|    pickup|     16.105263157894736|           3|
+------------+----------+-----------------------+------------+

In [60]:
# .dense_rank() does not skip numbers after rows with a tie like .rank() does.

(
    sdf
    .groupBy('manufacturer', 'model', 'class')
    .agg( avg('hwy').alias('average_highway_mileage') )
    .withColumn( 'mileage_rank', dense_rank().over(Window.partitionBy('class').orderBy(desc('average_highway_mileage'))) )
    .show()
)
+------------+------------------+----------+-----------------------+------------+
|manufacturer|             model|     class|average_highway_mileage|mileage_rank|
+------------+------------------+----------+-----------------------+------------+
|  volkswagen|        new beetle|subcompact|     32.833333333333336|           1|
|       honda|             civic|subcompact|      32.55555555555556|           2|
|     hyundai|           tiburon|subcompact|                   26.0|           3|
|      subaru|       impreza awd|subcompact|                   26.0|           3|
|        ford|           mustang|subcompact|      23.22222222222222|           4|
|      toyota|           corolla|   compact|                   34.0|           1|
|  volkswagen|             jetta|   compact|      29.11111111111111|           2|
|        audi|                a4|   compact|     28.285714285714285|           3|
|      toyota|      camry solara|   compact|     28.142857142857142|           4|
|      nissan|            altima|   compact|                   28.0|           5|
|  volkswagen|               gti|   compact|                   27.4|           6|
|      subaru|       impreza awd|   compact|                   26.0|           7|
|        audi|        a4 quattro|   compact|                  25.75|           8|
|       dodge|       caravan 2wd|   minivan|     22.363636363636363|           1|
|      subaru|      forester awd|       suv|                   25.0|           1|
|      toyota|       4runner 4wd|       suv|     18.833333333333332|           2|
|      nissan|    pathfinder 4wd|       suv|                   18.0|           3|
|     mercury|   mountaineer 4wd|       suv|                   18.0|           3|
|        ford|      explorer 4wd|       suv|                   18.0|           3|
|   chevrolet|c1500 suburban 2wd|       suv|                   17.8|           4|
+------------+------------------+----------+-----------------------+------------+
only showing top 20 rows


Using SQL Syntax

Register DF as Temp View

In [32]:
# I'm creating a temporary view to query. You might have a database you are querying, and not need to complete this step first.

sdf.createOrReplaceTempView('sql_mpg')
In [36]:
# For a simple query, I can just pass the sql query directly into the spark.sql() function.

results = spark.sql('SELECT * FROM sql_mpg WHERE class = "compact"')
In [37]:
# That easily, I'm working with a PySpark DataFrame.

results.show(5)
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
|manufacturer|model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
|        audi|   a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|
|        audi|   a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|
|        audi|   a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|
|        audi|   a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|
|        audi|   a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
only showing top 5 rows