JAX
Notes on the JAX ecosystem for machine learning, from the core transformation model to the surrounding training stack.
Scaling Up How distributed data parallelism, fully sharded data parallelism, tensor parallelism, and JAX sharding primitives fit together when scaling training.
JAX NumPy How JAX NumPy differs from NumPy: compiled execution, immutable arrays, explicit randomness, automatic vectorization, pytrees, and explicit sharding.
Introducing Flax NNX How Flax NNX gives JAX a stateful neural-network API while preserving explicit RNG streams, JIT compilation, autodiff, and Optax updates.
JAX AI Stack How JAX, XLA, Flax NNX, Optax, Orbax, and Grain fit together into a modern training stack, plus the role of jit, grad, and vmap.