The training loop, where a model actually learns
This is the heartbeat of every neural network. One step is just five small moves done in a fixed order, repeated again and again, and slowly the model gets better. We will walk through each move gently, assuming you have never written a line of PyTorch.
What we mean by training
Before we touch any code, let us be clear about what is happening when a model learns, because once you see it the rest falls into place.
A neural network is full of numbers called parameters. You can think of a parameter as a single adjustable knob. A large model has billions of these knobs, and at the start they are set to random values, so the model is basically guessing. Training is the patient process of nudging every knob a tiny bit, over and over, until the model’s guesses get good.
To know which way to nudge each knob, we need a way to measure how wrong the model currently is. That measurement is a single number called the loss. A high loss means the model is far off. A low loss means it is doing well. So the whole goal of training is simple to state: make the loss go down. Everything below is just the machinery that makes that happen, one careful step at a time.
If any of these words feel unfamiliar right now, that is completely normal. We will define each one again the moment we use it.
Five moves, in order
Here is the secret that surprises most beginners: one training step is always the same five moves, in the same order, no matter how big or fancy the model is. Learn these five and you have learned the engine that powers all of it.
- Zero the gradients with
zero_grad(). A gradient is just the answer to the question “which direction should I nudge this knob, and how strongly, to lower the loss?” PyTorch adds up gradients as it computes them, so before each new step we wipe last step’s gradients back to zero. If we forget this, old gradients pile on top of new ones and the model gets confused. - Forward pass with
forward. This means feeding data through the model to get its predictions. For a language model the raw predictions are called logits, which are just the model’s unscaled scores for every possible next word. From those predictions and the correct answers, we compute the loss, our single number for how wrong we were. - Backward pass with
backward(). This is where the magic lives. PyTorch works backwards through every calculation and fills in the gradient for each parameter automatically, storing it in a slot called.grad. You never have to do this calculus by hand, which is one of the kindest things PyTorch does for you. - Clip the gradients with
clip_grad_norm_(). Once in a while a batch of data produces enormous gradients that would shove the knobs way too far and wreck progress. Clipping puts a ceiling on the overall size of the gradients, so one wild step cannot blow everything up. Think of it as a safety governor. - Take a step with
step(). Now we finally nudge every knob in the direction its gradient pointed. The tool that does the nudging is called the optimizer (here, one called AdamW), and the size of each nudge is set by the learning rate: small learning rate, gentle nudges; large learning rate, bold nudges.
The order genuinely matters. If you call step before backward, you nudge the knobs before you have worked out which way to nudge them. If you forget to zero, stale gradients contaminate the new ones. The frustrating part is that getting the order wrong rarely throws an error; training just quietly fails to improve. So commit these five to memory: zero, forward, backward, clip, step.
Here is something you can play with. Press the button to take steps, and watch the five stages of the pipeline light up in turn while the loss curve slides downward. Notice how the earliest steps have big gradients that get clipped, shown in orange, and how the clipping quietly stops once things settle.
Press to take a training step and watch the five moves light up in order. The loss falls quickly at first and then flattens, which is the normal shape you will see again and again. Gradient clipping fires on the big early steps (orange) and then rarely, and each completed step also advances the learning rate schedule.
A few traps worth remembering
Every idea here has a sharp edge that catches people the first time, so here are the ones worth keeping in your pocket.
How this shows up in a real language model
It is encouraging to know that the loop you just learned is not a toy. The training loop inside a real large language model is exactly these same five moves, with one extra wrinkle for memory.
That wrinkle is called gradient accumulation. Sometimes a batch of data (a batch is simply a group of examples processed together) is too large to fit in memory all at once. So we split it into smaller pieces called micro-batches, run the forward and backward pass on each piece, and let the gradients quietly add up across all of them before we take a single step. In code that looks like: call zero_grad(set_to_none=True) once, then loop over the micro-batches each doing loss = model(xb, yb)[1] / grad_accum; loss.backward(), then clip_grad_norm_(model.parameters(), 1.0), set the scheduled learning rate, and finally one optimizer.step(). To track progress, the loss is printed with .item(), which pulls the plain number out of PyTorch, and every so often a separate check runs to measure how the model is doing on data it was not trained on.
The whole thing, runnable
Here is the complete loop. If the symbols still look dense, do not worry; read the comments line by line and you will see it is just our five moves wrapped in a counter. Each line is doing one of the things we have already talked through.
for step in range(max_steps): # repeat the five moves many times
lr = schedule(step) # pick this step's learning rate (the nudge size)
for g in opt.param_groups: g["lr"] = lr # tell the optimizer to use that learning rate
opt.zero_grad(set_to_none=True) # MOVE 1: wipe last step's gradients to zero
for _ in range(grad_accum): # split the batch into smaller micro-batches
xb, yb = get_batch("train") # grab a micro-batch: inputs xb and correct answers yb
logits, loss = model(xb, yb) # MOVE 2: forward pass gives predictions and the loss
(loss / grad_accum).backward() # MOVE 3: backward pass fills every .grad, scaled to average
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # MOVE 4: cap the gradient size for safety
opt.step() # MOVE 5: nudge every parameter to lower the loss
if step % eval_interval == 0: # every so often, check our progress
print(step, estimate_loss()) # measured under eval() + no_grad(), with no learning