Saving and loading checkpoints so you never lose a training run
A checkpoint is much more than the model's weights. To pick up a training run exactly where it stopped, you also need the optimizer's memory, the step counter, and the random number state. We will build that complete bundle together, save it, load it back, and see why each piece matters, all assuming you have never touched PyTorch.
First, why this matters
Training a neural network can take hours, days, or for the biggest models, weeks. Now imagine your machine reboots in the middle of that, or you simply want to stop for the night and continue tomorrow. If you cannot save your progress, you start again from scratch every time, and that is heartbreaking. A checkpoint is the answer: it is a snapshot of everything the run needs to continue, written to a file on disk so you can come back later and resume as if nothing happened.
Before we go further, a couple of words that will keep coming up, defined in plain language. A model is the network you are training, the thing that makes predictions. Its weights, also called parameters, are the many numbers inside it that get nudged a little on every training step until the model gets good at its job. A tensor is just PyTorch’s name for an array of numbers, like a list or a grid of values. That is all you need for now. If any of this still feels fuzzy, that is completely normal, and it will settle as we go.
The surprising part, and the heart of this whole article, is that saving the weights alone is not enough to resume training cleanly. There are a few other ingredients, and leaving any of them out causes quiet, confusing problems later. Let us see what they are.
The state_dict is the thing we actually save
In PyTorch, the model hands you its weights through a method called model.state_dict(). Think of state_dict as a labelled drawer of every parameter in the model. It is a dictionary (a lookup table of names to values) that maps each parameter’s name, like layer1.weight, to the tensor holding its numbers. The names come straight from the structure of your model, so they read like a little address for each piece.
Once you have that drawer, the function torch.save writes it to a file, and later load_state_dict reads it back and matches each saved tensor to the right spot in the model by its name and its shape. So far so good. If all you want is to run the trained model and make predictions, those saved weights are genuinely all you need.
But to resume training, you need more, and here is the reason. The thing that updates your weights each step is called the optimizer. A very popular one, Adam, does not just look at the current step. It keeps a running memory of recent gradients for every parameter, two extra numbers per parameter (often called the moments, or m and v), which let it take smoother, smarter steps. A gradient, by the way, is simply the direction and amount each weight should change to reduce the model’s error. If you reload the weights but throw away Adam’s memory, the optimizer wakes up with amnesia and stumbles for a while before finding its footing again.
So a complete, resumable checkpoint usually holds four things: the model’s weights, the optimizer’s state (Adam’s m and v), the current step (how many updates you have done so far, which also tells the learning rate schedule where it is), and ideally the RNG state, the random number generator’s position, so that shuffling and any randomness continue exactly as before.
Here is something you can play with directly. Toggle which pieces you include in the checkpoint, then try to resume the run. Leave out the optimizer’s state and the run still continues, but watch the first steps after resuming go wrong, because Adam’s memory restarts cold.
Saving only the weights is perfect when you just want to run the model and make predictions. For an exact resume, you want every box checked, otherwise the optimizer's momentum, the learning rate schedule, and the data order all quietly drift away from your original run.
A few traps worth remembering
Every idea here has a sharp edge that catches people the first time. Here are the ones to keep in your pocket.
How this shows up in a real language model
This is not a toy idea you will outgrow. In a real training setup, the trainer saves a bundle of {model, optimizer, step, config} to disk every so often, plus a separate best checkpoint whenever the model’s score on held-out data improves. When you resume, it loads each piece back into place: the weights with model.load_state_dict(ckpt['model']), the optimizer’s memory with opt.load_state_dict(ckpt['optimizer']), and the step counter so the learning rate schedule continues from the right spot.
A couple of practical touches show up here too. When training is spread across many GPUs at once, only one of them (the one labelled rank 0) actually writes the file, so you do not save the same thing many times over. And when loading, it is common to read the checkpoint onto the CPU first using a setting called map_location, then move the model onto the GPU afterward. This avoids a class of confusing errors about which device the tensors landed on. If those last details feel advanced, do not worry, the core idea (bundle everything, save it, load it back) is what really matters.
The whole thing, runnable
If you have PyTorch installed, you can adapt this into your own training script. Read the comments as you go, they walk through exactly what each line does.
import torch
# Save a full, resumable checkpoint: bundle every piece the run needs into one dict.
torch.save({
"model": model.state_dict(), # the network's weights (its learned numbers)
"optimizer": opt.state_dict(), # Adam's memory: the m and v moments per parameter
"step": step, # how many updates we have done so far
"config": cfg.__dict__, # the run's settings, handy for sanity checks later
}, "ckpt.pt") # write it all to a file on disk
# Resume: read the bundle back and put each piece where it belongs.
ckpt = torch.load("ckpt.pt", map_location="cpu") # load onto CPU first to avoid device surprises
model.load_state_dict(ckpt["model"]) # restore the weights
opt.load_state_dict(ckpt["optimizer"]) # restore Adam's memory so it picks up smoothly
start_step = ckpt["step"] + 1 # continue counting from the next step
model.to(device) # now move the model onto the GPU (or chosen device)
# Inference only: if you just want to run the model, the weights alone are enough.
torch.save(model.state_dict(), "weights.pt")