JAX Tutorials
JAX Tutorial by the Engineers at DeepMind
- Link to that: JAX 101
- JAX as the accelerated Numpy
JAX as the accelerated Numpy
- It says it would work seemlessly with CPU, GPU and TPU so lets try and compare them on google colab
Test 1. Dot Operation
- As the example in the tutorial suggested, a vector with length of 1e7 was dotted into it self and magic %timeit was used to measure the runtime.
- On CPU: 8.36 ms
- On TPU: 10.7 ms
- On GPU: 283 us
- As we can see, the GPU is 30 times faster than the CPU and 37 times faster than the TPU.
- The question is why TPU then?
- Well, the TPU is a special kind of accelerator that is designed to be used in a distributed manner.
- Lets try another more interesting test
Test 2. Using Lattice Boltzmann method to simulate flow ove a cylinder
- Lattice Boltzmann is a method of simulation that instead of solving the pde inside the whole domain, it simulates individual particles very similar to cellular automatas that I implemented before
- Flow over cylinder is a frequently used case for validating the simulation methods. Here we will implement an LBM method for it to see how it works as well as a small study on the runtime