Skip to main content

Command Palette

Search for a command to run...

pyspark Window function example

Updated
4 min read
P

I am a data engineer at Tesco and this blog is part of a mentoring process to track the progress of my career development journey.

Assignment 1

Please find the best and least selling product for each day.

from datetime import date
from pyspark.sql import Row

schema = {"date":"date", "transaction_id":"int", "customer_id":"int", "products":"struct<string>"}
data = [Row(date=date(2021,10,1), transaction_id=1, customer_id=101, products=["p1","p10"]),
        Row(date=date(2021,10,1), transaction_id=2, customer_id=103, products=["p2","p11","p2"]),
        Row(date=date(2021,10,1), transaction_id=3, customer_id=101, products=["p1","p10","p2"]),
        Row(date=date(2021,11,1), transaction_id=4, customer_id=102, products=["p1"]),
        Row(date=date(2021,11,1), transaction_id=5, customer_id=103, products=["p1","p2"])]
df = spark.createDataFrame(data)
df.show()
+----------+--------------+-----------+-------------+
|      date|transaction_id|customer_id|     products|
+----------+--------------+-----------+-------------+
|2021-10-01|             1|        101|    [p1, p10]|
|2021-10-01|             2|        103|[p2, p11, p2]|
|2021-10-01|             3|        101|[p1, p10, p2]|
|2021-11-01|             4|        102|         [p1]|
|2021-11-01|             5|        103|     [p1, p2]|
+----------+--------------+-----------+-------------+
from pyspark.sql.functions import explode, col, first, last, desc
from pyspark.sql import Window
df=df.withColumn("product", explode(df.products))

window = Window().orderBy("count").partitionBy("date")
window_d = Window().orderBy(col("count").desc()).partitionBy("date")
df_g1=df.groupBy("product","date").count()

#method last not working?
df_g2=df_g1.withColumn("least_selling", first("product").over(window))\
        .withColumn("Best_selling", first("product").over(window_d))
df_g2=df_g2.select("date","least_selling","Best_selling").orderBy("date").distinct()

df_g2.show()
+----------+-------------+------------+
|      date|least_selling|Best_selling|
+----------+-------------+------------+
|2021-11-01|           p2|          p1|
|2021-10-01|          p11|          p2|
+----------+-------------+------------+

Assignment 2

Tasks:

  • Calculate Cumulative Sales Amount per Customer

  • Rank Products by Total Sales

  • Calculate Moving Average Sales Amount for the Last 3 Orders per Customer

  • Calculate Previous and Next Sales Amounts of each order

  • Calculate Percentage Contribution of Each Order to Total Sales of the Customer

Load Data Code:

from pyspark.sql import Row
from pyspark.sql import functions as F

data = [
Row(OrderID=1, CustomerID='C001', OrderDate='2023-01-01', SalesAmount=100, ProductID='P01'),
Row(OrderID=2, CustomerID='C002', OrderDate='2023-01-05', SalesAmount=200, ProductID='P01'),
Row(OrderID=3, CustomerID='C001', OrderDate='2023-01-10', SalesAmount=150, ProductID='P02'),
Row(OrderID=4, CustomerID='C003', OrderDate='2023-01-15', SalesAmount=300, ProductID='P01'),
Row(OrderID=5, CustomerID='C001', OrderDate='2023-01-20', SalesAmount=100, ProductID='P03'),
Row(OrderID=6, CustomerID='C002', OrderDate='2023-01-25', SalesAmount=250, ProductID='P03'),
Row(OrderID=7, CustomerID='C003', OrderDate='2023-01-30', SalesAmount=400, ProductID='P02'),
Row(OrderID=8, CustomerID='C001', OrderDate='2023-02-05', SalesAmount=200, ProductID='P01'),
Row(OrderID=9, CustomerID='C002', OrderDate='2023-02-10', SalesAmount=100, ProductID='P02'),
Row(OrderID=10, CustomerID='C003', OrderDate='2023-02-15', SalesAmount=300, ProductID='P03'),
]
df = spark.createDataFrame(data)

Results:

# a) Calculate Cumulative Sales Amount per Customer
df.groupBy("CustomerID").sum("SalesAmount").show()
+----------+----------------+
|CustomerID|sum(SalesAmount)|
+----------+----------------+
|      C003|            1000|
|      C001|             550|
|      C002|             550|
+----------+----------------+

# b) Rank Products by Total Sales

from pyspark.sql import Window
from pyspark.sql.functions import sum as sum_p

window = Window().orderBy(col("TotalSales").desc())
df2=df.groupBy("ProductID").agg(sum_p("SalesAmount").alias("TotalSales"))
df2=df2.withColumn("ProductRank", rank().over(window))
df2.select("ProductID", "ProductRank").show()

+---------+-----------+
|ProductID|ProductRank|
+---------+-----------+
|      P01|          1|
|      P02|          2|
|      P03|          2|
+---------+-----------+

# c) Calculate Moving Average Sales Amount for the Last 3 Orders per Customer

from pyspark.sql.functions import avg, col
from pyspark.sql.window import Window


