Symbolic Differentiation in Julia
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
A Brief Introduction to Metaprogramming in Julia
In contrast to my previous post, which described one way in which Julia allows (and expects) the programmer to write code that directly employs the atomic operations offered by computers, this post is meant to introduce newcomers to some of Julia’s higher level functions for metaprogramming. To make metaprogramming more interesting, we’re going to build a system for symbolic differentiation in Julia.
Like Lisp, the Julia interpreter represents Julian expressions using normal data structures: every Julian expression is represented using an object of type Expr
. You can see this by typing something like :(x + 1)
into the Julia REPL:
1 2 3 4 5 | julia> :(x + 1) :(+(x,1)) julia> typeof(:(x+1)) Expr |
Looking at the REPL output when we enter an expression quoted using the :
operator, we can see that Julia has rewritten our input expression, originally written using infix notation, as an expression that uses prefix notation. This standardization to prefix notation makes it easier to work with arbitrary expressions because it removes a needless source of variation in the format of expressions.
To develop an intuition for what this kind of expression means to Julia, we can use the dump
function to examine its contents:
1 2 3 4 5 6 7 8 | julia> dump(:(x + 1)) Expr head: Symbol call args: Array(Any,(3,)) 1: Symbol + 2: Symbol x 3: Int64 1 typ: Any |
Here you can see that a Julian expression consists of three parts:
- A
head
symbol, which describes the basic type of the expression. For this blog post, all of the expressions we’ll work with havehead
equal to:call
. - An
Array{Any}
that contains the arguments of thehead
. In our example, thehead
is:call
, which indicates a function call is being made in this expression. The arguments for the function call are: :+
, the symbol denoting the addition function that we are calling.:x
, the symbol denoting the variablex
1
, the number 1 represented as a 64-bit integer.- A
typ
which stores type inference information. We’ll ignore this information as it’s not relevant to us right now.
Because each expression is built out of normal components, we can construct one piecemeal:
1 2 | julia> Expr(:call, {:+, 1, 1}, Any) :(+(1,1)) |
Because this expression only depends upon constants, we can immediately evaluate it using the eval
function:
1 2 | julia> eval(Expr(:call, {:+, 1, 1}, Any)) 2 |
Symbolic Differentiation in Julia
Now that we know how Julia expressions are built, we can design a very simple prototype system for doing symbolic differentiation in Julia. We’ll build up our system in pieces using some of the most basic rules of calculus:
- The Constant Rule:
d/dx c = 0
- The Symbol Rule:
d/dx x = 1
,d/dx y = 0
- The Sum Rule:
d/dx (f + g) = (d/dx f) + (d/dx g)
- The Subtraction Rule:
d/dx (f - g) = (d/dx f) - (d/dx g)
- The Product Rule:
d/dx (f * g) = (d/dx f) * g + f * (d/dx g)
- The Quotient Rule:
d/dx (f / g) = [(d/dx f) * g - f * (d/dx g)] / g^2
Implementing these operations is quite easy once you understand the data structure Julia uses to represent expressions. And some of these operations would be trivial regardless.
For example, here’s the Constant Rule in Julia:
1 | differentiate(x::Number, target::Symbol) = 0 |
And here’s the Symbol rule:
1 2 3 4 5 6 7 | function differentiate(s::Symbol, target::Symbol) if s == target return 1 else return 0 end end |
The first two rules of calculus don’t actually require us to understand anything about Julian expressions. But the interesting parts of a symbolic differentiation system do. To see that, let’s look at the Sum Rule:
1 2 3 4 5 6 7 8 9 | function differentiate_sum(ex::Expr, target::Symbol) n = length(ex.args) new_args = Array(Any, n) new_args[1] = :+ for i in 2:n new_args[i] = differentiate(ex.args[i], target) end return Expr(:call, new_args, Any) end |
The Subtraction Rule can be defined almost identically:
1 2 3 4 5 6 7 8 9 | function differentiate_subtraction(ex::Expr, target::Symbol) n = length(ex.args) new_args = Array(Any, n) new_args[1] = :- for i in 2:n new_args[i] = differentiate(ex.args[i], target) end return Expr(:call, new_args, Any) end |
The Product Rule is a little more interesting because we need to build up an expression whose components are themselves expressions:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | function differentiate_product(ex::Expr, target::Symbol) n = length(ex.args) res_args = Array(Any, n) res_args[1] = :+ for i in 2:n new_args = Array(Any, n) new_args[1] = :* for j in 2:n if j == i new_args[j] = differentiate(ex.args[j], target) else new_args[j] = ex.args[j] end end res_args[i] = Expr(:call, new_args, Any) end return Expr(:call, res_args, Any) end |
Last, but not least, here’s the Quotient Rule, which is a little more complex. We can code this rule up in a more explicit fashion that doesn’t use any loops so that we can directly see the steps we’re taking:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | function differentiate_quotient(ex::Expr, target::Symbol) return Expr(:call, { :/, Expr(:call, { :-, Expr(:call, { :*, differentiate(ex.args[2], target), ex.args[3] }, Any), Expr(:call, { :*, ex.args[2], differentiate(ex.args[3], target) }, Any) }, Any), Expr(:call, { :^, ex.args[3], 2 }, Any) }, Any) end |
Now that we have all of those basic rules of calculus implemented as functions, we’ll build up a lookup table that we can use to tell our final differentiate
function where to send new expressions based on the kind of function’s that being differentiated during each call to differentiate
:
1 2 3 4 5 6 | differentiate_lookup = { :+ => differentiate_sum, :- => differentiate_subtraction, :* => differentiate_product, :/ => differentiate_quotient } |
With all of the core machinery in place, the final definition of differentiate
is very simple:
1 2 3 4 5 6 7 8 9 10 11 | function differentiate(ex::Expr, target::Symbol) if ex.head == :call if has(differentiate_lookup, ex.args[1]) return differentiate_lookup[ex.args[1]](ex, target) else error("Don't know how to differentiate $(ex.args[1])") end else return differentiate(ex.head) end end |
Ive put all of these snippets together in a single GitHub Gist. To try out this new differentiation function, let’s copy the contents of that GitHub gist into a file called differentiate.jl
. We can then load the contents of that file into Julia at the REPL using include
, which will allow us try out our differentiation tool:
1 2 3 4 5 6 7 | julia> include("differentiate.jl") julia> differentiate(:(x + x*x), :x) :(+(1,+(*(1,x),*(x,1)))) julia> differentiate(:(x + a*x), :x) :(+(1,+(*(0,x),*(a,1)))) |
While the expressions that are constructed by our differentiate
function are ugly, they are correct: they just need to be simplified so that things like *(0, x)
are replaced with 0
. If you’d like to see how to write code to perform some basic simplifications, you can see the simplify
function I’ve been building for Julia’s new Calculus package. That codebase includes all of the functionality shown here for differentiate
, along with several other rules that make the system more powerful.
What I love about Julia is the ease with which one can move from low-level bit operations like those described in my previous post to high-level operations that manipulate Julian expressions. By allowing the programmer to manipulate expressions programmatically, Julia has copied one of the most beautiful parts of Lisp.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.