Understanding Flash Attention - Fueling Large language Models

Yashu Gupta
6 min readJul 25, 2023

--

Detailed Understanding of Flash Attention which is getting widely used in many of the Large Language Models.

👨🏾‍💻Github | Linkedin | GEN AI Chat GPT Detailed Architecture | QLORA and finetuning Falcon7B| Langchain QA | Transformers

Generative AI Models or Large Language Models (LLM) such as OpenAI’s GPT or Falcon 40B, Llama 2 are taking the world by storm. If you are new to this field you can refer to my different blog post pinned on the top. In this post we’ll focus on Flash Attention (v1,v2) and ALiBi which is getting widely used in many of the LLM. Let us start first by Flash attention with some background.

Introduction

As we know the original Transformer architecture uses self attention to capture contextual dependencies between different elements in a sequence effectively. Many Language Models and even Large language models uses self attention for capturing the contextual dependencies as it the heart of the transformer architecture. But there are few drawback’s of self attention. As we know all Language models and LLM has to maintain the fixed sequence length (512 for BERT and 4k for GPT 3) this is due to the reason that self-attention scales quadratically with respect to the input sequence length. As the sequence length increases, the number of pairwise interactions grows rapidly, making self-attention computationally expensive for long sequences which is n * n = n^2.To reduce complexity various methods have proposed. These methods reduce the complexity to Linear wrt to the Sequence Length. One of the method is Flash Attention which is widely adopted by various models like Falcon. Flash attention brings down the computation time of attention from quadratic to linear. By implementing Flash Attention training time can be reduced. What the heck is this Flash attention let us understand in detail.

Flash Attention: Fast and Memory -Efficient Attention with IO-Awareness

Before understanding the Flash we need to understand how self attention interacts with hardware which leads to N² complexity. Modern GPU has different memory with different speed and sizes. GPU has HBM (High Bandwidth Memory) and on chip SRAM Memory.

From the above image it is clear that SRAM is faster than the HBM, but SRAM has very much limited memory as its fast cache memory. In the above image you can compare the speed of HBM and SRAM. HBM is used to store Tensors, while SRAM is used to perform compute operations on those Tensors. When it comes to standard self attention, It performs 3 different operations which is

  1. Read op — Move tensor from HBM to SRAM
  2. Compute op - Perform compute intensive task on SRAM
  3. write op - move tensor back from SRAM to HBM

As we know in the vanilla transformer block various operations get perform like Matmul, Relu, SoftMax, Masking etc. Let us understand how these operations get performed when it comes to the Memory.

Standard attention with multiple read/writes ops

From the above image it is clear that multiple read and write operations are happening from HBM to SRAM and vice versa.

  1. Read op — Loading the query and key vector to SRAM
  2. compute op- Computation of matmul in SRAM
  3. write op — Writing back the results to HBM

Likewise multiple operations get perform. This is where flash attention comes to rescue. Flash attention is io-aware. It is designed in a way that multiple redundant HBM read/writes can be removed.Flash attention proposed Read once then perform the compute operations and then write back to SRAM. Below is the image for the reference .

on the left is the standard attention mechanism where multiple read writes are happening between 2 memory and Flash attention on the right which performs read/write only once

The paper proposed a fused kernel, this reduces the number of memory operations, which translates into a large speed-up during Training.

Besides implementing a fused kernel, paper makes another contribution .

  1. Tiling -: which means breaking the N*N matrix into multiple blocks
  2. Recomputation -: used only in backward pass only .

Let us understand how Tiling works. As the mentioned Tiling means breaking or chopping the the matrix into multiple block and then performing the computation in SRAM instead of passing a complete N*N matrix from HBM to SRAM.As we know SRAM is very limited in capacity. We can’t just load the whole thing.

Through tiling, flash attention splits the inputs Q, K, and V into blocks, then loads them from the slow HBM to the fast SRAM, then computes the attention output with respect to those blocks. Attached below the complete algorithm let us break in steps.

Let us understand the above steps starting with Allocating memory to Q, K,V in HBM.

Set the block sizes based on the M . As we can see in above image it is mentioned as M/4d. As our Query, key and value vector are n*d dimensional. Let us suppose if resultant of M/4d is 100, we would load 100 blocks of q, k, v vectors. In simple we will split the larger N*N vector into smaller blocks so that it can fit into the on chip SRAM.

We split the Q, K, and V into blocks using the block sizes .In the outer loop (red arrows in the above diagram), flash attention loops through K and V matrices and loads them to fast on-chip SRAM. In each block, flash attention loops over blocks of the Q matrix (blue arrows in the above diagram), loading them to SRAM, and writing the output of the attention computation back to the HBM. This alone provides 4 to 8 times wall clock speedup. In standard attention we compute the softmax but here in flash attention the mechanism is different, here we use summary statistics for computing the softmax. Since we are computing through blocks flash attention keeps track of summary statistics as it proceeds from one block to the next. When flash attention reaches the last block, the summary statistics will contain the exact softmax denominator. Flash attention keeps track of summary statistics {D and O} which are updated after each iteration/block.

The another point which was proposed in the paper was of Recomputation. So far we have discuss the forward pass. but the working of backward pass is different. Backward pass typically requires the gradients with respect to the attention matrix N*N, in order to propagate the gradient to previous layers. But the problem is attention matrix is never realized. It needs to be Recompute. By storing the output O (Nxd) and the softmax normalization statistics (N) we can recompute the attention matrices S (NxN) and P (NxN) in the backward pass directly from blocks of Q, K, and V (Nxd) in SRAM. This means that flash attention incurs more FLOPs compared to standard attention. Yet, even with more FLOPs, flash attention speeds up the backward pass due to reduced HBM accesses

How LLM Performance get Increase:

With help of Flash attention we can have Increase context window. As Flash attention allows the model to be trained on larger sequence length. Also with Flash attention the model training speed and Inference speed will increase. It achieves 15% faster training speed with BERT-large (seq. length 512), and 3Ă— faster training with GPT2 (seq. length 1K) than baseline implementations from HuggingFace and Megatron-LM. Also Flash attention reduces the significant read write operations as compare to standard attention.

Flash attention 2.0!

Recently Flash attention 2.0 was introduced by the authors. Below are the takeaways from the newer version

  1. Better Parallelism
  2. Better Work Partitioning
  3. Support for head dimensions up to 256

There is a article on Flash attention 2.0 https://crfm.stanford.edu/2023/07/17/flash2.html

Conclusion

Hopefully, by the end of this article, we will get to know about Flash attention.

Links , references and credits

Flash attention paper: https://arxiv.org/abs/2305.14314

--

--

Yashu Gupta
Yashu Gupta

Written by Yashu Gupta

Lead Data Scientist |AI Researcher | NLP Evangelist

No responses yet