torchao‘ is a python library that support PyTorch native quantization and sparsity for training and inference. I just finished some experiments/tests with it for my image-classification project, which use CNN model by PyTorch. Below are some conclusions.

My project already used Automatic Mixed Precision of ‘bfloat16’, but the convert_to_float8_training still easily reduce about 60% of the VRAM (on my RTX 4090 GPU):

Python


AdamW8bit could decrease the VRAM from 22.6GB to 22.4GB, not too much.

Didn't see any VRAM difference after using CPUOffloadOptimizer. Since it couldn't work well with learning-rate-scheduler. I tend to give up it.