Distributed Data Parallel (DDP) for Training Models

Preface

Training large models on a single GPU can be painfully slow. PyTorch's Distributed Data Parallel (DDP) is the standard way to scale training across multiple GPUs (and even across nodes). It ensures that gradients are synchronized efficiently so your model converges just like single-GPU training — but much faster.

In this post, let’s walk through how DDP works conceptually, how to structure your code, and how batch size scaling plays into it. I’ll also show you how to implement a very simple version of DDP yourself, so you can see what’s happening under the hood.

How DDP works

The core idea behind DDP is to split the data between different GPUs. The model is replicated on each GPU and eventually the gradients are synchronized. It works as follows:

  • Replica per GPU: Each rank holds the same model on its own GPU.
  • Sharded data: Each rank gets a unique mini-batch shard.
  • Backward computes grads locally: Autograd produces param.grad tensors on each rank.
  • Bucketed sync after backward: As soon as a bucket of grads is ready, DDP syncs the bucket, averaging the gradients according the world_size.
  • Identical optimizer step: Because grads are identical everywhere, optimizer.step() updates all replicas the same way.
  • world_size = total number of participating processes (usually equal to number of GPUs).

DDP

Step-by-step

Initialize the distributed environment

This initializes the distributed environment:

import torch.distributed as dist
dist.init_process_group(backend="nccl")

Wrap the Model

PyTorch gives you an easy wrapper around your model:

from torch.nn.parallel import DistributedDataParallel as DDP

model = MyModel().to(device)
model = DDP(model, device_ids=[rank])

What this does:

  • Each rank gets its own model replica pinned to one GPU.
  • After .backward(), hooks automatically run all_reduce across ranks.
  • Gradients are averaged so every model replica has the same updated weights.

Distributing the Data

Each GPU should see a different shard of the dataset. That’s where DistributedSampler comes in:

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

train_dataset = MyDataset()
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    sampler=sampler
)

This ensures no two GPUs process the same examples in an epoch, and it keeps your effective batch size consistent. Let's put some numbers in, if each GPU processes 32 samples and world_size=4, your effective batch size is 128.

Training Loop

Here’s how a minimal loop looks:

for epoch in range(epochs):
    sampler.set_epoch(epoch)

    for batch in train_loader:
        optimizer.zero_grad()
        inputs, targets = batch
        loss = loss_fn(model(inputs), targets)
        loss.backward()
        optimizer.step()

What's happening under the hood

When you call .backward(), DDP automatically runs the following lines:

dist.all_reduce(grad, op=dist.ReduceOp.SUM)
grad /= world_size

The gradients are averaged so every model replica has the same updated weights. The optimizer.step() updates all replicas the same way.

Gradient accumulation

DDP syncs gradients on every backward call by default. Gradient accumulation is a manual trick you can combine with DDP to reduce communication overhead.
Using it, the gradients are accumulated (not synced) over multiple steps and synced only after the accumulation. This is useful to avoid expensive network communication which in distributed training tends to be the bottleneck.

DIY DDP Implementation

Here is a very simple (and not production ready!) implementation of DDP:

import torch.distributed as dist
from torch.autograd import register_hook

class MyDDP(torch.nn.Module):

    def __init__(self, module, world_size):
        super().__init__()
        self.module = module
        self.world_size = world_size
        for param in self.module.parameters():
            param.register_hook(self.sync_gradients)

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def sync_gradients(self, grad):
        dist.all_reduce(grad, op=dist.ReduceOp.SUM)
        grad /= self.world_size
        return grad

Key Takeaways

  • Wrap your model in DistributedDataParallel.
  • Use DistributedSampler to shard data.
  • Remember that effective batch size scales with number of GPUs.
  • Under the hood, DDP is just averaging gradients across ranks.
  • You can write your own simple DDP class once you understand the mechanics!

With these steps, you can scale your training across multiple GPUs and finish jobs much faster while maintaining correctness.

What's next

We've seen how to scale training across multiple GPUs using DDP. In the next post, we'll see how to make scaling during training even better using FSDP and ZeRO.