flash attention
Flash Attention
Flash Attention
- parallelism
parallelize over batch_size and num of heads
flash attention2 - long sequences(small batch size or num of heads), parallelize over sequence length dimension - better work partition
reduce the amount of synchronization and communication between different warps
FlashAttention splits K and V across 4 warps while keeping Q accessible by all warps.
FlashAttention-2 splits Q across 4 warps while keeping K and V accessible by all warps. - supported head dimensions up to 256 and MQA
GQA、MQA
Flash Attention2 优化
Flash Attention2优化点详解
- Fused kernel与矩阵分块
- Causal Masking
- Non-Matmul 计算优化
- 流水编排与异步加载和Double Buffer
- Layout Swizzle