Window functions in PySpark allow you to perform operations across a set of rows that are somehow related to the current row. They are useful for tasks like ranking, cumulative sums, moving averages, lead/lag values, etc.
🧠 Key Concepts
A Window is defined by:
- Partitioning: Like SQL
PARTITION BY, groups rows. - Ordering: Sorts rows within each partition.
- Frame: Defines the subset of rows to consider relative to the current row.
We use the Window class from pyspark.sql.window to define the window specification.
✅ Common Window Functions
row_number(): Assigns a unique row number per partition.rank(): Assigns a rank, allowing for gaps.dense_rank(): Like rank but no gaps.lag(),lead(): Access previous/next row values.sum(),avg(),min(),max(): Aggregates over a window.
🔧 Setup
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, rank, dense_rank, lag, lead, sum
spark = SparkSession.builder.appName("WindowFunctions").getOrCreate()
Sample data:
data =
df = spark.createDataFrame(data, )
df.show()
🪟 Example 1: row_number() per category by date
windowSpec = Window.partitionBy("category").orderBy("date")
df.withColumn("row_num", row_number().over(windowSpec)).show()
📌 Explanation: For each category, it numbers the rows by date.
🪜 Example 2: rank() and dense_rank()
from pyspark.sql.functions import desc
rank_window = Window.partitionBy("category").orderBy(desc("sales"))
df.withColumn("rank", rank().over(rank_window))\
.withColumn("dense_rank", dense_rank().over(rank_window)).show()
📌 Explanation: Ranks sales within each category in descending order.
🔁 Example 3: lag() and lead()
from pyspark.sql.functions import lag, lead
lag_window = Window.partitionBy("category").orderBy("date")
df.withColumn("prev_day_sales", lag("sales", 1).over(lag_window))\
.withColumn("next_day_sales", lead("sales", 1).over(lag_window)).show()
📌 Explanation: Fetches the previous and next day's sales for each row.
➕ Example 4: Cumulative Sum (sum() over rows)
from pyspark.sql.functions import sum
cum_sum_window = Window.partitionBy("category").orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn("cumulative_sales", sum("sales").over(cum_sum_window)).show()
📌 Explanation: Computes running total of sales for each category.
Let me know if you want use cases like moving averages, top-N per group, or working with PySpark SQL directly!