Linear computational graph
One way to do reverse mode automatic differentiation
is to create a DSL for differentiable functions and use StableName
s to recover
sharing information and construct computational graph.
We adopt this approach, but with a twist.
Source code
Full source code of linear AST as explained in this section can be found here.
The twist
Derivative of function \(f\) is a linear function \(f'\), which is the best local linear approximation of \(f\) at some point \(x_0\).
There’s no need to construct a graph for the whole function \(f\). Graph for \(f'\) is enough.
Linear functions are much simpler than more general differentiable functions and lend much better to the Haskell type system, as we will see later.
The idea is to start like in forward mode differentiation:
data BVar a = BVar
{ bvarValue :: a
, bvarGrad :: Expr a
}
Except that bvarGrad
is not the value of the gradient, it’s an
abstract syntax tree of gradient’s computation instead.
This way bvarGrad
is a linear expression, by construction.
Building the graph for bvarGrad
only (as opposed to bvarValue
) greatly
reduces the scope
and complexity of the otherwise tricky “reverse” part of differentiation algorithm.
Linear maps
Automatic differentiation is all about vector spaces and linear maps. Let me quickly introduce them.
Vector spaces are covered by vector-space
package.
Two most relevant operations are vector
addition and scalar-vector multiplication:
(^+^) :: v -> v -> v
(*^) :: Scalar v -> v -> v
Linear map is a mapping \(f: U \to V\), where \(U\) and \(V\) are vector spaces, satisfying the following conditions:
While linear maps are conceptually just functions, we can’t represent all of them as Haskell functions, as that would lead to terrible algorithmic complexity.
The best way to represent a linear map depends on the vector spaces in question and on the nature of linear map itself. Threfore we introduce a class for linear maps with an operator to evaluate them. The choice of operator comes from the fact that linear maps can be represented as matrices (or, more generally, tensors) and evaluation corresponds to matrix-vector product.
class TensorMul u v where
type u ✕ v :: Type
(✕) :: u -> v -> u ✕ v
If f
represents linear map \(U \to V\) and u :: U
,
then f ✕ u :: V
evaluates \(f(u)\).
Such a general operator wouldn’t be very good for a Haskell library.
More specific functions
have better type inference, better error messages and make code easier to read
and navigate. Operator ✕
is supposed to mean tensor product followed by contraction,
but there might be multiple sensible contractions, with no way to choose the
right one at each call site.
Anyway, it is very useful for explaining things and demonstrating that
quite a few operations are actually the same.
Laws of linear map can now be translated to Haskell:
f ✕ (u ^+^ v) = f ✕ u ^+^ f ✕ v
f ✕ (a *^ u) = a *^ (f ✕ u)
Multiplication ✕
distributes over addition on the other side, too,
because linear maps form a vector space themselves:
(f ^+^ g) ✕ u = f ✕ u ^+^ g ✕ u
(a *^ f) ✕ u = a *^ (f ✕ u)
A common case in backpropagation is domain of \(f\) being scalar. We will name
it \(\mathbb{R}\) to make this text more intuitive, though actual type of the scalar isn’t
really important. Gradient of variable \(u \in U\) in this case is a linear map \(u^*: U \to \mathbb{R}\).
Vector space of such linear maps is said to be dual vector space of \(U\).
Translating this to Haskell and choosing name du
for \(u^*\) gives
u :: u
du :: du
du ✕ u :: R
We use lowercase type variables u
and du
, because all automatic differentiation
code will be polymorphic – u
and du
are type variables.
Going back to matrix analogy, if u
is a column vectors, then
du
is a row vector and their product is a scalar.
Here du
can be seen not only as a (row) vector, but also as a function:
(du ✕) :: u -> R
Vector u
can be seen as a function, too:
(✕ u) :: du -> R
There’s a nice symmetry between u
and du
– both have data representation,
both have function representation and both are duals of each other.
Another important operation besides evaluation is composition. We don’t need
another operator, because ✕
fits the bill. If you see linear maps as matrices,
composition is matrix multiplication. This usage of ✕
gives rise to associativity law.
Here are associative law of ✕
together with the laws of usual Haskell
function application and composition operators, put together to show relation between them:
(f . g) $ u = f $ (g $ u)
(f ✕ g) ✕ u = f ✕ (g ✕ u)
(f . g) . h = f . (g . h)
(f ✕ g) ✕ h = f ✕ (g ✕ h)
PrimFunc
The first ingredient of linear computational graphs is linear functions of a single argument.
data PrimFunc u du v dv = PrimFunc
{ fwdFun :: u -> v
, backFun :: dv -> du
}
PrimFunc
is made of two parts: u -> v
evaluates this function,
while dv -> du
backpropagates gradient. Given it’s a linear
map, it should have TensorMul
instance, but unfortunately we quickly run into
overlapping instances problem. We resort to newtype wrappers to overcome it.
newtype Vec x = Vec { unVec :: x }
This little nuisance is a consequence of overly general TensorMul
class. The instance
can now be given:
instance TensorMul (PrimFunc u du v dv) (Vec u) where
type (PrimFunc u du v dv) ✕ (Vec u) = Vec v
(PrimFunc f _) ✕ Vec v = Vec (f v)
That was forward mode evaluation.
Can you guess which operator we’re going to use for reverse mode?
Of course, it has to be ✕
.
There’s one more way to use it – on the left of the function:
f ✕ u :: v
dv ✕ f :: du
Matrix analogy goes a long way here – if u
and v
are a column vectors,
du
and dv
are row vectors, then f
is a matrix and ✕
is matrix-vector
or vector-matrix multiplication. Another thing worth mentioning – there are no transpositions
of matrices in sight. Matrix transposition assumes Hilbert space, we shouldn’t be expecting
them here.
Since we already have newtype wrappers for vectors, we might create a different one for gradients.
newtype Cov x = Cov {unCov :: x}
Cov
stands for covector. It doesn’t have much to do with variance and covariance,
it just indicates that the variable should be positioned on the left side of the function.
instance TensorMul (Cov dv) (PrimFunc u du v dv) where
type (Cov dv) ✕ (PrimFunc u du v dv) = Cov du
Cov v ✕ (PrimFunc _ f) = Cov (f v)
Function \(f\) can be seen as a bilinear form. If Haskell allowed such notation:
(✕ f ✕) :: dv -> u -> R
Associative law comes into play again:
dv ✕ (f ✕ u) = (dv ✕ f) ✕ u
dv ✕ fwdFun f u = backFun f dv ✕ u
This means fwdFun
and backFun
can’t be arbitrary linear maps – above
equation must hold for all choices of dv
and u
. Mathematically, this
law says that backFun must be transpose of fwdFun. That’s pretty much the
definition of tranpose of a linear map.
AST
We are ready to start building our AST:
data Expr a da v dv where
Var :: Expr a da a da
Func :: PrimFunc u du v dv -> Expr a da u du -> Expr a da v dv
Expr a da v dv
is a linear expression of type v
with one free variable of type a
.
Linear functions with multiple arguments
There’s a difference between linear and bilinear (or multilinear) functions. Linear functions with two variables satisfy this equation: $$ f(x_1+x_2,y_1+y_2) = f(x_1, y_1) + f(x_2, y_2) $$
Multiplication, for example, is bilinear, not linear because $$ (a+b) \cdot (x+y) \ne a \cdot x + b \cdot y $$
Linearity is a much stronger restriction than multilinearity. It turns turns out any linear function can be written as a sum of one variable linear functions: $$ f(x_1, x_2, …, x_n) = f_1(x_1) + f_2(x_2) + \cdots + f_n(x_n) $$ for some \(f_1\), \(f_2\), …, \(f_n\).
We have all the pieces to finish AST definition:
data Expr a da v dv where
Var :: Expr a da a da
Func :: PrimFunc u du v dv -> Expr a da u du -> Expr a da v dv
Sum :: AdditiveGroup v => [Expr a da v dv] -> Expr a da v dv
Thats it! That’s all we need to evaluate in reverse mode.
Evaluation
Evaluating Expr
directly is inefficient – we should recover sharing
information first. Anyways, let’s see what needs to be evaluated first.
Expr a da v dv
represents a function a -> v
, so it’s natural to give it
a TensorMul
instance. The code writes itself:
instance TensorMul (Expr a da v dv) (Vec a) where
type Expr a da v dv ✕ Vec a = Vec v
expr ✕ a = case expr of
Var -> a -- Var is identity function
Func f v -> f ✕ (v ✕ a) -- Func f v = f ✕ v
Sum vs -> sumV [v ✕ a | v <- vs]
Reverse mode evaluation is also straightforward:
instance AdditiveGroup da => TensorMul (Vec dv) (Expr a da v dv) where
type Vec dv ✕ (Expr a da v dv) = Vec da
dv ✕ expr = case expr of
Var -> dv
Func f v -> (dv ✕ f) ✕ v
Sum vs -> sumV [dv ✕ v | v <- vs]
Transposition
While evaluation code is mechanical and boring, it gets more interesting with transposition. In order not to spoil the fun, it’s left as a puzzle for the reader.
transposeExpr :: AdditiveGroup da => Expr a da v dv -> Expr dv v da a
transposeExpr = _
Can you fill in the hole? This time code doesn’t write itself. At first I attacked this problem using type tetris approach myself, but that proved too hard and I had to reach for a pen and paper. You can see my solution here