‘torchao‘ is a python library that support PyTorch native quantization and sparsity for training and inference. I just finished some experiments/tests with it for my image-classification project, which use CNN model by PyTorch. Below are some conclusions.
My project already used Automatic Mixed Precision of ‘bfloat16’, but the convert_to_float8_training
still easily reduce about 60% of the VRAM (on my RTX 4090 GPU):
from torchao.float8 import convert_to_float8_training def module_filter_fn(mod: torch.nn.Module, fqn: str) -> bool: # Example: Exclude the output layer from float8 conversion if fqn == "output": return False # Example: Exclude linear layers with dimensions not divisible by 16 if isinstance(mod, torch.nn.Linear): if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: return False return True convert_to_float8_training(m, module_filter_fn=module_filter_fn)
AdamW8bit
could decrease the VRAM from 22.6GB to 22.4GB, not too much.
Didn't see any VRAM difference after usingCPUOffloadOptimize
r. Since it couldn't work well with learning-rate-scheduler. I tend to give up it.