Iterate through a Spark DataFrame using its partitions in Java

My work at WSO2 Inc mainly revolves around the Business Activity Monitor (BAM)/ Data Analytics Server (DAS). For the DAS 3.0 release, we are bringing in Apache Spark as the analytics engine to the WSO2 Carbon Platform replacing Apache Hadoop and Apache Hive. I am working on this Spark migration. Spark introduces an interesting concept of RDDs to the analytics community. I am not going go into details about the RDDs. Please click here for further information. Once an RDD is created in the Spark “world”, it can be used for data manipulation/ analysis etc.

SparkSQL the SQL query engine for Spark, uses an extension of this RDD called, DataFrame, formerly called a SchemaRDD. For further information, click here.

Here I will be discussing how to use the partitions of a DataFrame to iterate through the underlying data… and some useful debugging tips in the Java environment. (Thought this was useful because, Spark is written in Scala, hence almost all of its features heavily use Scala functionalities… and when we bring it to the Java env, things might not work as expected!)

DataFrames… WTH?

As per Spark,

A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs.

Problem: How to retrieve data? Take all the elements?

A DataFrame (DF) encapsulates data in Rows and we can retrieve these Rows as a list or as an array, using the following collect methods in a DF.

def collect(): Array[Row]
def collectAsList(): java.util.List[Row]

But the problem here is, a ‘collect’ method collects all the data under a DF (in RDD jargon, it is an action op). Since Spark uses in-memory processing, if this DF covers a large data set, the collect operation will be inefficient.

Solution: Take data using Partitions!

Using underlying partitions of a DF gives a better solution for this!

def foreachPartition(f: Iterator[Row] => Unit): Unit

As you could see in the method signature, it takes function as the method parameter, and this function takes a Row Iterator and returns a Unit. So here, it would not collect all the data under the DF at once!

Example:

val b = sc.parallelize(List(1, 2, 3, 4, 5, 6, 7, 8, 9), 3)
b.foreachPartition(x => println(x.reduce(_ + _))) 

From Scala to Java… 

This looks fairly straightforward in Scala, but when is comes to Java, things are a little ‘messy’! There we would have to implement the Scala ‘anonymous function‘ .

In WSO2 DAS we have implemented this in our implementation of Spark BaseRelation.

data.foreachPartition(new AbstractFunction1<Iterator<Row>, BoxedUnit>(){
 @Override
 public BoxedUnit apply(Iterator<Row> v1) {
// your logic goes here... 
 return BoxedUnit.UNIT;
 }
 });

Here we have implemented the scala.Function1 using scala.runtime.AbstractFunction1. BoxedUnit here is equivalent to a void result in Java.

Troubleshooting tips….

  • One important thing to note here is that, in a distributed environment, this anonymous function will be serialized to all Spark Workers.
  • All non-serializing objects should be instantiated within the

Houston, We’ve Got a Problem!

BUT in the real DAS implementation, we came across the following exception…

org.apache.spark.SparkException: Job aborted due to stage failure: Task not serializable: java.io.NotSerializableException: ...

It turns out that when we implement the scala.runtime.AbstractFunction1 in Java environment, it is not readily serializable. (AbstractFunction1 does not implement java.io.Serializable interface)

So, as a solution, we implemented our own AbstractFunction1 implementation as follows, and it actually worked!

public class AnalyticsFunction1 extends AbstractFunction1<Iterator<Row>, BoxedUnit>
 implements Serializable {

 private static final long serialVersionUID = -1919222653470217466L;
 private int tId;
 private String tName;
 private StructType sch;

 public AnalyticsFunction1(int tId, String tName, StructType sch) {
 this.tId = tId;
 this.tName = tName;
 this.sch = sch;
 }

 @Override
 public BoxedUnit apply(Iterator<Row> iterator) {
 List<Record> records = new ArrayList<>();
 while (iterator.hasNext()) {
 if (records.size() == AnalyticsConstants.MAX_RECORDS) {
 try {
 ServiceHolder.getAnalyticsDataService().put(records);
 } catch (AnalyticsException e) {
 e.printStackTrace();
 }
 records.clear();
 } else {
 Row row = iterator.next();
 records.add(new Record(this.tId, this.tName,
 convertRowAndSchemaToValuesMap(row, this.sch)));
 }
 }

 if (!records.isEmpty()) {
 try {
 ServiceHolder.getAnalyticsDataService().put(records);
 } catch (AnalyticsException e) {
 e.printStackTrace();
 }
 }
 return BoxedUnit.UNIT;
 }

 private Map<String, Object> convertRowAndSchemaToValuesMap(Row row, StructType schema) {
 String[] colNames = schema.fieldNames();
 Map<String, Object> result = new HashMap<>();
 for (int i = 0; i < row.length(); i++) {
 result.put(colNames[i], row.get(i));
 }
 return result;
 }

}

It can be instantiated as follows…


data.foreachPartition(new AnalyticsFunction1(tenantId, tableName, data.schema()));

You can access the GitHub repo here.

sum up!

  • While iterating through a DataFrame, try your best to avoid ‘collect()’ method
  • Try to always use a method which returns an iterator to the data.
  • When using ‘foreachPartition()’ in Java, implement the anonymous function in an extended class which implements java.io.Serializable interface
  • Read this Databricks Knowledgebase: General Troubleshooting 

2 thoughts on “Iterate through a Spark DataFrame using its partitions in Java

  1. Hi, I am trying to get this code from github, but looks like link is not working. Can you provide the update link to get the complete code?

    Like

Leave a comment