Let’s get to the point directly:
import os import time import torch import torch.nn as nn import torch.distributed as dist from model import resnet152 from dataset import get_data_loaders from torch.nn.parallel import DistributedDataParallel as DDP learning_rate = 0.001 num_epochs = 40 momentum = 0.9 weight_decay = 1e-5 def setup(): # initialize the process group dist.init_process_group("nccl") def cleanup(): dist.destroy_process_group() def train(rank, world_size): setup() model = resnet152().to(rank) model = DDP(model) if rank == 0 and os.path.exists("last.pth"): obj = torch.load("last.pth") print(f"Rank{rank} load 'last.pth' with epoch: {obj['epoch']}") model.load_state_dict(obj["model"]) begin = obj["epoch"] + 1 else: begin = 0 print(f"Rank{rank} begin at {begin}") criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) start = time.time() running_loss = 0 trainloader, testloader = get_data_loaders(rank, world_size) for epoch in range(begin, num_epochs): trainloader.sampler.set_epoch(epoch) for index, (images, labels) in enumerate(trainloader): # gpu images, labels = images.to(rank), labels.to(rank) outputs = model(images) loss = criterion(outputs, labels) # backward and optimization optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() # train correct = 0 total = 0 with torch.no_grad(): for data in trainloader: images, labels = data # gpu images, labels = images.to(rank), labels.to(rank) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() trainset_accu = 100 * correct / total # test correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data # gpu images, labels = images.to(rank), labels.to(rank) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() testset_accu = 100 * correct / total if rank == 0: print( f"[{epoch}] Accu: {trainset_accu:.2f}%, {testset_accu:.2f}% \ | {(time.time() - start)/60.0:.1f} mins, loss: {running_loss}" ) torch.save(model.state_dict(), f"cifar100_{epoch}.pth") torch.save({"model": model.state_dict(), "epoch": epoch}, "last.pth") running_loss = 0.0 end = time.time() stopWatch = end - start print("Training is done") print("Total Training Time (second):", stopWatch) cleanup() if __name__ == "__main__": local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) train(local_rank, world_size)
The main training code comes from this notebook (really appreciate to @batuhan3526), and the code for the distributed environment is from here. I haven’t pasted the code for the dataset since this doc already gives sufficient introduction.
To run this snippet on two nodes (every node has two GPUs), I need to use the powerful “torchrun“:
torchrun \ --rdzv-backend=c10d \ --rdzv-endpoint=rogpt1:23456 \ --nnodes=1:2 \ --max-restarts=3 \ --nproc-per-node=2 \ train.py
For the above snippet, the Rank-0 process will save the checkpoint for each node. If one process fails, the whole cluster will restart and resume training from epoch + 1.
I tried letting only the Rank-0 process on node-0 save the checkpoint once. However since other nodes won’t have the checkpoint to load, the restart failed with a dead loop.