Let’s get to the point directly:

import os
import time

import torch
import torch.nn as nn
import torch.distributed as dist

from model import resnet152
from dataset import get_data_loaders
from torch.nn.parallel import DistributedDataParallel as DDP

learning_rate = 0.001
num_epochs = 40
momentum = 0.9
weight_decay = 1e-5

def setup():
    # initialize the process group

def cleanup():

def train(rank, world_size):

    model = resnet152().to(rank)
    model = DDP(model)

    if rank == 0 and os.path.exists("last.pth"):
        obj = torch.load("last.pth")
        print(f"Rank{rank} load 'last.pth' with epoch: {obj['epoch']}")
        begin = obj["epoch"] + 1
        begin = 0
    print(f"Rank{rank} begin at {begin}")

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    start = time.time()
    running_loss = 0
    trainloader, testloader = get_data_loaders(rank, world_size)

    for epoch in range(begin, num_epochs):
        for index, (images, labels) in enumerate(trainloader):
            # gpu
            images, labels = images.to(rank), labels.to(rank)

            outputs = model(images)

            loss = criterion(outputs, labels)

            # backward and optimization
            running_loss += loss.item()

        # train
        correct = 0
        total = 0
        with torch.no_grad():
            for data in trainloader:
                images, labels = data

                # gpu
                images, labels = images.to(rank), labels.to(rank)

                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        trainset_accu = 100 * correct / total

        # test
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                # gpu
                images, labels = images.to(rank), labels.to(rank)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        testset_accu = 100 * correct / total
        if rank == 0:
                f"[{epoch}] Accu: {trainset_accu:.2f}%, {testset_accu:.2f}% \
                    | {(time.time() - start)/60.0:.1f} mins, loss: {running_loss}"
            torch.save(model.state_dict(), f"cifar100_{epoch}.pth")
            torch.save({"model": model.state_dict(), "epoch": epoch}, "last.pth")
        running_loss = 0.0

    end = time.time()
    stopWatch = end - start
    print("Training is done")
    print("Total Training Time (second):", stopWatch)

if __name__ == "__main__":
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    train(local_rank, world_size)

The main training code comes from this notebook (really appreciate to @batuhan3526), and the code for the distributed environment is from here. I haven’t pasted the code for the dataset since this doc already gives sufficient introduction.

To run this snippet on two nodes (every node has two GPUs), I need to use the powerful “torchrun“:

torchrun \
  --rdzv-backend=c10d \
  --rdzv-endpoint=rogpt1:23456 \
  --nnodes=1:2 \
  --max-restarts=3 \
  --nproc-per-node=2 \

For the above snippet, the Rank-0 process will save the checkpoint for each node. If one process fails, the whole cluster will restart and resume training from epoch + 1.

I tried letting only the Rank-0 process on node-0 save the checkpoint once. However since other nodes won’t have the checkpoint to load, the restart failed with a dead loop.