FlashAttention
Inference Optimization
A hardware-aware exact attention algorithm that computes attention with O(N) memory instead of O(N²) by leveraging GPU memory hierarchy (SRAM vs. HBM tiling) — delivering 2-4x wall-clock speedups with no approximation or quality loss.
Standard attention computes Q×K^T (an N×N matrix), applies softmax, then multiplies by V. This materializes the full N×N attention matrix in GPU HBM (slow global memory). FlashAttention (Dao et al., 2022) tiles the computation: it loads blocks of Q, K, V into SRAM (fast on-chip memory), computes partial attention for each block using online softmax (numerically stable incremental computation), and never materializes the full attention matrix. Result: O(N) memory, ~2-4x faster wall-clock time, numerically identical output. FlashAttention-2 (2023) improves parallelism and work partitioning. FlashAttention-3 (2024) adds FP8 support and targets Hopper GPUs. Available in PyTorch 2.0+ as `torch.nn.functional.scaled_dot_product_attention` with the `flash` backend.
Why Does This Exist?
FlashAttention delivers 2-4x wall-clock speedup and O(N) memory instead of O(N²) for the attention computation that dominates transformer inference and training cost. Because it is exact (not approximate), it provides pure efficiency gains with zero quality tradeoff. Its ubiquitous adoption means every modern model benefits from it — it effectively lowered the cost floor for the entire field.
By reducing attention's memory footprint from O(N²) to O(N), FlashAttention made long context windows practically feasible. Models with 128K-1M token contexts (GPT-4, Claude, Gemini) would be economically unviable without it. Longer context directly enables new capabilities: processing entire codebases, books, document collections, and multi-turn agent interactions that were previously impossible.