Install and compile Cython

Learn how to install and compile Cython with Databricks.

Written by Adam Pavlacka

Last published at: May 19th, 2022

This document explains how to run Spark code with compiled Cython code. The steps are as follows:

  1. Creates an example Cython module on DBFS (AWS | Azure).
  2. Adds the file to the Spark session.
  3. Creates a wrapper method to load the module on the executors.
  4. Runs the mapper on a sample dataset.
  5. Generate a larger dataset and compare the performance with native Python example.
Delete

Info

By default, paths use dbfs:/ if no protocol is referenced.

%python

# Write an example cython module to /example/cython/fib.pyx in DBFS.
dbutils.fs.put("/example/cython/fib.pyx", """
def fib_mapper_cython(n):
    '''
    Return the first fibonnaci number > n.
    '''
    cdef int a = 0
    cdef int b = 1
    cdef int j = int(n)
    while b<j:
        a, b  = b, a+b
    return b, 1
""", True)

# Write an example input file to /example/cython/input.txt in DBFS.
# Every line of this file is an integer.
dbutils.fs.put("/example/cython_input/input.txt", """
1
10
100
""", True)

# Take a look at the example input.
dbutils.fs.head("/example/cython_input/input.txt")

Add Cython Source Files to Spark

To make the Cython source files available across the cluster, we will use sc.addPyFile to add these files to Spark. For example,

%python

sc.addPyFile("dbfs:/example/cython/fib.pyx")

Test Cython compilation on the driver node

This code will test compilation on the driver node first.

%python

import pyximport
import os

pyximport.install()
import fib

Define the wapper function to compile and import the module

The print statements will get executed on the executor nodes. You can view the stdout log messages to track the progress of your module.

%python

import sys, os, shutil, cython

def spark_cython(module, method):
  def wrapped(*args, **kwargs):
    print 'Entered function with: %s' % args
    global cython_function_
    try:
      return cython_function_(*args, **kwargs)
    except:
      import pyximport
      pyximport.install()
      print 'Cython compilation complete'
      cython_function_ = getattr(__import__(module), method)
    print 'Defined function: %s' % cython_function_
    return cython_function_(*args, **kwargs)
  return wrapped

Run the Cython example

The below snippet runs the fibonacci example on a few data points.

%python

# use the CSV reader to generate a Spark DataFrame. Roll back to RDDs from DataFrames and grab the single element from the GenericRowObject
lines = spark.read.csv("/example/cython_input/").rdd.map(lambda y: y.__getitem__(0))

mapper = spark_cython('fib', 'fib_mapper_cython')
fib_frequency = lines.map(mapper).reduceByKey(lambda a, b: a+b).collect()
print fib_frequency

Performance comparison

Below we’ll test out the speed difference between the 2 implementations. We will use the spark.range() api to generate data points from 10,000 to 100,000,000 with 50 Spark partitions. We will write this output to DBFS as a CSV.

For this test, disable autoscaling (AWS | Azure) in order to make sure the cluster has the fixed number of Spark executors.

%python

dbutils.fs.rm("/tmp/cython_input/", True)
spark.range(10000, 100000000, 1, 50).write.csv("/tmp/cython_input/")

Normal PySpark code

%python

def fib_mapper_python(n):
  a = 0
  b = 1
  print "Trying: %s" % n
  while b < int(n):
    a, b = b, a+b
  return (b, 1)

print fib_mapper_python(2000)

lines = spark.read.csv("/tmp/cython_input/").rdd.map(lambda y: y.__getitem__(0))
fib_frequency = lines.map(lambda x: fib_mapper_python(x)).reduceByKey(lambda a, b: a+b).collect()
print fib_frequency

Test Cython code

Now test the compiled Cython code.

%python

lines = spark.read.csv("/tmp/cython_input/").rdd.map(lambda y: y.__getitem__(0))
mapper = spark_cython('fib', 'fib_mapper_cython')
fib_frequency = lines.map(mapper).reduceByKey(lambda a, b: a+b).collect()
print fib_frequency

The test dataset we generated has 50 Spark partitions, which creates 50 csv files seen below. You can view the dataset with dbutils.fs.ls("/tmp/cython_input/").