0

Performance of Flash Attention and torch.compile()

 6 months ago
source link: https://donghao.org/2024/03/01/performance-of-flash-attention-and-torch-compile/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Performance of Flash Attention and torch.compile()

I am trying to build a small repo about multi-modal models (CLIP, ALBEF, BLIP etc). The GPT code is mainly from nanoGPT. Then I became inquisitive about the performance of “Flash Attention” and “torch.compile()”.

The metrics with my original code (w/o Flash Attention, w/o torch.compile()):

[100] loss: 4.0315 time 23.7708
[200] loss: 4.0020 time 23.9010
[300] loss: 3.8115 time 23.9407
[400] loss: 3.7021 time 23.9785
[500] loss: 3.6626 time 24.0076
[600] loss: 3.7109 time 24.0060
Python
[100] loss: 4.0315 time 23.7708
[200] loss: 4.0020 time 23.9010
[300] loss: 3.8115 time 23.9407
[400] loss: 3.7021 time 23.9785
[500] loss: 3.6626 time 24.0076
[600] loss: 3.7109 time 24.0060

The metrics after adding Flash Attention:

[100] loss: 4.1204 time 23.0655
[200] loss: 3.8950 time 23.2243
[300] loss: 3.9116 time 23.2714
[400] loss: 3.7837 time 23.2864
[500] loss: 3.8313 time 23.2993
[600] loss: 3.9138 time 23.3255
Python
[100] loss: 4.1204 time 23.0655
[200] loss: 3.8950 time 23.2243
[300] loss: 3.9116 time 23.2714
[400] loss: 3.7837 time 23.2864
[500] loss: 3.8313 time 23.2993
[600] loss: 3.9138 time 23.3255

The metrics after adding Flash Attention and torch.compile()

[100] loss: 3.9969 time 14.8842                                                                                               
[200] loss: 3.8506 time 15.0004                                                                                               
[300] loss: 3.8702 time 15.0050                               
[400] loss: 3.7977 time 15.0061                                                                                               
[500] loss: 3.7374 time 15.0492       
[600] loss: 3.6589 time 15.0661 
Python
[100] loss: 3.9969 time 14.8842                                                                                               
[200] loss: 3.8506 time 15.0004                                                                                               
[300] loss: 3.8702 time 15.0050                               
[400] loss: 3.7977 time 15.0061                                                                                               
[500] loss: 3.7374 time 15.0492       
[600] loss: 3.6589 time 15.0661 

Seems “torch.compile()” is much more powerful than “Flash Attention”

Related Posts


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK