1. What is JAX
JAX is a Machine Learning library which I would describe (very vaguely) as Numpy with auto differentiation that you can execute on GPUs (and TPUs too!). Additionally, it has XLA compilation and built-in vectorization and parallelization capabilities.
2. Graph Neural Networks
2.1 Graph Convolutional Networks
2.2 Graph Attention Networks
3. JAX implementation
3.1 Graph Convolutional Networks
3.2 Graph Attention Networks
3.3 Main loop
4. Other JAX resources
链接地址:http://gcucurull.github.io/deep-learning/2020/04/20/jax-graph-neural-networks/