What is the difference between map and flatMap functions in Spark?
March 9, 2022How to pivot and unpivot a DataFrame in Spark?
March 23, 2022stack function in Spark takes a number of rows as an argument followed by expressions.
stack(n, expr1, expr2.. exprn)
stack function will generate n rows by evaluating the expressions.
stack() in action
Let’s see the stack function in action. stack() comes in handy when we attempt to unpivot a dataframe.
Let’s say out data is pivoted and it looks like below.
+-------+---------+-----+---------+----+ | Name|Analytics| BI|Ingestion| ML| +-------+---------+-----+---------+----+ | Mickey| null|12000| null|8000| | Martin| null| 5000| null|null| | Jerry| null| null| 1000|null| | Riley| null| null| null|9000| | Donald| 1000| null| null|null| | John| null| null| 1000|null| |Patrick| null| null| null|1000| | Emily| 8000| null| 3000|null| | Arya| 10000| null| 2000|null| +-------+---------+-----+---------+----+
We would like to unpivot the data and make it look like below.
+-------+---------+---------------+ | Name| Project|Cost_To_Project| +-------+---------+---------------+ | Mickey| BI| 12000| | Mickey| ML| 8000| | Martin| BI| 5000| | Jerry|Ingestion| 1000| | Riley| ML| 9000| | Donald|Analytics| 1000| | John|Ingestion| 1000| |Patrick| ML| 1000| | Emily|Analytics| 8000| | Emily|Ingestion| 3000| | Arya|Analytics| 10000| | Arya|Ingestion| 2000| +-------+---------+---------------+
We can use the stack function like below to get the job done.
- For every record in pivotDF, we are selecting Name and calling stack function
- stack function will create 4 rows for every row in pivotDF and will create 2 columns – Project and Cost_To_Project
- 4 rows because we have 4 unique projects in our dataset
- Project column is a hard coded literal in the stack function definition (for eg. ‘Analytics’)
- Cost_To_Project value is fetched by selecting the value of the specified column name from the dataset (for eg. col(Analytics))
pivotDF.select($"Name", expr("stack(4, 'Analytics', Analytics, 'BI', BI, 'Ingestion', Ingestion, 'ML', ML) as (Project, Cost_To_Project)")).show(false) +------+---------+---------------+ |Name |Project |Cost_To_Project| +------+---------+---------------+ |Mickey|Analytics|null | |Mickey|BI |12000 | |Mickey|Ingestion|null | |Mickey|ML |8000 | |Martin|Analytics|null | |Martin|BI |5000 | |Martin|Ingestion|null | |Martin|ML |null | |Jerry |Analytics|null | |Jerry |BI |null | |Jerry |Ingestion|1000 | |Jerry |ML |null | |Riley |Analytics|null | |Riley |BI |null | |Riley |Ingestion|null | |Riley |ML |9000 | |Donald|Analytics|1000 | ---- ----
Finally, we can filter the DataFrame by ignoring records with Cost_To_Project NULL.
val unPivotDF = pivotDF.select($"Name", expr("stack(4, 'Analytics', Analytics, 'BI', BI, 'Ingestion', Ingestion, 'ML', ML) as (Project, Cost_To_Project)")).where("Cost_To_Project is not null") unPivotDF.show() +-------+---------+---------------+ | Name| Project|Cost_To_Project| +-------+---------+---------------+ | Mickey| BI| 12000| | Mickey| ML| 8000| | Martin| BI| 5000| | Jerry|Ingestion| 1000| | Riley| ML| 9000| | Donald|Analytics| 1000| | John|Ingestion| 1000| |Patrick| ML| 1000| | Emily|Analytics| 8000| | Emily|Ingestion| 3000| | Arya|Analytics| 10000| | Arya|Ingestion| 2000| +-------+---------+---------------+
Full code
val data = Seq( ("Ingestion", "Jerry", 1000), ("Ingestion", "Arya", 2000), ("Ingestion", "Emily", 3000), ("ML", "Riley", 9000), ("ML", "Patrick", 1000), ("ML", "Mickey", 8000), ("Analytics", "Donald", 1000), ("Ingestion", "John", 1000), ("Analytics", "Emily", 8000), ("Analytics", "Arya", 10000), ("BI", "Mickey", 12000), ("BI", "Martin", 5000)) import spark.sqlContext.implicits._ val df = data.toDF("Project", "Name", "Cost_To_Project") --pivot val pivotDF = df.groupBy("Name").pivot("Project").sum("Cost_To_Project") pivotDF.show() --unpivot val unPivotDF = pivotDF.select($"Name", expr("stack(4, 'Analytics', Analytics, 'BI', BI, 'Ingestion', Ingestion, 'ML', ML) as (Project, Cost_To_Project)")).where("Cost_To_Project is not null") unPivotDF.show()
2 Comments
[…] Check out this post if you need help understanding the stack function. […]
[…] You can use the stack() operation as mentioned in this tutorial. […]