‘torchao‘ is a python library that support PyTorch native quantization and sparsity for training and inference. I just finished some experiments/tests with it for my image-classification project, which use CNN model by PyTorch. Below are some conclusions.
My project already used Automatic Mixed Precision of ‘bfloat16’, but the convert_to_float8_training
still easily reduce about 60% of the VRAM (on my RTX 4090 GPU):
Python
x
13
13
1
from torchao.float8 import convert_to_float8_training
2
3
def module_filter_fn(mod: torch.nn.Module, fqn: str) -> bool:
4
# Example: Exclude the output layer from float8 conversion
5
if fqn == "output":
6
return False
7
# Example: Exclude linear layers with dimensions not divisible by 16
8
if isinstance(mod, torch.nn.Linear):
9
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
10
return False
11
return True
12
13
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
AdamW8bit
could decrease the VRAM from 22.6GB to 22.4GB, not too much.
Didn't see any VRAM difference after usingCPUOffloadOptimize
r. Since it couldn't work well with learning-rate-scheduler. I tend to give up it.