Apex is a tool to enable mixed-precision training that comes from Nvidia.
import apex.amp as amp net, optimizer = amp.initialize(net, optimizer, opt_level="O2") # forward outputs = net(inputs) loss = criterion(outputs, targets) optimizer.zero_grad() # float16 backward with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() ... torch.save(net, "model.pth")
After I changed my code to use Apex, it reported an error when saving the model by using
AttributeError: Can't pickle local object '_initialize.<locals>.patch_forward.<locals>.new_fwd'
Someone has already noticed this problem but it seems no one wants to solve it: link. The only solution for this comes from a Chinese blog: link. It recommends just saving model parameters: