Weight initialization and the trick of sharing one matrix
Before a network learns anything, every one of its numbers has to start somewhere. The starting values turn out to matter enormously: pick them too big or too small and learning stalls before it begins. Here is how to choose them well, plus a lovely space-saving trick called weight tying, all explained assuming you have never touched PyTorch.
A few words before we begin
Let us start with the handful of terms we will lean on, so nothing trips you up later. A neural network is built from layers, and each layer holds a big grid of numbers called weights. A weight is simply a number the network multiplies your data by. When data passes through a layer, every weight scales part of it, the results get added up, and a new set of numbers comes out the other side. Those numbers flowing through the network are called activations, and an activation is nothing more than a single number on its way forward.
The whole point of training is to slowly adjust the weights until the network does something useful. But here is the question this article is about: what should those weights be on step one, before any learning has happened? They cannot all be zero, and as it turns out they cannot just be anything either. The very first values you fill them with, called the initialization (or init for short), quietly decide whether the network can learn at all. If this sounds like a strange thing to fuss over, that is a completely normal first reaction. By the end you will see exactly why it matters.
Why the starting scale of the weights is everything
Imagine a whisper passed down a long line of people. If each person shouts a little louder than they heard, by the end of the line the message is a deafening roar. If each person mumbles a little quieter, by the end it has faded to silence. A deep network is exactly that line, and each layer either amplifies or shrinks the signal a little as it passes through.
If the initial weights are too large, the activations grow bigger and bigger as they move through the layers until they explode into enormous numbers. If the weights are too small, the activations shrink toward zero and vanish. Either way the network goes deaf, and a quantity called the gradient (the signal that tells each weight which way to adjust) dies along with it, so training stalls and nothing improves.
Good init aims for the calm middle: it keeps the typical size of the activations roughly constant no matter how deep the network gets, so the whisper arrives at the end at the same volume it started. The way we measure “typical size” here is the variance, which is just a number describing how spread out a bunch of values are. We want that spread to stay steady all the way down.
How do we pull that off? We draw each weight at random from a normal distribution, the familiar bell curve, and we carefully choose how wide that bell curve is. The width is set by the standard deviation, usually written std, which controls how far the weights typically stray from zero. A few popular recipes:
- A fixed small width like
std = 0.02, the simple choice made famous by GPT-2. XavierorKaimingscaling, which sets the width to1/√fan_in. Herefan_injust means the number of inputs feeding into a neuron, so layers with more incoming connections get smaller weights, which keeps the signal balanced.
Many large language models add one more touch. They take the weights inside certain layers (the residual projections, the parts that write back into the network’s main signal pathway) and shrink them by an extra factor of 1/√(2·n_layer), where n_layer is the number of layers. This stops the main pathway from slowly accumulating more and more variance as it travels through block after block. If the details here feel dense, do not worry: the big idea is simply that we tune the width of the bell curve so the signal neither explodes nor fades.
Here is something you can play with. Pick an initialization scheme, then watch what happens to the activation size as a unit-variance input (one whose spread starts at exactly 1) passes through several layers. Drag between the schemes and feel the difference: too wide and the bars climb toward an explosion, too narrow and they collapse toward nothing, and only the tuned setting holds steady.
The histogram on the left shows the bell curve your weights are drawn from. The bars on the right track how big the activations stay as the signal travels through 6 layers. Notice that only the fan-in-tuned width keeps them hovering near 1 the whole way down, while the others either blow up or fade away.
Weight tying, or how one matrix can do two jobs
Now for a genuinely satisfying trick. In a language model, two different places use a big grid of numbers, and it turns out they can share the very same grid.
At the very start, the input embedding turns each word (more precisely each token, a small chunk of text the model reads) into a list of numbers called a vector. That grid has shape (vocab × C), meaning one row per word in the vocabulary and C numbers per row. At the very end, the output head does the reverse: it takes the model’s internal vector and produces a score for every possible next word. Those scores are called logits, and a logit is just a raw, unsquashed number saying how strongly the model favors each word. The output head has shape (C × vocab).
Look closely and you will notice those two grids are the same shape, just flipped (what we call transposed). Weight tying is the simple decision to use one single grid for both jobs, written in code as head.weight = token_embed.weight. The payoff is real on three fronts. First, you delete an entire vocab × C block of numbers, which is huge when the vocabulary is fifty thousand words or more, so the model gets noticeably smaller. Second, it acts as a gentle form of regularization, meaning it nudges the model away from memorizing and toward genuine learning. Third, it often improves quality outright. The shared grid simply receives learning signal from both ends at once, from the embedding lookup at the front and the output head at the back, and that double duty tends to make it better at both.
Pitfalls worth memorizing
Every idea here has a sharp edge that catches people the first time. Here are the ones to keep in your pocket.
How it shows up in a real LLM repo
This is not a toy idea you will outgrow. In an actual codebase, the model defines a small function usually called _init_weights and applies it to every layer at once with self.apply(...). Inside that function: every nn.Linear layer (a standard layer that multiplies its input by a grid of weights) gets its weights filled from normal_(0, 0.02), with its biases (small per-neuron offset numbers added at the end) set to zero. Every embedding gets the same 0.02 treatment. The residual projections get that extra 1/√(2·n_layer) shrink we met earlier. Weight tying is then a single line, self.lm_head.weight = self.token_embedding.weight, and because the two now point at the same grid, the saved model stores that matrix only once. Small decisions, real savings.
The whole thing, runnable
If you have PyTorch installed, you can read this top to bottom and see every idea from above in code. The comments walk through what each line is doing.
import torch, torch.nn as nn
def _init_weights(m):
# m is one layer of the model; we fill its weights based on what kind it is
if isinstance(m, nn.Linear):
# a Linear layer: draw weights from a bell curve of width 0.02, centered at 0
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if m.bias is not None: nn.init.zeros_(m.bias) # start the bias offsets at 0
elif isinstance(m, nn.Embedding):
# an Embedding (the token-to-vector table): same small bell curve
nn.init.normal_(m.weight, mean=0.0, std=0.02)
model.apply(_init_weights) # run _init_weights on every layer in the model
# shrink the residual-projection weights so signal doesn't pile up across blocks
for name, p in model.named_parameters():
if name.endswith("proj.weight"):
nn.init.normal_(p, std=0.02 / (2 * n_layer) ** 0.5)
# weight tying: point the output head and the input embedding at one shared matrix
model.lm_head.weight = model.token_embedding.weight