Training across many GPUs without losing your mind
When one GPU is too slow, the trick is wonderfully simple: put a full copy of your model on every GPU, feed each copy a different slice of the data, and then average their lessons so all the copies stay perfectly in sync. Here is the whole idea, built up slowly, assuming you have never touched PyTorch.
First, the problem we are trying to solve
Training a big model means showing it mountains of examples, one batch at a time, and nudging it to do a little better after each batch. A batch is just a handful of examples bundled together, and a GPU is the specialised chip that does the heavy number crunching. On a single GPU, that nudging happens at one fixed speed, and for a large model it can take days or even weeks. At some point you cannot make the chip any faster, so the only way forward is to use more than one.
Here is the natural question. If you had eight GPUs sitting in front of you, how would you get them to share the work without each one quietly drifting off and learning something different? That is exactly the puzzle this article solves, and the answer is more elegant than you might expect.
Before we dive in, one more word worth pinning down. When the model studies a batch, it produces a gradient, which is simply a set of numbers saying “to do better, turn this knob up a bit and that knob down a bit.” Every knob inside the model is called a parameter, and the gradient is the model’s advice for how to adjust all of them. That is the whole vocabulary you need to follow along.
Copy the model, split the data, then share the lessons
The strategy has a name, DistributedDataParallel, which people shorten to DDP. Despite the long name, the idea fits in three short steps, and once it clicks you will see it everywhere.
First we replicate. We place an identical copy of the model on each GPU. Each copy is called a rank, which is just a numbered worker: rank 0, rank 1, and so on. At the start, all the copies are exactly the same.
Second we scatter the data. The big batch gets sliced up, and each rank receives a different slice. So with eight GPUs, each one studies one eighth of the batch at the same moment. Every rank runs its own forward pass (making a prediction) and its own backward pass (working out its gradient), all in parallel. Because each rank saw different examples, each one comes back with a slightly different gradient. This is the moment where, if we did nothing, the copies would start to disagree and drift apart.
Third, and this is the heart of it, we do an all-reduce. Every rank shares its gradient with all the others, the gradients get summed and then averaged, and that one shared average is handed back to every rank. The word “reduce” here just means “combine many numbers into one,” and “all” means “and give the result to everyone.” Now every copy holds the identical averaged gradient, so when each one adjusts its parameters, they all make the exact same adjustment. They stay perfectly in step.
That last part is what makes the whole thing trustworthy. Because the copies always step together, they never diverge, and at the end of training all eight ranks hold the same finished model. You have effectively trained on eight times as much data per step, which is why DDP multiplies your speed by the number of GPUs you own.
If this feels like a lot of moving parts, that is normal. The interactive panel below lets you watch it happen instead of just reading about it. Choose how many GPUs you want, then take a step: each rank will show you the gradient it computed from its own slice of data. Press all-reduce and watch every disagreeing number collapse into one shared average before the step is taken.
Before all-reduce the ranks hold different numbers, because each one studied different data. After all-reduce they all match, and that agreement is exactly what keeps the copies from drifting apart. Notice too that the amount of data you cover per step is the per-GPU batch size multiplied by the number of GPUs.
A few traps worth remembering
This idea is simple at heart, but it has a few sharp edges that catch almost everyone the first time. Here are the ones worth keeping in your pocket.
How this shows up in a real language model project
In a real codebase you do not start the program the usual way. Instead you launch it with a helper called torchrun, for example torchrun --nproc_per_node=8, which means “start eight copies, one per GPU.” When each copy wakes up, it reads three values from its environment to learn who it is: RANK (its global worker number), LOCAL_RANK (which GPU on this machine it should use), and WORLD_SIZE (how many workers there are in total).
Each copy then calls init_process_group('nccl'), which opens the communication channel the ranks use to talk to each other. It moves its model onto its assigned GPU with cuda:local_rank, and wraps it in DDP(model, device_ids=[local_rank]). A helper called DistributedSampler takes care of the data slicing, making sure each rank sees a different set of examples. From there it is almost magic: the backward pass triggers the gradient all-reduce automatically, so you never have to call it by hand. And because all ranks end up identical, you only need to save the model and write log messages from one of them, conventionally rank 0, so you do not get eight copies of every file.
The whole thing, runnable
Here is a compact version of a real DDP training loop. You would launch it with torchrun as described above, not by running it directly. Read the comments as you go, they walk through what each line is doing.
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group("nccl") # open the channel the GPUs use to talk
rank, local = int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"]) # who am I?
torch.cuda.set_device(local) # claim this worker's own GPU
model = Model().to(local) # put a full copy of the model on that GPU
model = DDP(model, device_ids=[local]) # replicate it and hook up the automatic all-reduce
for xb, yb in loader: # DistributedSampler gives each rank a different slice
loss = model(xb.to(local), yb.to(local))[1]
loss.backward() # gradients are averaged across all ranks automatically
opt.step(); opt.zero_grad(set_to_none=True) # every rank takes the identical step
if rank == 0: # save and log from just one worker, to avoid duplicates
torch.save(model.module.state_dict(), "ckpt.pt")