torchao‘ is a python library that support PyTorch native quantization and sparsity for training and inference. I just finished some experiments/tests with it for my image-classification project, which use CNN model by PyTorch. Below are some conclusions.

My project already used Automatic Mixed Precision of ‘bfloat16’, but the convert_to_float8_training still easily reduce about 60% of the VRAM (on my RTX 4090 GPU):

from torchao.float8 import convert_to_float8_training

def module_filter_fn(mod: torch.nn.Module, fqn: str) -> bool:
    # Example: Exclude the output layer from float8 conversion
    if fqn == "output":
        return False
    # Example: Exclude linear layers with dimensions not divisible by 16
    if isinstance(mod, torch.nn.Linear):
        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
            return False
    return True

convert_to_float8_training(m, module_filter_fn=module_filter_fn)


AdamW8bit could decrease the VRAM from 22.6GB to 22.4GB, not too much.

Didn't see any VRAM difference after using CPUOffloadOptimizer. Since it couldn't work well with learning-rate-scheduler. I tend to give up it.