For Kids: How FlashAttention Works 🧠
Regular Attention (Like Making Pizza 🍕)
Imagine making pizza for your whole class:
- Ask everyone what toppings they want (Query)
- Check all possible toppings (Key)
- Make a giant list of everyone's choices (big Score matrix)
- Choose most popular toppings (Softmax)
- Combine them into final pizzas (Value)
This takes lots of paper (memory) to track all choices!
FlashAttention (Smart Pizza Making)
FlashAttention is smarter:
- Split class into small groups 🧑🤝🧑
- Track favorites per group 📝
- Keep running total of toppings ➕
- Combine group totals at end 🤝
Uses less paper and finishes faster! 🚀
FlashAttention 2 (Supercharged Version)
Even better improvements:
- Organize toppings first 🥦🍗
- Reduce repetitive counting 🔄
- Work on multiple groups at once 👯
2-3x faster than original FlashAttention! ⚡
For Scientists: Mathematical Details 🔬
Standard Attention
For input matrices \( Q, K, V \in \mathbb{R}^{N \times d} \):
\[ S = \frac{QK^T}{\sqrt{d}}, \quad P = \text{softmax}(S), \quad O = PV \]Memory: \( O(N^2) \) for storing \( S \) and \( P \).
FlashAttention Algorithm
Tiling with block matrices \( Q_i, K_j, V_j \):
- Compute block scores: \[ S_{ij} = \frac{Q_i K_j^T}{\sqrt{d}} \]
- Online softmax with: \[ m_{ij} = \max(m_{i(j-1)}, \text{rowmax}(S_{ij})) \] \[ l_{ij} = e^{m_{i(j-1)} - m_{ij}} l_{i(j-1)} + \sum e^{S_{ij} - m_{ij}} \]
- Output blocks: \[ O_i = \sum_j \frac{e^{S_{ij} - m_{ij}}}{l_{ij}} V_j \]
Memory: \( O(N) \) using streaming.
FlashAttention 2 Optimizations
Key improvements:
- Loop reordering over \( K,V \) blocks first
- Reduced non-matmul FLOPs: \[ \text{FLOPs}_{\text{FA2}} = \frac{\text{FLOPs}_{\text{FA1}}}{1 + \frac{\text{Non-Matmul}}{\text{Matmul}}} \]
- Parallel computation strategy
Backward Pass
Gradient computation via recomputation:
\[ \frac{\partial L}{\partial Q} = \frac{1}{\sqrt{d}} \left( \text{softmax}(S) \odot (dV) \right) K \]Only store block statistics \( m_{ij}, l_{ij} \).
Comments
Post a Comment