I am trying to build a small repo about multi-modal models (CLIP, ALBEF, BLIP etc). The GPT code is mainly from nanoGPT. Then I became inquisitive about the performance of “Flash Attention” and “torch.compile()”.
The metrics with my original code (w/o Flash Attention, w/o torch.compile()):
[100] loss: 4.0315 time 23.7708 [200] loss: 4.0020 time 23.9010 [300] loss: 3.8115 time 23.9407 [400] loss: 3.7021 time 23.9785 [500] loss: 3.6626 time 24.0076 [600] loss: 3.7109 time 24.0060
The metrics after adding Flash Attention:
[100] loss: 4.1204 time 23.0655 [200] loss: 3.8950 time 23.2243 [300] loss: 3.9116 time 23.2714 [400] loss: 3.7837 time 23.2864 [500] loss: 3.8313 time 23.2993 [600] loss: 3.9138 time 23.3255
The metrics after adding Flash Attention and torch.compile()
[100] loss: 3.9969 time 14.8842 [200] loss: 3.8506 time 15.0004 [300] loss: 3.8702 time 15.0050 [400] loss: 3.7977 time 15.0061 [500] loss: 3.7374 time 15.0492 [600] loss: 3.6589 time 15.0661
Seems “torch.compile()” is much more powerful than “Flash Attention”