Back to all posts

Window Functions in PySpark

Window functions in PySpark allow you to perform operations across a set of rows that are somehow related to the current row. They are useful for tasks lik…

Window functions in PySpark allow you to perform operations across a set of rows that are somehow related to the current row. They are useful for tasks like ranking, cumulative sums, moving averages, lead/lag values, etc.


🧠 Key Concepts

A Window is defined by:

  • Partitioning: Like SQL PARTITION BY, groups rows.
  • Ordering: Sorts rows within each partition.
  • Frame: Defines the subset of rows to consider relative to the current row.

We use the Window class from pyspark.sql.window to define the window specification.


✅ Common Window Functions

  • row_number(): Assigns a unique row number per partition.
  • rank(): Assigns a rank, allowing for gaps.
  • dense_rank(): Like rank but no gaps.
  • lag(), lead(): Access previous/next row values.
  • sum(), avg(), min(), max(): Aggregates over a window.

🔧 Setup

Python
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, rank, dense_rank, lag, lead, sum

spark = SparkSession.builder.appName("WindowFunctions").getOrCreate()

Sample data:

Bash
data = 

df = spark.createDataFrame(data, )
df.show()

🪟 Example 1: row_number() per category by date

SQL
windowSpec = Window.partitionBy("category").orderBy("date")

df.withColumn("row_num", row_number().over(windowSpec)).show()

📌 Explanation: For each category, it numbers the rows by date.


🪜 Example 2: rank() and dense_rank()

SQL
from pyspark.sql.functions import desc

rank_window = Window.partitionBy("category").orderBy(desc("sales"))

df.withColumn("rank", rank().over(rank_window))\
  .withColumn("dense_rank", dense_rank().over(rank_window)).show()

📌 Explanation: Ranks sales within each category in descending order.


🔁 Example 3: lag() and lead()

SQL
from pyspark.sql.functions import lag, lead

lag_window = Window.partitionBy("category").orderBy("date")

df.withColumn("prev_day_sales", lag("sales", 1).over(lag_window))\
  .withColumn("next_day_sales", lead("sales", 1).over(lag_window)).show()

📌 Explanation: Fetches the previous and next day's sales for each row.


➕ Example 4: Cumulative Sum (sum() over rows)

Python
from pyspark.sql.functions import sum

cum_sum_window = Window.partitionBy("category").orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow)

df.withColumn("cumulative_sales", sum("sales").over(cum_sum_window)).show()

📌 Explanation: Computes running total of sales for each category.


Let me know if you want use cases like moving averages, top-N per group, or working with PySpark SQL directly!

Keep building your data skillset

Explore more SQL, Python, analytics, and engineering tutorials.