Attention chuẩn tính S = Q·Kᵀ / √dₖ → P = softmax(S) → O = P·V. Bottleneck không phải FLOPs mà là HBM memory I/O: ma trận S và P có kích thước N×N (N là seq length), phải ghi/đọc từ HBM nhiều lần → chậm vì HBM bandwidth < SRAM.
Flash Attention (Dao et al. 2022, v2 2023, v3 2024) tối ưu I/O bằng 2 kỹ thuật:
1. Tiling — chia Q, K, V thành block nhỏ đủ vừa SRAM (on-chip cache). Load block K,V → tính attention một phần → cộng dồn vào output. Không bao giờ materialize ma trận N×N đầy đủ trong HBM.
2. Online softmax + recomputation — dùng công thức softmax ổn định tính incrementally khi lướt qua các block; backward pass recompute attention từ output + stats nhỏ thay vì lưu intermediate → giảm memory O(N²) xuống O(N).
Kết quả bit-exact (kết quả giống attention chuẩn về mặt số) nhưng:
- Training 2-4x nhanh hơn.
- Memory giảm từ O(N²) về O(N) → fit được context dài hơn 4-16x.
- Inference prefill nhanh hơn đáng kể.
Flash Attention 2 tối ưu parallelism (work partitioning giữa warp); Flash Attention 3 exploit H100 Tensor Core (FP8, async). Hiện default trong PyTorch scaled_dot_product_attention, vLLM, SGLang, TGI. Không áp dụng được cho architecture attention đặc biệt như sliding window, Dilated attention trừ khi có variant riêng.
Standard attention computes S = Q·Kᵀ / √dₖ → P = softmax(S) → O = P·V. The bottleneck is not FLOPs but HBM memory I/O: the S and P matrices are N×N (N = seq length), read/written to HBM many times → slow because HBM bandwidth < SRAM.
Flash Attention (Dao et al. 2022, v2 2023, v3 2024) optimizes I/O with 2 techniques:
1. Tiling — split Q, K, V into small blocks fitting in SRAM (on-chip cache). Load a K,V block → compute partial attention → accumulate into output. Never materializes the full N×N matrix in HBM.
2. Online softmax + recomputation — uses a numerically stable softmax computed incrementally across blocks; the backward pass recomputes attention from output + small stats instead of saving intermediates → memory drops from O(N²) to O(N).
Result is bit-exact (same as standard attention numerically) but:
- Training 2–4x faster.
- Memory drops from O(N²) to O(N) → fits 4–16x longer contexts.
- Inference prefill significantly faster.
Flash Attention 2 improves parallelism (warp-level work partitioning); v3 exploits H100 Tensor Cores (FP8, async). Default in PyTorch scaled_dot_product_attention, vLLM, SGLang, TGI. Doesn't directly apply to specialized attentions like sliding window or dilated without variants.