I am trying to build a small repo about multi-modal models (CLIP, ALBEF, BLIP etc). The GPT code is mainly from nanoGPT. Then I became inquisitive about the performance of “Flash Attention” and “torch.compile()”.

The metrics with my original code (w/o Flash Attention, w/o torch.compile()):

[100] loss: 4.0315 time 23.7708
[200] loss: 4.0020 time 23.9010
[300] loss: 3.8115 time 23.9407
[400] loss: 3.7021 time 23.9785
[500] loss: 3.6626 time 24.0076
[600] loss: 3.7109 time 24.0060

The metrics after adding Flash Attention:

[100] loss: 4.1204 time 23.0655
[200] loss: 3.8950 time 23.2243
[300] loss: 3.9116 time 23.2714
[400] loss: 3.7837 time 23.2864
[500] loss: 3.8313 time 23.2993
[600] loss: 3.9138 time 23.3255

The metrics after adding Flash Attention and torch.compile()

[100] loss: 3.9969 time 14.8842                                                                                               
[200] loss: 3.8506 time 15.0004                                                                                               
[300] loss: 3.8702 time 15.0050                               
[400] loss: 3.7977 time 15.0061                                                                                               
[500] loss: 3.7374 time 15.0492       
[600] loss: 3.6589 time 15.0661 

Seems “torch.compile()” is much more powerful than “Flash Attention”