The Pipeline API from Spark-ML is quite useful, and we use it a lot at work, but I find it quite verbose at times. Here is what I'm talking about:

val first_variable_indexer = new StringIndexer()

val first_variable_encoder = new OneHotEncoder()

val second_variable_indexer = new StringIndexer()

val first_variable_encoder = new OneHotEncoder()

// A few more of the same

val pipeline = new Pipeline().setStages(Array(

Because the basic item of the API is the transformation (here, StringIndexer and OneHotEncoder), we have to declare each transformation, and the columns they act on. That means repeating each column, once as an input column and once as an output column (and get a runtime error if they don't match). That means a lot of boilerplate, and a transformation code that is hard to read.

What I would prefer is a concise way to express the transformations I want. I want to be able to write:

Take column "first_variable"
  pass it through a string indexer
  pass it through a one-hot encoder
  call the result "first_variable_vector"

Let's write that!


It took me some time alone with a pen and paper, and a few iteration to come up with this final version. I can't remember exactly how I came up with this, but it now seems like a pretty good idea:

I want my transformation to be plain Scala functions! If they are Scala functions, then I can just chain them and pass values through them just like classical functions. The previous pseudocode would be directly translated to:

val col1 = col("first_variable")
val col2 = stringIndexer(col1)
val col3 = oneHotEncoder(col2)
val result ="first_variable_vector")

Only one problem remains: if my transformations are functions, on what values should they operate?

If I remove the transformations from the last example, I'm left with one thing: a value col1, which is created by passing a string to the col function, and that has a method build that takes a string and return the result I want.

And now is the time to make a choice: what result do I want? Today, I want to build a DSL on top of the Spark Pipeline API, so the result I want is a Pipeline! To make things a little bit more composable, I'll use an Array[PipelineStage], which is what pipelines are made of.

Basic DSL

So, to sum-up, I have a type that represents of columns and has a build method that takes a string and returns a Array[PipelineStage]. I'm going to take the laziest definition for that type:

case class Col(build: String => Array[PipelineStage])

Yup. It's exactly what I wanted. Nothing less, nothing more. A col is something to which you give the name of the column where you want your output to be, and that gives you the pipeline to make it happen.

Now I have the basic type of my DSL, I need to define the functions that will operate on it, like stringIndexer and oneHotEncoder. We'll let the types guide us.

I want to write stringIndexer, which has type Col => Col, which is basically (String => Array[PipelineStage]) => (String => Array[PipelineStage]).So, I have a function that takes a string and returns a Array[PipelineStage], I also have a string, and I have to return a Array[PipelineStage]. I could just pass the string to the first function and be done with it, but that wouldn't be very useful. However, what I can do is call the build function to have a pipeline that outputs to an arbitrary column, then add a new pipeline stage that takes its input from that column, do what I want to do, and put its output in the right column.

The resulting code is:

def stringIndexer(col: Col): Col = {
  Col((nextCol: String) =>"arbitraryName") :+ new StringIndexer().setInputCol("arbitraryName").setOutputCol(nextCol))

oneHotEncoder can be defined in the exact same way. And with that, we have all we need to write the nice code we mentioned earlier.


We have built a DSL that has columns as first-class values, and transformations as functions from columns to columns.

This is only the beginning. There is much we can do from those foundations: have nice, non-colliding names for intermediate columns, make it easy to defined new transformations, make transformations that take multiples columns in input. We can even make our DSL check the types of our columns, or use nice combinators to make some nice syntactic sugar. All those points, that I won't develop in the present blog post, can be seen in the complete code this blog post is based on.

There is one last thing that I'm not satisfied with the current approach: col("someColumn").build("anotherColumn") returns an empty pipeline, which is not really expected. So there is room for improvement!

Thanks for reading my DSL experimentation! If you have any question about it, don't hesitate to get in touch (I'm @georgesdubus on twitter)!