Automatic Differentiation


Consider a function $F:\mathbb{R}^n\rightarrow\mathbb{R}^m$ defined as a composition of multiple functions: $F:C \circ B \circ A$ (i.e., $F(x)=C(B(A(x)))$). By chain rule:

$$\nabla F(x)=\nabla C(y) \cdot \nabla B(z) \cdot \nabla A(x)$$

where $z=A(x)$ and $y=B(z)$. Recall that $\nabla F(x)$ is the Jacobian matrix of $F$ (and hence is an $m\times n$ matrix). Notice that each component of the right hand size equation can be computed independently if we know $x$, $y$, and $z$. This gives us two ways to compute $\nabla F(x)$: through forward accumulation, where

$$\nabla F(x)=\nabla C(y) \cdot (\nabla B(z) \cdot \nabla A(x))$$

or through reverse accumulation, where

$$\nabla F(x)=(\nabla C(y) \cdot \nabla B(z)) \cdot \nabla A(x)$$

In machine learning workloads, the last function is usually a loss function that produces a single scalar, so $m = 1$ and the reverse accumulation can be seen as a sequence of vector-Matrix multiplications instead of matrix-matrix multiplications. Hence the reverse accumulation often utilizes Vector-Jacobian Product (VJP) operators.

Autograd allows developers to only write the forward model using familiar Python APIs (NumPy, TensorFlow, PyTorch, etc). It will automatically perform differentiations. It has two main components: 1. it defines the VJP for primitives like np.sum, math.sin, etc; 2. it traces data flow at runtime and generates a trace graph. This trace graph stores the input ($x$), the intermediate values ($z$ and $y$), and the dependency relationship of compute primitives involved ($A$, $B$, and $C$). Since the graph is generated at runtime, autograd is unaware of the control flow (e.g., if-else, while, etc) of the program: it will only keep a linear flow.

At this point we can discuss some trade-offs between forward accumulation and backward accumulation. The forward accumulation has constant memory overhead as it does not need the runtime trace and can compute the gradient directly during the forward pass (you can calculate $\nabla B(z)$ when you calculate $B(z)$ since it does not depend on anything generated in later operators). But consider a DNN with multiple layers whose weights $w_1, w_2, …$ that needs to be trained, at any point you need to keep all $\frac{\partial G}{\partial w_1}, \frac{\partial G}{\partial w_2}, …$ where $G$ is the composition function from beginning to the current intermediate point. This requires huge memory overhead: for a regular CNN, the output of a layer can be $32\times32\times3$ and 30 layers, each with $3\times3\times3\times64$ kernels. This is almost 1.6G floating point numbers in total. With more complicated DNN with larger input size and more kernels, the number of gradients cannot be all kept in memory and we will have to calculate only a portion of them in one pass. Hence forward accumulation typically requires many forward passes under machine learning network.

In contrast, reverse accumulation can be done with only one forward and one backward pass, significantly reduce the amount of computation required: with reverse accumulation, during the backward pass, we only need to keep $\frac{\partial F}{\partial w}$ and $\frac{\partial F}{\partial z}$ where $w$ includes the weight of the current layer and $z$ includes all intermediate values passed into the current layer as input. Recall that $F$ outputs a single scalar, the memory required is very limited. This, however, depends on the runtime trace and hence as a memory overhead proportional to the depth of the machine learning model. Autograd can reduce this through checkpointing: you can define your own primitives to be a series of arithmetic operations (in contrast to a single arithmetic operation in pre-defined primitives). During the reverse accumulation, it simply redoes a forward pass within that primitive to figure out the trace graph within this primitive.

JAX is a domain-specific tracing JIT compiler for generating high-performance accelerator code from pure Python and Numpy machine learning programs. It first takes the original code and generates an intermediate representation, JAX IR, which is used as a static trace for automatic differentiation and optimizations. The backend compiler XLA (Accelerated Linear Algebra) will also take JAX IR to produce efficient machine codes. Compared with Autograd that generates tracing graphs dynamically during runtime, JAX is aware of control flows and can generate more compact tracing graphs. However, if the graph changes slightly (say, the input has a different size), JAX will have to regenerate the trace.

Roy Frostig, Matthew James Johnson, and Chris Leary. 2018. Compiling machinelearning programs via high-level tracing.Systems for Machine Learning(2018),23–24.

Leave a Reply

Your email address will not be published. Required fields are marked *