Sparsity
It’s easy to run into quadratic complexity.
Gradients are often sparse. Consider backpropagating over fst :: (a, b) -> a
.
If gradient of a
is da
, then gradient of (a, b)
is (da, zeroV)
. The second
element is just a zero, but it might be a very fat zero – maybe
a large nested record of big arrays, all full of zeros! Extreme
case of gradient sparsity appears in indexing a vector. All but one elements of
gradient are zero. Constructing such a vector makes indexing \(O(n)\) operation.
Sparse gradients
Imperative implementations of backpropagation dodge this problem by updating
gradients in-place. While that’s possible to do in Haskell, there’s a better way –
builders. Builder is data type for efficient representation of sparse gradients.
Or it might be a ST
action that bumps the gradient in-place if you wish.
Downhill
library has a class for builders:
class Monoid (VecBuilder v) => BasicVector v where
type VecBuilder v :: Type
sumBuilder :: VecBuilder v -> v
BasicVector
is absolutely minimal requirement for a type to be eligible
to automatic differentiation.
Functions on graph edges produce builders. Nodes then mconcat
them and
pass to sumBuilder
.
For example, builder for pairs looks like this:
type VecBuilder (a, b) = Maybe (VecBuilder a, VecBuilder b)
Nothing
stands for zero vector. Maybe
is important here,
mempty
wouldn’t be cheap for deeply nested pairs otherwise.
Better AST
Expr
type in the library is different to that in previous part in a few ways.
First of all, it hasn’t got
pairs of vectors and gradients, such as a da v dv
. Just a v
. Two
sets of parameters allows both forward and reverse mode evaluation,
but we do reverse mode only here. Those would be da dv
for reverse
mode. We drop superfluous “d” and call them a
anv v
.
There’s also a little problem with our Expr
type.
As we’re going to convert it to a graph, we
need a clear separation between nodes and edges.
Func
is definetely an edge. Sum
itself is a
node, but it contains a mixed bag of adjacent edges and nodes.
We disallow this situation of nodes adjacent to nodes by
splitting AST into terms and expressions:
data Term a v where
Term :: (v -> VecBuilder u) -> Expr a u -> Term a v
data Expr a v where
ExprVar :: Expr a a
ExprSum :: BasicVector v => [Term a v] -> Expr a v
There’s v -> VecBuilder u
in place of PrimFunc
, which adds
support for sparse gradients and drops forward mode evaluation.
Also BasicVector
replaces AdditiveGroup
in ExprSum
.
Inline nodes
Builders are not enough. Say, we have a simple newtype wrapper for vector. Let’s have a closer look at what happens when we attempt to index it:
newtype MyVector a = MyVector { unMyVector :: Vector a }
myLookup :: MyVector a -> Int -> a
myLookup v i = unMyVector v ! i
If this code would be bluntly adapted to work on Expr
, the
tree would have three nodes with two edges between them:
Let’s see how gradients propagate when we flip edges:
Indexing function (! i)
produces lightweight gradient builder, as desired.
Only for the intermediate node (labeled in bold font) to convert it into a big fat vector, undoing all
optimization!
BackGrad
has the ability to relay gradients without summing them:
newtype BackGrad a v
= BackGrad
( forall x.
(x -> VecBuilder v) ->
Term a x
)
BackGrad
turns linear functions (x -> VecBuilder v
) to Term
s.
Alternatively, you could see it as a Term
data constructor with a hole in
place of Expr
argument. It generalizes Expr
:
realNode :: Expr a v -> BackGrad a v
realNode x = BackGrad (\f -> Term f x)
and provides means to apply linear function without creating a node:
inlineNode ::
forall r u v.
(VecBuilder v -> VecBuilder u) ->
BackGrad r u ->
BackGrad r v
inlineNode f (BackGrad g) = BackGrad go
where
go :: forall x. (x -> VecBuilder v) -> Term r x
go h = g (f . h)
Node that Expr
is a node, Term
is an edge. No Expr
– no node.
Sparse nodes
Inline nodes are still not enough. There’s no good way to access members
of tuples, or other product types for that matter. They are important,
because this library differentiates
unary functions BVar a -> BVar b
only. If we have many variables to differentiate with
respect to, we have to pack them together into single tuple or record BVar a
.
For a complex model a
might be a big structure of nested records.
Automatic differentiation starts with a single big variable containing all the data
and there has to be an efficient way to access all parts of it.
Constructing real Expr
nodes won’t cut, because they lose sparsity and
make the cost of accessing any member proportional to the size of the whole structure.
Inline nodes are not an
option, too. Accessing deeply nested members would create a long chain of inlineNode
s.
The cost of traversing the whole chain will have to be paid every time the variable
is used. This way a simple traversal of a list will turn into into
a Schlemiel the painter’s algorithm!
The solution is to store sparse gradients in graph nodes for this use case.
Luckily, there’s no need for new types of node here.
Have a look at BackGrad
definition – there’s
no v
, only VecBuilder
. This means we can choose a different type of node to
store gradient and hide it under BackGrad
as if nothing happened. No one can
possibly notice.
castBackGrad ::
forall r v z.
VecBuilder z ~ VecBuilder v =>
BackGrad r v -> BackGrad r z
castBackGrad (BackGrad g) = BackGrad g
Sparse gradients are wrapped in SparseVector
newtype for storage in graph.
Storing naked VecBuilder v
runs into a little problem – what’s
VecBuilder (VecBuilder v)
?
newtype SparseVector v = SparseVector
{ unSparseVector :: VecBuilder v }
sumBuilder :: VecBuilder v -> SparseVector v
doesn’t really sum anything,
it just stores unevaluated builders.
How does it differ from inline nodes? Turns out monoid operation of builders of product types plays a key role in intermediate nodes. It collects gradients from all successor nodes and packs them into a tuple/record before passing them to parent node as a single unit. This way gradients are assembled bottom up into a tree of the same shape as original data. Inline nodes would propagate gradients form each leaf node all the way to the root individually.