w_ma = Window().orderBy(col("OrderDate")).partitionBy("CustomerID").rowsBetween(-2, 0)
df3=df.withColumn("moving_avg", avg(col("SalesAmount")).over(w_ma))
df3.orderBy("OrderDate").show()
+-------+----------+----------+-----------+---------+------------------+
|OrderID|CustomerID| OrderDate|SalesAmount|ProductID|        moving_avg|
+-------+----------+----------+-----------+---------+------------------+
|      1|      C001|2023-01-01|        100|      P01|             100.0|
|      2|      C002|2023-01-05|        200|      P01|             200.0|
|      3|      C001|2023-01-10|        150|      P02|             125.0|
|      4|      C003|2023-01-15|        300|      P01|             300.0|
|      5|      C001|2023-01-20|        100|      P03|116.66666666666667|
|      6|      C002|2023-01-25|        250|      P03|             225.0|
|      7|      C003|2023-01-30|        400|      P02|             350.0|
|      8|      C001|2023-02-05|        200|      P01|             150.0|
|      9|      C002|2023-02-10|        100|      P02|183.33333333333334|
|     10|      C003|2023-02-15|        300|      P03| 333.3333333333333|
+-------+----------+----------+-----------+---------+------------------+

# d) Calculate Previous and Next Sales Amounts of each order
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, lead

wna = Window().orderBy(col("OrderDate"))
df4=df.withColumn("next_amount", lead("SalesAmount").over(wna)).withColumn("prev_amount", lag("SalesAmount").over(wna))
df4.select(df4.OrderID, df4.next_amount, df4.prev_amount).show()

+-------+-----------+-----------+
|OrderID|next_amount|prev_amount|
+-------+-----------+-----------+
|      1|        200|       null|
|      2|        150|        100|
|      3|        300|        200|
|      4|        100|        150|
|      5|        250|        300|
|      6|        400|        100|
|      7|        200|        250|
|      8|        100|        400|
|      9|        300|        200|
|     10|       null|        100|
+-------+-----------+-----------+

# e) Calculate Percentage Contribution of Each Order to Total Sales of the Customer
from pyspark.sql.functions import col, round
from pyspark.sql.functions import sum as sum_p
from pyspark.sql.window import Window

wcust = Window().partitionBy(col("CustomerID"))
df5 = df.withColumn("TotalSales", sum_p("SalesAmount").over(wcust))
df5.select(df5.OrderID, df5.CustomerID, df5.SalesAmount, round(100*df5.SalesAmount/df5.TotalSales).alias("percentage")).show()

+-------+----------+-----------+----------+
|OrderID|CustomerID|SalesAmount|percentage|
+-------+----------+-----------+----------+
|      4|      C003|        300|      30.0|
|      7|      C003|        400|      40.0|
|     10|      C003|        300|      30.0|
|      1|      C001|        100|      18.0|
|      3|      C001|        150|      27.0|
|      5|      C001|        100|      18.0|
|      8|      C001|        200|      36.0|
|      6|      C002|        250|      45.0|
|      9|      C002|        100|      18.0|
|      2|      C002|        200|      36.0|
+-------+----------+-----------+----------+

Assignment 3

# list  of college data with two lists 
data = [["A",1,6,7],
["B",2,7,6],
["C",3,8,5],
["D",4,9,4],
["E",5,8,3]] 

# giving column names of dataframe 
columns = ["store","p1","p2","p3"]

# creating a dataframe 
dataframe = spark.createDataFrame(data, columns) 

# show data frame 
dataframe.show()

# dataframe.unpivot("store", ["p1","p2","p3"], "product", "qty").show()
+-----+---+---+---+
|store| p1| p2| p3|
+-----+---+---+---+
|    A|  1|  6|  7|
|    B|  2|  7|  6|
|    C|  3|  8|  5|
|    D|  4|  9|  4|
|    E|  5|  8|  3|
+-----+---+---+---+
from pyspark.sql.functions import lit, row_number, rank, col, desc
from pyspark.sql.window import Window

dp1=dataframe.select("store",lit("p1").alias("product"), dataframe.p1.alias("qty"))
dp2=dataframe.select("store",lit("p2").alias("product"), dataframe.p2.alias("qty"))
dp3=dataframe.select("store",lit("p3").alias("product"), dataframe.p3.alias("qty"))

dp=dp1.unionAll(dp2).unionAll(dp3).orderBy("store")
#dp.show()

partition=Window.partitionBy("store").orderBy(col("qty").desc())
dp=dp.withColumn("rn", rank().over(partition))

dp=dp.filter("rn < 3").drop("rn")
dp.show()
+-----+-------+---+
|store|product|qty|
+-----+-------+---+
|    A|     p3|  7|
|    A|     p2|  6|
|    B|     p2|  7|
|    B|     p3|  6|
|    C|     p2|  8|
|    C|     p3|  5|
|    D|     p2|  9|
|    D|     p1|  4|
|    D|     p3|  4|
|    E|     p2|  8|
|    E|     p1|  5|
+-----+-------+---+