PySpark GroupBy, Joins and Window Functions Deep Dive
In every PySpark interview, three topics appear almost without exception: groupBy aggregations, join strategies, and window functions. Mastering these three areas covers roughly 70% of what you will be asked. This guide goes deep on each.
GroupBy and Aggregations
The pattern is: groupBy on key columns, then agg() with one or more aggregate functions. Common interview trap: candidates use multiple separate groupBy calls instead of combining aggregations.
from pyspark.sql import functions as F
# Wrong — two separate jobs
total = df.groupBy("user_id").agg(F.sum("revenue"))
cnt = df.groupBy("user_id").agg(F.count("*"))
# Correct — one job, one shuffle
result = df.groupBy("user_id").agg(
F.sum("revenue").alias("total_revenue"),
F.count("*").alias("order_count"),
F.avg("revenue").alias("avg_order"),
F.max("order_date").alias("last_order_date")
)Join Strategies
Spark has four physical join strategies. Understanding which one is chosen and when to override it is what separates average from strong PySpark candidates.
# 1. Broadcast join — small table (< 10MB default)
from pyspark.sql.functions import broadcast
result = orders.join(broadcast(countries), "country_code")
# 2. Sort-merge join — default for large tables (causes shuffle)
result = orders.join(customers, "customer_id")
# 3. Shuffle hash join — hint for medium tables
result = orders.hint("shuffle_hash").join(customers, "customer_id")
# Check which join was chosen
result.explain()Window Functions — The Most-Tested Topic
Window functions are asked in almost every senior data engineering interview. The three patterns you must know cold: ranking, running totals, and lag/lead comparisons.
from pyspark.sql.window import Window
from pyspark.sql import functions as F
# Pattern 1: Top N per group (rank within partition)
w_rank = Window.partitionBy("dept").orderBy(F.desc("salary"))
df.withColumn("rank", F.dense_rank().over(w_rank)) .filter(F.col("rank") <= 3)
# Pattern 2: Running total
w_running = Window.partitionBy("user_id").orderBy("date") .rowsBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn("cumulative_spend", F.sum("amount").over(w_running))
# Pattern 3: Period-over-period change
w_lag = Window.partitionBy("product_id").orderBy("date")
df.withColumn("prev_revenue", F.lag("revenue", 1).over(w_lag)) .withColumn("wow_change", F.col("revenue") - F.col("prev_revenue"))Interview Tips
• Always explain the window spec before writing code — interviewers want to see your reasoning • Mention the shuffle when talking about joins — shows you understand performance implications • For groupBy, always ask if you need to handle nulls in the key columns before aggregating • Use .explain() in your answers to show you know how to verify the physical plan