From DDP to ZeRO-2 and ZeRO-3
Preface
If you’ve ever trained or fine-tuned a large language model in PyTorch, you’ve probably started with Distributed Data Parallel (DDP). DDP is simple: replicate the whole model on every GPU, split the batch, and all-reduce the gradients. It scales compute across GPUs, but every GPU holds a full copy of the parameters, gradients, and optimizer state. That’s fine for GPT-2, but a dead end for 70B-parameter models.
This is where ZeRO (Zero Redundancy Optimizer) comes in. Instead of replicating everything, it shards model states across GPUs. There are three stages. In practice, ZeRO-2 helps when optimizer states + grads are the problem, while ZeRO-3 is what you reach for when the parameters themselves don’t fit.
ZeRO-2: Sharding grads and optimizer states
In ZeRO-2:
- Parameters: still fully replicated on every GPU.
 - Gradients: are reduced into shards after backward.
 - Optimizer states: also sharded, so each GPU only keeps its piece.
 
✨ Memory win:
you eliminate duplicated gradients and optimizer states.
⚡ Comms cost:
Still similar to DDP, but uses reduce-scatter so adds some overhead compared to all-reduce.
🎯 Use ZeRO-2 when:
- Your model weights fit per GPU.
 - But optimizer states + gradients blow up memory.
 
ZeRO-3: Sharding parameters too
In ZeRO-3, the big leap is that parameters themselves are sharded:
- Forward: before computing a layer (say 
q_proj), GPUs do an all-gather of parameter shards so every rank has the full weight just in time. It's then released immediately. - Backward: gradients are reduce-scattered back into shards so each GPU only keeps its slice.
 - Optimizer states: remain sharded.
 
✨ Memory win:
No GPU ever stores the full parameter set, only shards. This is why ZeRO-3 enables training/fine-tuning models with hundreds of billions of parameters.
⚡ Comms cost:
every layer requires all-gather (forward) and reduce-scatter (backward).
This cost is meaningful, and for this reason both techniques exist.
Note: If your GPUs are only connected over PCIe, this will dominate runtime. On NVLink, it’ll be a lot faster.
🎯 Use ZeRO-3 when:
- The model weights themselves don’t fit.
 - You’re willing to trade extra communication for being able to run at all.
 
Memory breakdown: DDP vs ZeRO-2 vs ZeRO-3
For the rest of this post, let’s use a concrete reference point: Qwen3-4B-Thinking-2507.
- Params: ~4B × 2 bytes (bf16) ≈ 8 GB.
 - Gradients: same size ≈ 8 GB.
 - Adam optimizer states: 2×params = 16 GB.
 - Total model state = 32 GB.
 
We’ll also assume you’re on 2×A10G GPUs (24 GB each).
DDP (replicate all)
- Params: 8 GB
 - Gradients: 8 GB
 - Optimizer states: 16 GB
 - Total per GPU: 32 GB
 
❌ That’s more than a single A10G can hold. Training fails before you even start.
ZeRO-2 (shard grads + opt)
- Params: 8 GB (still fully replicated)
 - Gradients: 8 GB ÷ 2 GPUs = 4 GB
 - Optimizer states: 16 GB ÷ 2 GPUs = 8 GB
 - Total per GPU: 20 GB
 
⚠️ This technically fits into 24 GB VRAM, but leaves very little headroom for activations, dataloader buffers, etc.
ZeRO-3 (shard everything)
- Params: 8 GB ÷ 2 GPUs = 4 GB
 - Gradients: 8 GB ÷ 2 GPUs = 4 GB
 - Optimizer states: 16 GB ÷ 2 GPUs = 8 GB
 - Total per GPU: 16 GB
 
✅ This leaves ~8 GB free on each A10G for activations, buffers, and dataloader overhead. It’s the only setup here that’s actually comfortable for fine-tuning.
Code: Training with FSDP
You’ll often see people mention FSDP (Fully Sharded Data Parallel) instead of ZeRO. FSDP (v2 in this case) is PyTorch’s native implementation of ZeRO-3. That's exactly what we're doing here. Eventually we'll run this with torchrun.
Code Snippet
import torch
from torch.distributed.fsdp import fully_shard
from model import Transformer
...
model = Transformer()
for layer in model.layers:
    fully_shard(layer)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
loss_fn = torch.nn.CrossEntropyLoss()
for _ in range(epochs):
    for xb, yb in dl:
        out = model(xb)
        loss = loss_fn(out, yb)
        loss.backward()
        optim.step()
        optim.zero_grad()
Run training
torchrun --nproc_per_node=2 train.py
Takeaways
- DDP: replicate everything, all-reduce grads. Simple, but memory heavy.
 - ZeRO-2: shard grads + optimizer states. Saves memory if params fit.
 - ZeRO-3: shard everything. Saves massive memory, but introduces per-layer all-gather + reduce-scatter.
 
What's next
I think I'm done with these techniques for now. The last part left in this series is to see how to scale inference across multiple GPUs. We'll look at that next.