flash attention

Flash Attention

Flash Attention

  • parallelism
    parallelize over batch_size and num of heads
    flash attention2 - long sequences(small batch size or num of heads), parallelize over sequence length dimension
  • better work partition
    reduce the amount of synchronization and communication between different warps
    FlashAttention splits K and V across 4 warps while keeping Q accessible by all warps.
    FlashAttention-2 splits Q across 4 warps while keeping K and V accessible by all warps.
    image
  • supported head dimensions up to 256 and MQA
    GQA、MQA

Flash Attention2 优化

image

Flash Attention2优化点详解

  • Fused kernel与矩阵分块
  • Causal Masking
  • Non-Matmul 计算优化
  • 流水编排与异步加载和Double Buffer
  • Layout Swizzle