Attention chuẩn tính S = Q·Kᵀ / √dₖ → P = softmax(S) → O = P·V. Bottleneck không phải FLOPs mà là HBM memory I/O: ma trận S và P có kích thước N×N (N là seq length), phải ghi/đọc từ HBM nhiều lần → chậm vì HBM bandwidth < SRAM.
Flash Attention (Dao et al. 2022, v2 2023, v3 2024) tối ưu I/O bằng 2 kỹ thuật:
1. Tiling — chia Q, K, V thành block nhỏ đủ vừa SRAM (on-chip cache). Load block K,V → tính attention một phần → cộng dồn vào output. Không bao giờ materialize ma trận N×N đầy đủ trong HBM.
2. Online softmax + recomputation — dùng công thức softmax ổn định tính incrementally khi lướt qua các block; backward pass recompute attention từ output + stats nhỏ thay vì lưu intermediate → giảm memory O(N²) xuống O(N).
Kết quả bit-exact (kết quả giống attention chuẩn về mặt số) nhưng:
- Training 2-4x nhanh hơn.
- Memory giảm từ O(N²) về O(N) → fit được context dài hơn 4-16x.
- Inference prefill nhanh hơn đáng kể.
Flash Attention 2 tối ưu parallelism (work partitioning giữa warp); Flash Attention 3 exploit H100 Tensor Core (FP8, async). Hiện default trong PyTorch scaled_dot_product_attention, vLLM, SGLang, TGI. Không áp dụng được cho architecture attention đặc biệt như sliding window, Dilated attention trừ khi có variant riêng.