Mixed precision and the art of using fewer bits
Numbers inside a neural network are usually stored in a big, careful format. If we switch most of them to a smaller, lighter format, training gets roughly twice as fast and uses about half the memory. The whole skill is knowing which few numbers we must leave in the careful format, and this walks you through it slowly, assuming you have never touched PyTorch.
First, what a number even looks like to a computer
Before we touch any machine learning, let us slow right down, because the whole article rests on one idea: a computer does not store a number the way you write it on paper. It stores it using a fixed number of slots called bits, where each bit is just a tiny switch that is either on or off. The more bits you give a number, the more detail it can hold. The fewer bits you give it, the more it has to round off.
A neural network is, underneath everything, a giant pile of numbers. Some of those numbers are the parameters, which are the values the model adjusts as it learns (you will also hear them called weights). Others are the activations, which are the numbers flowing through the network as it processes an input. There are millions or billions of these numbers, and every one of them takes up memory and takes time to multiply. So the format we store them in matters enormously.
The standard, careful format is called fp32, short for 32-bit floating point. It uses 32 bits per number and holds plenty of detail. The lighter format we are going to explore is called bf16, which uses only 16 bits. Half the bits means half the memory, and on the right hardware it means roughly double the speed. The catch, and the reason this is a skill rather than a free lunch, is that 16 bits cannot hold as much detail, so some numbers get rounded. Our job is to figure out where that rounding is harmless and where it is dangerous.
If this feels abstract right now, that is completely normal. The first widget below will make it concrete.
What bf16 keeps and what it throws away
A floating point number is split into two parts. One part, called the exponent, controls how big or small the number can be, all the way from tiny fractions to enormous values. The other part, called the mantissa, controls how many meaningful digits of precision you get. Think of the exponent as choosing the ballpark (thousands, millionths, billions) and the mantissa as choosing the fine detail within that ballpark.
Here is the clever design choice behind bfloat16, the full name for bf16. It keeps the exact same exponent as fp32, so it can represent the same huge range of magnitudes and will not suddenly overflow to infinity or collapse to zero. What it gives up is mantissa: it keeps only 7 mantissa bits, which works out to roughly 2 or 3 meaningful decimal digits. So bf16 always knows roughly how big a number is, but it is fuzzy about the fine detail.
For most of what a network does, that fuzziness is fine. The workhorse operation in a neural network is matrix multiplication (often shortened to matmul), which is just multiplying and adding up huge grids of numbers. When you add together hundreds or thousands of terms, the little rounding errors tend to cancel each other out, so the fuzziness mostly washes away.
The danger appears in one specific situation: subtracting two numbers that are very close together. When two nearby numbers are each a little fuzzy, the small difference between them can be almost entirely made of fuzz. This has a wonderfully dramatic name, catastrophic cancellation, and it is exactly the kind of operation that certain training methods rely on, so it is worth seeing with your own eyes.
Here is something you can play with. Drag the value below and watch how far its bf16 version drifts from the true value. Then look at what happens when you subtract two nearby values: the rounding error, which looked harmless on its own, suddenly dominates the answer.
On its own, a single value rounds to bf16 almost perfectly and you would never notice. But subtract two nearby bf16 values and the tiny errors no longer cancel, so the small true difference gets drowned out. This is exactly why some quantities are computed in fp32 instead.
Letting the computer choose the right precision for each step
Now to the practical question: if some operations are safe in bf16 and some are dangerous, how do we get each one in the right format without hand-labelling thousands of lines of code? PyTorch gives us a tool called autocast that does it automatically. You wrap your model’s forward pass in autocast, and it looks at each operation and picks the precision based on a simple rule.
The rule follows everything we just learned. Big matmuls, such as the linear layers and the attention computation, run in bf16 because they are heavy (so the speedup helps most) and forgiving (so the rounding does not hurt). The delicate operations stay in fp32. Those delicate ones are the steps that add up many terms or need careful precision: the normalization layers like LayerNorm and RMSNorm (which compute the mean and spread of a batch of numbers), the softmax (which turns a list of scores into probabilities), and the loss (the single number that measures how wrong the model currently is).
There is one more important detail. Even while activations dart through the network in bf16, the master copy of the parameters always stays in fp32, and so does the optimizer state, which is the bookkeeping the training algorithm keeps to update each parameter smoothly. Only the temporary forward-pass numbers go low. The valuable long-term values stay safe and precise.
One footnote for completeness. There is an older 16-bit format called fp16 that, unlike bf16, has a narrow exponent. With fp16 the smallest gradients can shrink so far that they round all the way to zero and vanish, a problem called underflow. To rescue them you add a helper called a GradScaler that temporarily multiplies the loss up to keep those tiny values in range. The lovely thing about bf16 is that its wide exponent makes this whole headache disappear, so you usually do not need a GradScaler at all.
Here is a widget to build your intuition. Look at each operation and see which precision it lands in, and notice how cleanly it follows the rule.
A rule of thumb you can carry with you: if an operation is matmul-heavy and forgiving of small errors, send it to bf16 for speed; if it sums up many terms or feeds directly into the loss, keep it in fp32 for safety.
A few traps worth remembering
Every idea here has a sharp edge that catches people the first time they try it. These are the ones to keep in your pocket.
Where this shows up in a real language model
This is not a toy idea you will outgrow. In a real training script, the forward and backward passes run inside a block that looks like with torch.autocast(device_type='cuda', dtype=torch.bfloat16):. The master weights stay in fp32, and the optimizer (commonly AdamW) does its updates in fp32. If a project uses the older fp16 path instead, you will see a GradScaler wrapped around the backward step to stop those tiny gradients from underflowing to zero. And in reinforcement learning code, which leans heavily on subtracting nearby log-probabilities, you will spot a deliberate logits.float() that forces the calculation back into fp32 so the small differences stay meaningful. Once you know what to look for, you will see these patterns everywhere.
The whole thing, runnable
If you have PyTorch installed, this is the shape of the real code. Read the comments as you go, they walk through exactly what each line is doing and why.
import torch
# bf16: same exponent (range) as fp32, but only 7 mantissa bits (precision)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits, loss = model(xb, yb) # heavy matmuls run in bf16; softmax and loss stay fp32 inside
loss.backward() # compute gradients; bf16's wide range means no GradScaler needed
opt.step() # update the fp32 master weights
# The older fp16 path DOES need a scaler so tiny gradients don't round down to zero:
scaler = torch.cuda.amp.GradScaler()
with torch.autocast("cuda", dtype=torch.float16):
logits, loss = model(xb, yb)
scaler.scale(loss).backward(); scaler.step(opt); scaler.update() # scale up, step, then reset the scale
# Reinforcement learning: always cast back to fp32 before taking log-probs,
# so subtracting two nearby values doesn't get swamped by rounding error
logp = torch.log_softmax(logits.float(), dim=-1)