LayerNorm and RMSNorm, the gentle reset button for deep networks
Stack enough layers and the numbers flowing through a network tend to drift, growing huge or shrinking to nothing until training falls apart. Normalization is a small, reliable trick that resets those numbers to a sane size at every layer. Here is the whole idea, built up slowly, assuming you have never touched PyTorch.
First, the problem we are trying to solve
Imagine whispering a message down a long line of people. By the time it reaches the end, it has either grown into a shout or faded into silence. Deep neural networks have the same weakness. A network is built from many stacked layers, and the numbers that flow through it pass from one layer to the next, over and over. If each layer nudges those numbers a little larger, they can balloon into enormous values by the time they reach the top. If each layer shrinks them, they can fade toward zero. Either way the network stops learning, and training quietly breaks.
Before we go further, a couple of words that will keep coming up. A tensor is just a container of numbers, like a list or a grid of numbers bundled together so we can do math on all of them at once. An activation is one of those numbers as it flows forward through the network. When we say numbers “flow through the network,” these activations are exactly what we mean. You do not need any more than that to follow along.
Normalization is a small, dependable fix for the whispering-down-the-line problem. At each layer we take the activations and reset them to a sensible, predictable size before passing them on. Because every layer receives numbers that are already in a comfortable range, nothing balloons and nothing fades. This is what lets people stack twenty-four layers, or far more, and still have the whole thing train smoothly. Two flavors of this trick power almost every large language model today, and we will meet both.
LayerNorm, center then scale
The first flavor is called LayerNorm, and it does two things in order. First it centers the numbers, and then it scales them.
Let us unpack that with the two ideas it leans on. The mean is just the average of a group of numbers. The standard deviation is a measure of how spread out those numbers are: a small standard deviation means they huddle close to the average, and a large one means they are scattered widely. LayerNorm computes both for a single token’s vector of activations, then subtracts the mean (so the numbers now average to zero) and divides by the standard deviation (so their spread becomes one). After that step, every token’s activations have a mean of about 0 and a standard deviation of about 1, no matter how large or off-center the input was. That is the whole reset.
There is one more touch. After centering and scaling, LayerNorm multiplies by a learnable scale called γ (the Greek letter gamma) and adds a learnable shift called β (beta). “Learnable” means these are not fixed by us; they are parameters, numbers the network adjusts on its own as it trains, the same way it tunes everything else it learns. We reset the activations to a clean 0 mean and 1 spread, and then we hand the network two dials so it can stretch or move them back if that turns out to help. If this feels strange, that is normal. The short version is: we standardize first, then let the model fine-tune the result.
If you like the precise recipe, here it is in one line, with everything inside backticks so the math reads cleanly: (x − μ) / √(σ² + eps) · γ + β. Here μ is the mean, σ² is the variance (the standard deviation squared), and eps is a tiny number we add so we never accidentally divide by zero. Two small but important details that trip people up: the variance uses the biased estimator (a particular way of averaging that LayerNorm always picks), and the eps sits inside the square root, not outside it. Do not worry if those distinctions feel fussy right now; they matter when you reproduce the math by hand, which we will do at the end.
RMSNorm, just scale
The second flavor is called RMSNorm, and it is LayerNorm with the first half thrown away. It skips the centering step entirely (no subtracting the mean) and drops the β shift too. All it does is divide by the root mean square of the activations, which is another way of measuring their typical size, and then multiply by the learnable scale γ.
Why bother with a stripped-down version? Because it is cheaper to compute and, in practice, works just as well. That combination is exactly what you want when you are training enormous models, which is why LLaMA, Mistral, and most modern large language models reach for RMSNorm instead of LayerNorm. Doing less, faster, with no loss in quality is a very good trade.
Here is something you can play with. In the widget below, the input is deliberately given a bad size and offset (every number is multiplied by 5 and then 3 is added), exactly the kind of mess normalization is meant to clean up. Reroll the input to get fresh random numbers, drag the γ and β dials, and switch between the two norms. Watch the output statistics snap back to a sane range no matter how wild the input is.
No matter how large or off-center the input, LayerNorm pulls its output back to a mean of about 0 and a standard deviation of about 1 (before the γ and β dials touch it). RMSNorm fixes the size but leaves the centering alone, so notice that its output mean is not zero. Drag the dials and reroll to feel the difference.
A few traps worth remembering
Every idea here has a sharp edge that catches people the first time. These are the ones to keep in your pocket.
Where this shows up in a real language model
This is not a toy idea you will outgrow; it sits at the heart of every transformer. A transformer is built from repeated blocks, and inside each block a norm is applied right before the attention step and again right before the small neural network called the MLP. Placing the norm first, before each of those steps, is called pre-norm, and it is the layout that keeps deep models stable. There is also one final norm applied just before the LM head, which is the very last layer that turns the network’s internal numbers into predictions about the next word.
A couple of grounding details. The learnable γ and β have shape (n_embed,), meaning there is one dial per feature in a token’s vector. And a from-scratch RMSNorm is genuinely only a few lines: x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * self.weight. The key thing both norms share is that they normalize over the channel axis only, which is the list of features inside a single token. The batch axis (the separate examples we process together) and the time axis (the separate positions in the sequence) are left untouched, so each token is reset on its own.
The whole thing, runnable
If you have PyTorch installed, you can paste this into a file and run it. Read the comments as you go; they walk through exactly what each line does.
import torch, torch.nn as nn
x = torch.randn(2, 3, 4) * 5.0 + 3.0 # a batch of token vectors, deliberately given a bad scale and offset
ln = nn.LayerNorm(4) # LayerNorm over the last dimension, with default eps=1e-5 and learnable gamma/beta
y = ln(x)
row = y[0, 0] # grab one token's vector so we can check the statistics
print(round(row.mean().item(), 5)) # about 0.0: the centering step did its job
print(round(row.std(unbiased=False).item(), 4))# about 1.0: the scaling step did its job
# now reproduce LayerNorm by hand to see there is no magic inside
mean = x.mean(-1, keepdim=True) # the average of each token's features
var = x.var(-1, keepdim=True, unbiased=False) # the spread, using the biased estimator LayerNorm expects
manual = (x - mean) / torch.sqrt(var + ln.eps) * ln.weight + ln.bias # note: eps lives INSIDE the square root
print(torch.allclose(manual, y, atol=1e-6)) # True: our hand-built version matches PyTorch exactly
class RMSNorm(nn.Module): # the lighter cousin used by most modern LLMs
def __init__(self, dim, eps=1e-6):
super().__init__(); self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # one learnable scale per feature, no shift at all
def forward(self, x):
# divide by the root-mean-square size, then apply the learnable scale; no mean is ever subtracted
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight