MilikMilik

How JAX and MaxText Cut Days Off Frontier LLM Training on Blackwell GPUs

How JAX and MaxText Cut Days Off Frontier LLM Training on Blackwell GPUs
Interest|High-Quality Software

Why LLM Training Throughput Now Decides Who Ships First

Large language model training throughput is the rate at which a training system can process tokens and model updates, and it has become the decisive factor for how fast AI labs can move from an idea to a frontier‑scale model in production. When pre‑training runs span trillions of tokens across thousands of accelerators, a tiny reduction in step time scales into days of saved wall‑clock time and major infrastructure savings. Low‑bit mixed‑precision formats promise big speed gains, but they are difficult to tune without hurting convergence. This is where the combination of JAX, the MaxText framework, and NVIDIA’s Blackwell GPUs enters the picture: they align compiler, framework, and hardware features so that ultra‑low‑precision math can be used safely, turning precision choice into a first‑class performance control knob for distributed model training.

Inside NVFP4: Sub‑byte Precision Built for Blackwell

NVFP4 is a 4‑bit floating‑point format designed for training that uses two‑level microscaling to encode higher‑magnitude signals with less error than earlier microscaling schemes. According to NVIDIA, native NVFP4 support on the NVIDIA GB300 Grace Blackwell Ultra Superchip delivers 7x GEMM throughput compared to native FP8 precision on the NVIDIA Hopper architecture. That gain, combined with a tailored pretraining recipe, shortens training step time with no measurable accuracy loss versus an FP8 baseline in their experiments. NVFP4 is applied first to the MLP layers of a transformer, which account for most training FLOPs, while attention blocks stay in higher precision to avoid amplifying quantization noise through softmax. The GEMMs for forward, activation gradients, and weight gradients consume NVFP4 inputs and emit BF16 outputs that are eventually folded into FP32 master weights, keeping long‑range numerical stability intact.

The Five Ingredients of the MaxText NVFP4 Training Recipe

The MaxText NVFP4 recipe assembles several techniques so 4‑bit mixed‑precision training converges as reliably as higher‑precision baselines. First, micro block scaling uses 16‑element blocks, halving the block size of MXFP4, so single outliers distort fewer values. Second, E4M3 block scale factors replace power‑of‑two E8M0 scaling and sit under a per‑tensor FP32 scale; in an 8B‑parameter, 1T‑token experiment, MXFP4 needed about 36% more tokens to match NVFP4’s final loss. Third, a Random Hadamard Transform is applied on WGRAD GEMM inputs to Gaussianize outliers without breaking 2D scaling. Fourth, 2D weight scaling assigns one FP8 scale per 16×16 weight block, ensuring forward and transposed backward paths share scales. Finally, stochastic rounding on gradient quantizers avoids crushing tiny weight updates to zero, while weights and activations stay on round‑to‑nearest‑even for lower error.

JAX and MaxText: Turning NVFP4 Into Distributed Throughput

The JAX MaxText framework closes the loop between the NVFP4 format and real‑world LLM training throughput by packaging the recipe into a configurable, distributed training stack. In MaxText, switching into NVFP4 is as simple as setting a quantization flag, with two modes exposed: te_nvfp4, which enables the full recipe including the Random Hadamard Transform, and te_nvfp4_no_rht, which removes RHT for the lowest overhead at some convergence risk. MaxText uses JAX, NVIDIA Transformer Engine, and Blackwell’s FP4 conversion instructions to quantize MLP GEMMs while leaving attention in higher precision, so most FLOPs run in 4‑bit without destabilizing training. Because these choices are implemented at the framework level, they scale cleanly across multi‑thousand GPU clusters, turning a single‑step latency win into days of saved time for trillion‑token runs.

Why Framework‑Level Optimizations Are Now a Competitive Moat

As LLMs grow and pre‑training runs stretch over trillions of tokens, framework‑level optimizations like NVFP4 support in MaxText become hard competitive advantages rather than niche tricks. Step‑time improvements compound over hundreds of thousands of iterations, directly affecting how quickly new architectures can be explored and how many models a given cluster can train per year. Labs now track both LLM training throughput and speed‑per‑dollar, similar to how serving benchmarks track tokens per second for inference, where rankings list models such as GPT‑oss 120B and Gemini 3.5 Flash by output rate. The same mentality is moving upstream into training systems. By tightly aligning numerical formats, compiler transformations, and distributed execution, JAX and MaxText show that lowering precision to NVFP4 on NVIDIA Blackwell GPUs can turn cutting‑edge hardware into faster model development cycles instead of only higher theoretical FLOPs.

How JAX and MaxText Cut Days Off Frontier LLM Training on Blackwell GPUs

Milik earns a commission when you shop through our links, at no extra cost to you. Editorial content is independently selected by our team.

You May Also Like

Comments
Say something...
No comments yet. Be the first to share your thoughts!