From DDP to ZeRO-2 and ZeRO-3

Preface

If you’ve ever trained or fine-tuned a large language model in PyTorch, you’ve probably started with Distributed Data Parallel (DDP). DDP is simple: replicate the whole model on every GPU, split the batch, and all-reduce the gradients. It scales compute across GPUs, but every GPU holds a full copy of the parameters, gradients, and optimizer state. That’s fine for GPT-2, but a dead end for 70B-parameter models.

This is where ZeRO (Zero Redundancy Optimizer) comes in. Instead of replicating everything, it shards model states across GPUs. There are three stages. In practice, ZeRO-2 helps when optimizer states + grads are the problem, while ZeRO-3 is what you reach for when the parameters themselves don’t fit.

ZeRO-2: Sharding grads and optimizer states

In ZeRO-2:

  • Parameters: still fully replicated on every GPU.
  • Gradients: are reduced into shards after backward.
  • Optimizer states: also sharded, so each GPU only keeps its piece.

Memory win:
you eliminate duplicated gradients and optimizer states.

Comms cost:
Still similar to DDP, but uses reduce-scatter so adds some overhead compared to all-reduce.

🎯 Use ZeRO-2 when:

  • Your model weights fit per GPU.
  • But optimizer states + gradients blow up memory.

ZeRO-3: Sharding parameters too

In ZeRO-3, the big leap is that parameters themselves are sharded:

  • Forward: before computing a layer (say q_proj), GPUs do an all-gather of parameter shards so every rank has the full weight just in time. It's then released immediately.
  • Backward: gradients are reduce-scattered back into shards so each GPU only keeps its slice.
  • Optimizer states: remain sharded.

Memory win:
No GPU ever stores the full parameter set, only shards. This is why ZeRO-3 enables training/fine-tuning models with hundreds of billions of parameters.

Comms cost:
every layer requires all-gather (forward) and reduce-scatter (backward). This cost is meaningful, and for this reason both techniques exist. Note: If your GPUs are only connected over PCIe, this will dominate runtime. On NVLink, it’ll be a lot faster.

🎯 Use ZeRO-3 when:

  • The model weights themselves don’t fit.
  • You’re willing to trade extra communication for being able to run at all.

Memory breakdown: DDP vs ZeRO-2 vs ZeRO-3

For the rest of this post, let’s use a concrete reference point: Qwen3-4B-Thinking-2507.

  • Params: ~4B × 2 bytes (bf16) ≈ 8 GB.
  • Gradients: same size ≈ 8 GB.
  • Adam optimizer states: 2×params = 16 GB.
  • Total model state = 32 GB.

We’ll also assume you’re on 2×A10G GPUs (24 GB each).

DDP (replicate all)

  • Params: 8 GB
  • Gradients: 8 GB
  • Optimizer states: 16 GB
  • Total per GPU: 32 GB

❌ That’s more than a single A10G can hold. Training fails before you even start.

ZeRO-2 (shard grads + opt)

  • Params: 8 GB (still fully replicated)
  • Gradients: 8 GB ÷ 2 GPUs = 4 GB
  • Optimizer states: 16 GB ÷ 2 GPUs = 8 GB
  • Total per GPU: 20 GB

⚠️ This technically fits into 24 GB VRAM, but leaves very little headroom for activations, dataloader buffers, etc.

ZeRO-3 (shard everything)

  • Params: 8 GB ÷ 2 GPUs = 4 GB
  • Gradients: 8 GB ÷ 2 GPUs = 4 GB
  • Optimizer states: 16 GB ÷ 2 GPUs = 8 GB
  • Total per GPU: 16 GB

✅ This leaves ~8 GB free on each A10G for activations, buffers, and dataloader overhead. It’s the only setup here that’s actually comfortable for fine-tuning.

Code: Training with FSDP

You’ll often see people mention FSDP (Fully Sharded Data Parallel) instead of ZeRO. FSDP (v2 in this case) is PyTorch’s native implementation of ZeRO-3. That's exactly what we're doing here. Eventually we'll run this with torchrun.

Code Snippet

import torch
from torch.distributed.fsdp import fully_shard
from model import Transformer
...

model = Transformer()
for layer in model.layers:
    fully_shard(layer)
fully_shard(model)

optim = torch.optim.Adam(model.parameters(), lr=1e-2)
loss_fn = torch.nn.CrossEntropyLoss()

for _ in range(epochs):
    for xb, yb in dl:
        out = model(xb)
        loss = loss_fn(out, yb)
        loss.backward()
        optim.step()
        optim.zero_grad()

Run training

torchrun --nproc_per_node=2 train.py

Takeaways

  • DDP: replicate everything, all-reduce grads. Simple, but memory heavy.
  • ZeRO-2: shard grads + optimizer states. Saves memory if params fit.
  • ZeRO-3: shard everything. Saves massive memory, but introduces per-layer all-gather + reduce-scatter.

What's next

I think I'm done with these techniques for now. The last part left in this series is to see how to scale inference across multiple GPUs. We'll look at that next.