What is JAX? As described by the main JAX webpage, JAX is Autograd and XLA, brought together for high-performance machine learning research. JAX essentially augments the numpy library to create a nouvelle library with Autograd, Vector Mapping (vmap), Just In Time compilation (JIT), all compiled with Accelerated Linear Algebra (XLA) with Tensor processing unit (TPU) support and much more. With all of these features, problems that depend on linear algebra and matrix methods can be solved more efficiently. The purpose of this article is to show that indeed, these features can be used to solve a range of simple to complex optimization problems with matrix methods and to provide an intuitive understanding of the mathematics and implementation behind the code.
Firstly we will be required to import the JAX libraries and nanargmin/nanargmax from numpy, as they are not implemented in JAX yet. If you are using Google Colab, there is no installation of JAX required, as JAX is open sourced and maintained by Google.
2 Grad, Jacobians and Vmap
Grad is best used for taking the automatic derivative of a function. It creates a function that evaluates the gradient of a given function. If we called grad(grad(f)), this would be the second derivative.
Jacobian is best used for taking the automatic derivative of a function with a vector input. We can see that it returns the expected vector from a circle function.
Even more interesting is how we can compute the Hessian of a function by computing the Jacobian twice; this is what makes JAX so powerful! We see that the function hessian takes in a function and returns a function as well.
It should be noted that the gradients are computed with automatic differentiation, which is much more accurate and efficient compared to finite differences.