SampleAttention: Near-Lossless Acceleration of
Long Context LLM Inference with
Adaptive Structured Sparse Attention

Qianchao Zhu, Jiangfei Duan, Chang Chen, Siran Liu, Xiuhong Li, Guanyu Feng§
Xin Lv§, Huanqi Cao, Chuanfu Xiao, Xingcheng Zhang, Dahua Lin‡∩, Chao Yang
Peking University  The Chinese University of Hong Kong
§Zhipu.AI  Tsinghua University  Shanghai AI Lab
Abstract

Large language models (LLMs) now support extremely long context windows, but the quadratic complexity of vanilla attention results in significantly long Time-to-First-Token (TTFT) latency. Existing approaches to address this complexity require additional pretraining or finetuning, and often sacrifice model accuracy. In this paper, we first provide both theoretical and empirical foundations for near-lossless sparse attention. We find dynamically capturing head-specific sparse patterns at runtime with low overhead is crucial. To address this, we propose SampleAttention, an adaptive structured and near-lossless sparse attention. Leveraging observed significant sparse patterns, SampleAttention attends to a fixed percentage of adjacent tokens to capture local window patterns, and employs a two-stage query-guided key-value filtering approach, which adaptively select a minimum set of key-values with low overhead, to capture column stripe patterns. Comprehensive evaluations show that SampleAttention can seamlessly replace vanilla attention in off-the-shelf LLMs with nearly no accuracy loss, and reduces TTFT by up to 2.42×2.42\times2.42 × compared with FlashAttention.

1 Introduction

Recent advances [1, 2, 3, 4, 5] race to scale the context window of large language models (LLMs) [6, 7, 8] for more complex applications, including document analysis [9], code copilot [10, 11], and prolonged conversations [12, 13]. Popular LLMs like Gemini [14], Claude [15] and Kimi [16] now support context lengths exceeding 1 million tokens. However, the increase in context length makes it challenging to support live interactions due to the quadratic complexity of attention mechanism. As illustrated in Figure 1, the attention computation time increases quadratically with sequence length, quickly dominating the Time to First Token (TTFT) latency (i.e. prefill latency). For example, in a 1 million token context, the attention of ChatGLM-6B [17] takes 1555155515551555 seconds, constituting over 90% of the TTFT when evaluated on an A100 GPU.

Various solutions have been proposed to address the quadratic complexity of attention, but none of them can be seamlessly and practically applied to pretrained LLMs without finetuning or pretraining and sacrificing model accuracy. Prior approaches explore to approximate dense attention with static or dynamic sparse attention [18, 19, 20, 21, 22, 23, 24, 25, 26], low-rank matrices [27, 28, 29], and unified sparse and low-rank attention [30, 31]. Recurrent states [32, 33, 34] and external memory [35, 36] are also investigated to mitigate the complexity. However, these approaches require pretraining from scratch or additional finetuning, and cannot achieve the same accuracy of full attention. StreamingLLM [37] offers a tuning-free sparse attention for infinite generation scenarios, but it cannot effectively reduce TTFT without accuracy loss. Therefore, we ask the question,

Refer to caption
Figure 1: Comparison of sparse attention pattern and TTFT latency speedup. SampleAttention features adaptive structured sparse, compared with previous static and dynamic sparse attention. It achieves significant reduction in TTFT compared with FlashAttention.

How can we reduce the TTFT for off-the-shelf long context LLMs with near-lossless111Near-lossless refers to that model accuracy stays above 99%percent9999\%99 % of the baseline according to MLPerf [38]. model accuracy?

In this paper, we first provide both theoretical and empirical foundations for near-lossless sparse attention. We find that the sparsity of intermediate score matrix in long-context attention is inherently-high, head-specific, and content-aware. Specifically, for a given long context prompt, some attention heads focus on only 0.2%percent0.20.2\%0.2 % of the tokens, while others may need to attend to over half. From the dynamic sparse patterns, we also demonstrate some inherent local window and column stripe patterns as illustrated in Figure 1. Except for adjacent tokens in the local window, some dynamic column stripes appear to be critical for near-lossless attention. This flexible sparsity indicates that sparse attention should dynamically capture the head-specific sparse patterns at runtime to be near-lossless. However, adaptive selection of essential elements involves significant overhead. The trade-off between efficiency and accuracy is a permanent topic in sparse attention design.

To address these challenges, we propose SampleAttention, an adaptive structured sparse attention that can be seamlessly integrated into off-the-shelf long context LLMs with near-lossless model accuracy. SampleAttention leverages the significant window and stripe sparse patterns, thus achieves structured sparse and is hardware-efficient. To resolve the adaptive sparsity, SampleAttention attends to a fixed percentage of adjacent tokens to capture local window patterns, and employs a two-stage query-guided key-value filtering approach, which adaptively select a minimum set of key-values with low overhead, to focus on column stripe patterns. SampleAttention significantly accelerates vanilla attention by reducing both I/O and computation requirements. We also implement hardware-efficient kernels. Notably, SampleAttention aims to reduce the computation overhead of attention, and is orthogonal and can be combined with existing KV cache eviction approaches [39, 40, 41] to further reduce memory consumption.

We evaluate SampleAttention on ChatGLM2 and InternLM2 with a suite of popular benchmarks covering various generative tasks across different sequence lengths. Experimental results show that SampleAttention achieves nearly no accuracy loss for different LLMs, significantly outperforming prior works, and reduces the TTFT by up to 2.42×2.42\times2.42 × compared with FlashAttention.

2 Related Work

Approximate Attention. Plenty of works have been proposed to approximate quadratic attention with lower complexity[18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 42, 40, 25]. For example, BigBird [20] combines window-, global- and random-attention to capture long range dependency. Reformer [21] reduces computional cost via locality-sensitive hashing. LongNet [22] replaces full attention with dilated attention. Linformer [27] employs low-rank matrix to approximate attention. HyperAttention [26] utilizes locality sensitive hashing to identify important entries on attention map. However, these approaches uses either static or coarse-grained sparse pattern, and often overlook the head-specific sparsity pattern. They cannot be losslessly applied in pretrained LLMs without additional finetuning or training.

KV Cache Compression. Long sequence comes with substantial KV cache memory consumption. StreamingLLM [37] keeps attention sinks and several recent tokens for infinite length generation. H2O [39] dynamically retains a balance of recent and heavy hitter tokens according to attention score during decoding. FastGen [43] adaptively construct KV cache according to observed head-specific policies. Recent efforts also quantize KV cache to lower precision to reduce memory consumption [44, 45, 46]. These works target on reducing the memory consumption of KV cache, while SampleAttention focuses on mitigating the long context computation overhead. SampleAttention can be combined with these approaches to further reduce memory consumption of KV cache.

3 Foundation of Near-Lossless Sparse Attention

We start with a regular full attention mechanism for one attention head, while the following contents can be seamlessly applied to multiple attention heads. Let QSq×dQsuperscriptsubscript𝑆𝑞𝑑\textbf{Q}\in\mathbb{R}^{S_{q}\times d}Q ∈ blackboard_R start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT and K, VSk×dK, Vsuperscriptsubscript𝑆𝑘𝑑\textbf{K, V}\in\mathbb{R}^{S_{k}\times d}K, V ∈ blackboard_R start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT be the query and key-value tensor of one head, where Sq,Sksubscript𝑆𝑞subscript𝑆𝑘S_{q},S_{k}italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is the sequence length respectively, and d𝑑ditalic_d is the head dimension. The full attention output OSq×dOsuperscriptsubscript𝑆𝑞𝑑\textbf{O}\in\mathbb{R}^{S_{q}\times d}O ∈ blackboard_R start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT can be formulated as,

P=softmax(QKTd)[0,1]Sq×Sk,O=PVSq×d,formulae-sequencePsoftmaxsuperscriptQK𝑇𝑑superscript01subscript𝑆𝑞subscript𝑆𝑘OPVsuperscriptsubscript𝑆𝑞𝑑\textbf{P}=\texttt{softmax}(\frac{\textbf{QK}^{T}}{\sqrt{d}})\in[0,1]^{S_{q}% \times S_{k}},\quad\textbf{O}=\textbf{PV}\in\mathbb{R}^{S_{q}\times d},P = softmax ( divide start_ARG QK start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , O = PV ∈ blackboard_R start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT , (1)

where softmax is applied in row-wise, and P is the attention score. We find, in long context LLMs, the attention score matrix P becomes extremely large, leading to inefficiencies. Moreover, applying softmax over long sequences tends to reduce the influence of smaller elements, making them less significant. This insight motivates us to investigate the inherent sparsity in the attention scores, which can potentially accelerate the attention mechanism without compromising accuracy.

3.1 Theoretical Foundation

We first present a theoretical foundation to explore the attention score sparsity. Suppose we apply an attention mask M{0,1}Sq×SkMsuperscript01subscript𝑆𝑞subscript𝑆𝑘\textbf{M}\in\{0,1\}^{S_{q}\times S_{k}}M ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT for attention score P to obtain a sparse attention. The sparse attention output O~Sq×d~Osuperscriptsubscript𝑆𝑞𝑑\tilde{\textbf{O}}\in\mathbb{R}^{S_{q}\times d}over~ start_ARG O end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT can be formulated as,

P~=MP[0,1]Sq×Sk,O~=P~VSq×d,formulae-sequence~PMPsuperscript01subscript𝑆𝑞subscript𝑆𝑘~O~PVsuperscriptsubscript𝑆𝑞𝑑\tilde{\textbf{P}}=\textbf{M}*\textbf{P}\in[0,1]^{S_{q}\times S_{k}},\quad% \tilde{\textbf{O}}=\tilde{\textbf{P}}\textbf{V}\in\mathbb{R}^{S_{q}\times d},over~ start_ARG P end_ARG = M ∗ P ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , over~ start_ARG O end_ARG = over~ start_ARG P end_ARG V ∈ blackboard_R start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT , (2)

where * represents the element-wise product. We give a theorem for near-lossless sparse attention.

Theorem 1.

(near-lossless sparse attention) Assume that L1subscript𝐿1L_{1}italic_L start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT-norms of values V are upper-bounded by R>0𝑅0R>0italic_R > 0. Given ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, there exists an attention mask M such that P~P1ϵRsubscriptnorm~PP1italic-ϵ𝑅||\tilde{\textbf{P}}-\textbf{P}||_{1}\leq\frac{\epsilon}{R}| | over~ start_ARG P end_ARG - P | | start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ divide start_ARG italic_ϵ end_ARG start_ARG italic_R end_ARG, and the following holds: O~O1ϵsubscriptnorm~OO1italic-ϵ||\tilde{\textbf{O}}-\textbf{O}||_{1}\leq\epsilon| | over~ start_ARG O end_ARG - O | | start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ italic_ϵ, where O~~O\tilde{\textbf{O}}over~ start_ARG O end_ARG near-losslessly approximates the attention output O.

The proof of Theorem 1 please refer to Appendix A.1. Theorem 1 suggests that we can always find an attention mask M to achieve near-lossless approximate attention for a given threshold. We then define a key metric that helps understand and quantify the efficiency of a sparse attention.

Definition 1.

The sparsity degree (SD) measures the maximum percentage of key-value elements that can be dropped while maintaining a specified CRA threshold α𝛼\alphaitalic_α, and formulated as,

SD(α)=maxM{0,1}Sq×Sk{1i,jMijSqSk/2}[0,1],s.t.CRA(M)α,formulae-sequenceSD𝛼subscriptMsuperscript01subscript𝑆𝑞subscript𝑆𝑘1subscript𝑖𝑗subscriptM𝑖𝑗subscript𝑆𝑞subscript𝑆𝑘201s.t.CRAM𝛼\textbf{SD}(\alpha)=\max_{\textbf{M}\in\{0,1\}^{S_{q}\times S_{k}}}\{1-\frac{% \sum_{i,j}{\textbf{M}_{ij}}}{S_{q}\cdot S_{k}/2}\}\in[0,1],\quad\text{s.t.}% \quad\textbf{CRA}(\textbf{M})\geq\alpha,SD ( italic_α ) = roman_max start_POSTSUBSCRIPT M ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT { 1 - divide start_ARG ∑ start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT M start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ⋅ italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT / 2 end_ARG } ∈ [ 0 , 1 ] , s.t. CRA ( M ) ≥ italic_α , (3)

where CRA evaluates the degree to which the attention score matrix can be recovered.

Definition 2.

The cumulative residual attention (CRA) is defined as the minimum sum of the remaining attention probabilities among each query after sparsification with M, and formulated as,

CRA(M)=mini{0,,Sq1}k=0iP~ik,whereP~=MPformulae-sequenceCRAMsubscript𝑖0subscript𝑆𝑞1superscriptsubscript𝑘0𝑖subscript~P𝑖𝑘where~PMP\textbf{CRA}(\textbf{M})=\min_{i\in\{0,\cdots,S_{q}-1\}}\sum_{k=0}^{i}\tilde{% \textbf{P}}_{ik},\quad\text{where}\quad\tilde{\textbf{P}}=\textbf{M}*\textbf{P}CRA ( M ) = roman_min start_POSTSUBSCRIPT italic_i ∈ { 0 , ⋯ , italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT - 1 } end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT over~ start_ARG P end_ARG start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT , where over~ start_ARG P end_ARG = M ∗ P (4)

Here we use minimum for CRA because we want to ensure even the row with minimal residual attention score can be near-losslessly recovered. We can show that the CRA of near-lossless sparse attention has a lower bound.

Lemma 1.

Given ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, the following holds for near-lossless sparse attention: CRA(M)1ϵRCRAM1italic-ϵ𝑅\textbf{CRA}(\textbf{M})\geq 1-\frac{\epsilon}{R}CRA ( M ) ≥ 1 - divide start_ARG italic_ϵ end_ARG start_ARG italic_R end_ARG.

Lemma 1 can be easily proved since P~P1=1CRA(M)subscriptnorm~PP11CRAM||\tilde{\textbf{P}}-\textbf{P}||_{1}=1-\textbf{CRA}(\textbf{M})| | over~ start_ARG P end_ARG - P | | start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 1 - CRA ( M ).

Takeaway: By discovering an effective attention mask M that meets a desired CRA threshold α𝛼\alphaitalic_α, we can effectively decrease the attention computation time with near-lossless sparse attention by minimizing I/O and the computation. The higher sparsity degree (SD(α)SD𝛼\textbf{SD}(\alpha)SD ( italic_α )) brings greater acceleration.

Refer to caption
Figure 2: Statistics of ChatGLM-6B (Model1, 28 layers×\times×32 heads) and InternLM2-7B (Model2, 32 layers×\times×32 heads), evaluated over different tasks at the prefill stage. (a) The trend of SD(α𝛼\alphaitalic_α=0.95) across each layer under real prompt questions with different sequence lengths, demonstrating inherent high attention sparsity. (b) The results of SD(α𝛼\alphaitalic_α=0.95) as the sequence length extends in the "Needle in a Haystack" task, indicating that an increase in sequence length intensifies the sparsity. (c) The variation of SD(α𝛼\alphaitalic_α=0.95) across different heads under a 90K sequence, indicating significant disparities in sparsity among the heads. (d) Different contexts cause the same head to display varied sparse structures, while numerous attention heads follow two primary patterns: column stripe and local window. (e) Relationship between the ratio of selected top-k strips and CRA. The high row-wise numerical distribution similarity enables a small amount of critical column stripes to cover the majority values of the full attention score matrix. Further details are presented in Appendix A.3, A.4.

3.2 Empirical Foundation of Adaptive Sparsity in Attention

Section 3.1 uncovers the possibility of approximating the attention output using near-lossless sparse attention, with the key lying in finding an effective attention mask. In this section, we present our empirical findings that reveal the inherently-high, head-specific, and content-aware adaptive sparsity and the significant patterns. These can be leveraged to achieve efficient near-lossless sparse attention.

Inherently-High Sparsity Degree. Our observations reveal that LLMs inherently exhibit a significant sparsity degree when using near-lossless sparse attention. In Figure 2(a), the average sparsity degree across different layers of various LLMs is depicted, with a threshold of α=0.95𝛼0.95\alpha=0.95italic_α = 0.95 for near-lossless model accuracy. We find that most layers exhibit remarkably high sparsity degree, surpassing 90%, regardless of the input length. Notably, the first layer has a lower sparsity degree.

To further quantify the variation in sparsity degree with increasing sequence length, we conduct a scaling evaluation on the "Needle in a Haystack" [47] task, as illustrated in Figure 2(b). Our findings indicate that as the context becomes longer, there is a corresponding increase in the sparsity degree.

Adaptive Sparsity. The attention sparsity is head-specific and content-aware. The sparsity degree and structure varies across different attention heads and input contexts. Figure 2(c) demonstrates that certain heads in most layers exhibit lower SD(α𝛼\alphaitalic_α=0.95) values, even when processing sequences as long as 90K. For example, one head in the first layer has a sparsity degree as low as 27.4%percent27.427.4\%27.4 %, while the highest degree can reach 99.8%percent99.899.8\%99.8 %. This suggests that different heads may have distinct roles in processing long sequences, indicating that uniform compression across all heads may not be optimal. Figure 2(d) shows that different contents of similar length result in noticeable variations in sparse patterns within the same layer and head. This indicates that regions with higher attention scores change significantly based on the given scenario, such as different user prompts.

Significant Window and Stripe Patterns. We identify two significant sparse patterns that substantially contribute to the attention score, as depicted in Figure 2(d). The local window pattern captures recent context information, while column stripe pattern embodies the key global contextual information. By adaptively combining these two patterns, LLMs can effectively handle both fine-grained information and key contextual cues. Figure 2(e) demonstrates that selecting a small amount of critical column strips is able to cover the majority values of the full attention score matrix, thus achieving a high CRA. This indicates the high numerical distribution similarity across rows.

Although similar patterns have been observed in recent works [43, 37, 39], they focus on reducing KV cache memory consumption during decoding. Directly migrating these approaches to accelerate prefill attention requires computing full attention score, which is unaffordable in long context. How to effectively explore these patterns for near-lossless acceleration of prefill is remain challenging.

Takeaway: Attention sparsity is inherently-high, head-specific, and content-aware, and exhibits significant local window and column stripe patterns. This adaptive sparsity indicates that sparse attention should dynamically capture the adaptive sparse patterns at runtime to be near-lossless.

4 SampleAttention

In this section, we introduce our approach to efficiently discover effective attention masks with observed significant sparse patterns and accelerate the attention with near-lossless sparse attention.

4.1 Problem Formulation

As discussed, the key to utilizing near-lossless sparse attention is to find an attention mask M with the following properties to achieve superior performance: 1) near-lossless: meets a desired CRA threshold α𝛼\alphaitalic_α, 2) adaptive: varies across different heads, layers and contents, 3) hardware-efficient: maximizes hardware efficiency, 4) efficiently discoverable: can be found with minimal overhead. A static mask clearly cannot meet these criteria, and these properties pose significant challenges.

Selecting an attention mask M{0,1}Sq×SkMsuperscript01subscript𝑆𝑞subscript𝑆𝑘\textbf{M}\in\{0,1\}^{S_{q}\times S_{k}}M ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT directly from the Sq×Sksubscript𝑆𝑞subscript𝑆𝑘S_{q}\times S_{k}italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT attention score grid during runtime is hardware-inefficient and incurs high overhead due to the grid size and potential random pattern. Thus, we first utilize the observed significant sparse patterns to simplify and reformulate the problem, aiming to discover a hardware-efficient structured sparse pattern mask M^^M\hat{\textbf{M}}over^ start_ARG M end_ARG,

M^Mwindow(w)Mstripe(IKV),w{1,,Sk},IKV{0,,Sk1},formulae-sequence^MsubscriptM𝑤𝑖𝑛𝑑𝑜𝑤𝑤subscriptM𝑠𝑡𝑟𝑖𝑝𝑒subscript𝐼𝐾𝑉formulae-sequence𝑤1subscript𝑆𝑘subscript𝐼𝐾𝑉0subscript𝑆𝑘1\hat{\textbf{M}}\coloneqq\textbf{M}_{window}(w)\cup\textbf{M}_{stripe}(I_{KV})% ,\quad w\in\{1,\cdots,S_{k}\},\quad I_{KV}\subseteq\{0,\cdots,S_{k}-1\},over^ start_ARG M end_ARG ≔ M start_POSTSUBSCRIPT italic_w italic_i italic_n italic_d italic_o italic_w end_POSTSUBSCRIPT ( italic_w ) ∪ M start_POSTSUBSCRIPT italic_s italic_t italic_r italic_i italic_p italic_e end_POSTSUBSCRIPT ( italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT ) , italic_w ∈ { 1 , ⋯ , italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT ⊆ { 0 , ⋯ , italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - 1 } , (5)

where w𝑤witalic_w is the window size for window mask Mwindow(w)subscriptM𝑤𝑖𝑛𝑑𝑜𝑤𝑤\textbf{M}_{window}(w)M start_POSTSUBSCRIPT italic_w italic_i italic_n italic_d italic_o italic_w end_POSTSUBSCRIPT ( italic_w ), and IKVsubscript𝐼𝐾𝑉I_{KV}italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT is the set of key-value indices of interest for stripe mask Mstripe(IKV)subscriptM𝑠𝑡𝑟𝑖𝑝𝑒subscript𝐼𝐾𝑉\textbf{M}_{stripe}(I_{KV})M start_POSTSUBSCRIPT italic_s italic_t italic_r italic_i italic_p italic_e end_POSTSUBSCRIPT ( italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT ) (as illustrated in Figure 3). We leverage the fixed sparse pattern, which is hardware-efficient, and adaptively determine the size and indices during runtime according to the context. Moreover, the structured attention mask M^^M\hat{\textbf{M}}over^ start_ARG M end_ARG maintains near-lossless property,

Theorem 2.

The hardware-efficient structured sparse pattern mask M^^M\hat{\textbf{M}}over^ start_ARG M end_ARG maintains near-lossless sparse.

The proof of Theorem 2 please refer to Appendix A.1. Given the formulation, the problem now is to find w𝑤witalic_w and IKVsubscript𝐼𝐾𝑉I_{KV}italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT for each head to meet the required properties during runtime.

4.2 Method

Tuned Window Size w𝑤witalic_w. High attention scores tend to occur in local windows of varying sizes, depending on the context, as shown in Figure 2(d). However, dynamically determining the local window size for each token incurs high overhead and is hardware-inefficient. Instead, we resort to set window size w𝑤witalic_w as a fixed percentage of sequence length (rw%×Skpercentsubscript𝑟𝑤subscript𝑆𝑘\lceil r_{w}\%\times S_{k}\rceil⌈ italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT % × italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⌉), where Sksubscript𝑆𝑘S_{k}italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is the sequence length of the input request. The percentage is tuned to be enough large to capture important local windows, and it also accommodates dynamic window sizes across various context lengths. While previous works have explored window attention [20, 39, 37], they typically rely on a fixed window size, which cannot adequately capture local dependencies across various context lengths.

KV Indices of Interest IKVsubscript𝐼𝐾𝑉I_{KV}italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT. The stripe pattern exhibits more dynamics, and our remaining problem is to efficiently and dynamically select a minimum set of key-value indices of interest IKVsubscript𝐼𝐾𝑉I_{KV}italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT for the input prompt and desired CRA threshold α𝛼\alphaitalic_α,

argminIKV|IKV|s.t.mini{0,,Sq1}jIKVPijα.subscriptargminsubscript𝐼𝐾𝑉subscript𝐼𝐾𝑉s.t.subscript𝑖0subscript𝑆𝑞1subscript𝑗subscript𝐼𝐾𝑉subscriptP𝑖𝑗𝛼\operatorname*{arg\,min}_{I_{KV}}|I_{KV}|\quad\text{s.t.}\quad\min_{i\in\{0,% \cdots,S_{q}-1\}}\sum_{j\in I_{KV}}\textbf{P}_{ij}\geq\alpha.start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT | s.t. roman_min start_POSTSUBSCRIPT italic_i ∈ { 0 , ⋯ , italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT - 1 } end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_j ∈ italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT end_POSTSUBSCRIPT P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ≥ italic_α . (6)

Ideally, computing the entire attention score matrix P and then selecting IKVsubscript𝐼𝐾𝑉I_{KV}italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT would be optimal, but this incurs unaffordable quadratic overhead in both computation and memory consumption. Fortunately, the similar distribution of large numerical values across rows, as observed in Section 3.2, can be leveraged to simplify the indices selection process. SampleAttention introduces a two-stage query-guided key-value filtering approach to approximate the solution. The PyTorch-style algorithm refers to Appendix A.7. Our evaluations show that the approximation performs pretty well (Section 5).

Stage-1: Query-Guided Attention Sampling. SampleAttention first samples the attention score matrix by computing exact scores for a few queries (Figure 3①). This is motivated by the significant column stripe sparse pattern: a high score for PiksubscriptP𝑖𝑘\textbf{P}_{ik}P start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT suggests a high probability that Pjk(ji)subscriptP𝑗𝑘𝑗𝑖\textbf{P}_{jk}(j\neq i)P start_POSTSUBSCRIPT italic_j italic_k end_POSTSUBSCRIPT ( italic_j ≠ italic_i ) is also high. Therefore we can select a minimum subset of queries {i1,,il}{0,,Sq1}subscript𝑖1subscript𝑖𝑙0subscript𝑆𝑞1\{i_{1},\cdots,i_{l}\}\subseteq\{0,\cdots,S_{q}-1\}{ italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_i start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } ⊆ { 0 , ⋯ , italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT - 1 } to approximate the attention score with low overhead. SampleAttention performs stride sampling along the rows based on a predefined sampling ratio rrow=lSqsubscript𝑟𝑟𝑜𝑤𝑙subscript𝑆𝑞r_{row}=\frac{l}{S_{q}}italic_r start_POSTSUBSCRIPT italic_r italic_o italic_w end_POSTSUBSCRIPT = divide start_ARG italic_l end_ARG start_ARG italic_S start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_ARG. Experiments show that this simple approach is effective: sampling a small amount of rows can accurately approximate the real CRA, further details can be found in Appendix A.5.

Stage-2: Score-Based Key-Value Filtering. SampleAttention then filters key-values indices of interest base on the sampled attention score. Exactly solve Equation 6 for sampled queries is inefficient due to long sequence length. To resolve this, SampleAttention filters key-values based on the accumulated attention scores along column (Figure 3②), which is more statistical approximation of attention score. After column-wise reduction, SampleAttention separately select top-k key-value indices that can meet the desired CRA threshold α𝛼\alphaitalic_α for each head. Attention sinks can also be discovered in this way.

Refer to caption
Figure 3: SampleAttention replaces the original full attention with a two-stage implementation. In the first stage, attention scores are computed by performing stride sampling across multiple rows and accumulating the scores along the column. In the second stage, the indices IKVsubscript𝐼𝐾𝑉I_{KV}italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT that meet the CRA threshold α𝛼\alphaitalic_α are selected via top-k operation for each head. The obtained IKVsubscript𝐼𝐾𝑉I_{KV}italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT is then merged with the masks of the local window and bottom area to enable sparse computation of the attention.
Table 1: The meaning of hyperparameters and tuning approach.
Hyperparameter Description Tuning
α𝛼\alphaitalic_α The desired CRA threshold Offline profiling separately
rrowsubscript𝑟𝑟𝑜𝑤r_{row}italic_r start_POSTSUBSCRIPT italic_r italic_o italic_w end_POSTSUBSCRIPT The sampling ratio in stage-1
rw%percentsubscript𝑟𝑤r_{w}\%italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT % The ratio of local window size

Hyperparameter Tuning. SampleAttention needs to tune several hyperparameters as listed in Table 1. These hyperparameters affects both model accuracy and inference latency. For example, a large α𝛼\alphaitalic_α reduces the speedup in latency but improves model accuracy, while a large rrowsubscript𝑟𝑟𝑜𝑤r_{row}italic_r start_POSTSUBSCRIPT italic_r italic_o italic_w end_POSTSUBSCRIPT increases the sampling overhead but reduces the attention approximation error. We find that fixed hyperparameter, obtained by lightweight offline profiling, for an LLM performs well across different tasks. Thus we use a small dataset that contains 22 requests ranging from 25K-96K context length to determine these hyperparameters. The detailed effects of varying these hyperparameters are studied in Section 5.3.

4.3 Hardware-efficient Implementation

To achieve substantial speedup in wall-clock time, SampleAttention is implemented with IO-awareness to maximize hardware-efficiency. First, the query-guided key-value filtering involves a series of small operators (bmm, softmax, reduction) that read and write large intermediate results. SampleAttention significantly reduces IO overhead by fusing these operators. Second, SampleAttention implements an efficient adaptive structured sparse attention kernel by modifying FlashAttention [48]. These hardware-aware optimizations enhance speed performance significantly.

5 Experiments

5.1 Setup

Backbones. We evaluate our method on two widely used open-source LLM variants: ChatGLM2-6B with a 96K context window based on GLM [17], and internLM2-7B [49] with a 200K context window based on LLAMA2 [8]. All utilized models are decoder-only transformers [50], and are pre-trained via causal language modeling. They encompass similar architectural components, such as rotary positional encoding [51], and grouped-query attention [52]. Simultaneously, there are notable differences, e.g., the former augments the context window capacity via continued training with an extended sequence length, whereas the latter achieves length extrapolation through rope scaling. We only replace the full attention implementation during the prompt prefill stage with SampleAttention and various baselines, while maintaining an uncompressed KV cache in the decode phase.

Tasks. We evaluate SampleAttention and other methods’ understanding capabilities in long-context scenarios on three distinct tasks: LongBench [53], BABILong [54], and Needle in a Haystack [47]. LongBench, a multi-task benchmark, comprises single and multi-document QA, summarization, few-shot learning, synthetic tasks, and code completion. It offers over 4,750 test cases with task lengths from 4K-35K. BABILong is a generative benchmark test designed to assess long-context inferencing capability, consisting of 20 different tasks. Given its generative nature, task lengths can be flexibly set from 4K-88K. Additionally, the "Needle in a Haystack" stress test challenges models to accurately extract information from a specific sentence buried within a lengthy document at a random position. We have set the number of depth intervals at 32, with lengths ranging from 10K-96K. Note that in these tasks, each case is evaluated after the model provides an output. This output is compared against a standard answer or judged by more advanced models, such as GPT-4 [55], for scoring.

5.2 Accuracy Results

Table 2: Accuracy comparison across various sparse methods on LongBench and BABILong. The best results are highlighted in Bold while the second best results are marked with an Underline.

Model Baseline LongBench BABILong Single- Doc QA Multi- Doc QA Summari- zation Few-shot Learning Synthetic Tasks Code Completion Total Scores Total Scores ChatGLM2 6B Full Attention 161.15 147.76 98.64 243.66 87.00 99.20 837.40 30.20 SampleAttention(α=0.95𝛼0.95\alpha=0.95italic_α = 0.95) 158.75 147.18 98.03 242.93 87.82 98.20 833.00 31.04 BigBrid 158.47 131.40 94.96 243.45 41.96 95.70 765.94 27.68 Streaming LLM 82.96 94.91 91.67 159.01 6.10 84.62 519.27 14.60 HyperAttention 85.76 80.71 85.18 175.28 8.82 73.19 508.94 17.00 Hash-Sparse 73.30 52.61 75.46 84.35 10.87 67.90 364.49 11.20 InternLM2 7B Full Attention 73.25 75.15 98.10 257.40 53.38 128.18 685.46 35.24 SampleAttention(α=0.95𝛼0.95\alpha=0.95italic_α = 0.95) 77.53 76.01 98.52 254.95 53.02 126.83 686.86 36.88 BigBrid 72.55 73.16 95.59 254.87 19.88 120.99 637.04 34.12 Streaming LLM 31.49 26.44 35.32 133.53 3.33 89.44 319.55 5.96 HyperAttention 87.98 33.40 38.52 95.78 3.09 77.80 336.57 16.64 Hash-Sparse 20.12 11.37 24.32 49.88 5.87 45.28 156.84 2.82

Refer to caption
Figure 4: Scores of different methods on the "Needle in a Haystack" task at various lengths.

Baselines and settings. We consider the full attention (as the gold baseline), BigBrid [20], Streaming-LLM [37], HyperAttention [26] and Hash-Sparse [24] as baselines to compare model accuracy across different tasks. To maintain consistency, we assign the same window size ratio 8%percent88\%8 % to SampleAttention, BigBrid, and StreamingLLM. BigBrid retains a global ratio of 8%percent88\%8 %. StreamingLLM sets its initail attention sink at 4 tokens. HyperAttention set both bucket size and the number of sampled columns to 256, and Hash-Sparse uses a bucket number of 16. The sampling ratio rrowsubscript𝑟𝑟𝑜𝑤r_{row}italic_r start_POSTSUBSCRIPT italic_r italic_o italic_w end_POSTSUBSCRIPT and the threshold α𝛼\alphaitalic_α for SampleAttention are set to 5% and 0.95, respectively, through offline profiling.

Main results. Table 2 and Figure 4 display the accuracy results of the models on three downstream tasks. Detailed results are listed in Appendix A.2. The results show that:

  • The performance in accuracy of SampleAttention is consistently robust across all benchmarks (including subdomains), various models, and diverse sequence lengths. When compared to full attention, which serves as the gold standard, SampleAttention consistently achieves scores above 99% of full attention, demonstrating near-lossless efficiency.

  • BigBrid exhibits varying degrees of performance degradation across different tasks, with "Synthetic Task" presenting a significant challenge. Nonetheless, on average, BigBrid still attains scores that are approximately 91% of those achieved by full attention.

  • StreamingLLM, HyperAttention and Hash-Sparse result in performance degradation across all tasks, demonstrating that these techniques fail to capture critical KV elements in long sequences at the prefill stage.

5.3 Hyperparameter Ablation Study

We conducted further tests on the impact of three critical hyperparameters in SampleAttention on the accuracy of downstream tasks. These experiments adhered to the settings outlined in Section 5.2, with only one hyperparameter changed at a time. Detailed results under different hyperparameter configurations are provided in Table 5.2.

Table 3: Results of varying the three hyperparameters in the SampleAttention on the ChatGLM2-6B. The best results are highlighted in Bold while the second best results are marked with an Underline.

Task full attention CRA threshold α𝛼\alphaitalic_α local window rw%percentsubscript𝑟𝑤r_{w}\%italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT % sample ratio rrowsubscript𝑟𝑟𝑜𝑤r_{row}italic_r start_POSTSUBSCRIPT italic_r italic_o italic_w end_POSTSUBSCRIPT α=0.80𝛼0.80\alpha=0.80italic_α = 0.80 α=0.90𝛼0.90\alpha=0.90italic_α = 0.90 α=0.95𝛼0.95\alpha=0.95italic_α = 0.95 α=0.98𝛼0.98\alpha=0.98italic_α = 0.98 rw=4subscript𝑟𝑤4r_{w}=4italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = 4 rw=8subscript𝑟𝑤8r_{w}=8italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = 8 2%percent22\%2 % 5%percent55\%5 % 10%percent1010\%10 % LongBench 837.40 820.30 824.98 833.00 829.80 792.87 833.00 809.34 833.00 831.14 BABILong 30.20 27.28 29.08 31.04 31.16 31.12 31.04 28.92 31.04 30.64 Needle in a Haystack 2235 2130 2090 2239 2231 2084 2239 2106 2239 2231

CRA threshold α𝛼\alphaitalic_α. Setting the α𝛼\alphaitalic_α too low leads to performance degradation due to the excessive filtering of KV elements. This represents a clear trade-off between performance and speedup. However, even with α𝛼\alphaitalic_α set at 80%, SampleAttention’s average performance score still surpasses 94.5% of the standard full attention, effectively validating the effectiveness of SampleAttention. Additionally, performance stabilizes when α𝛼\alphaitalic_α reaches a sufficiently high threshold. Thus, conducting a profiling to determine an appropriate α𝛼\alphaitalic_α for a given model is essential.

Local window size and sampling ratio. Additionally, setting excessively small local window ratios or sampling ratios also results in performance decreases. Specifically, halving the ratio of the local window size (k=4) results in a performance decline of over 6% in the LongBench and "Needle-in-a-Haystack" tasks. This confirms the high significance of KV elements within the local window area. Additionally, reducing the sampling ratio to 2% results in an approximate 4.5% performance loss. However, performance stabilizes once the sampling ratio reaches a certain threshold, as the top-k results for the approximate attention becomes stable.

5.4 Acceleration Speedup Benchmarking

We conducted micro-benchmarks on a single NVIDIA-A100 GPU (80GB) to evaluate performance in speed of attention operation during the prefill and TTFT metrics. The baselines selected were PyTorch’s scaled_dot_product_attention (noted as SDPA) and FlashAttention2. All tests were conducted using the configuration from ChatGLM2-6B: 32 heads, and d=128𝑑128d=128italic_d = 128, with synthetic data from the "Needle-in-a-Haystack" benchmark as input. We standardize the batch size of the input data to 1 to support longer sequence lengths.

Speedup and sampling overhead Figure 5 displays the profiling results conducted on the model’s full 28 layers using generated data ranging from 8K to 96K. Figure 5(a), focusing on the GPU performance of the attention module, indicates that both SampleAttention (α=0.95𝛼0.95\alpha=0.95italic_α = 0.95) and SampleAttention (α=0.80𝛼0.80\alpha=0.80italic_α = 0.80) do not exhibit a speed advantage over FlashAttention2 at shorter lengths, due to sampling overhead and small batch sizes. However, for longer sequences such as 96K, substantial savings in KV memory-transfers enable the attention operations of SampleAttention(α=0.95𝛼0.95\alpha=0.95italic_α = 0.95) and SampleAttention(α=0.80𝛼0.80\alpha=0.80italic_α = 0.80) to achieve accelerations of 2.20×2.20\times2.20 × and 5.12×5.12\times5.12 × over FlashAttention2, while reducing the TTFT metric by 1.62×1.62\times1.62 × and 2.28×2.28\times2.28 ×, respectively. Furthermore, Figure 5(c) demonstrates that as sequence lengths increase, the proportion of sampling overhead decreases, suggesting that SampleAttention can offer greater acceleration benefits for longer sequences.

Refer to caption
Figure 5: (a) Latency comparison for the self-attention module. (b)The proportion of time spent on sampling and sparse computation in SampleAttention. (c) Comparison for the TTFT metric.

Scaling the sequence length to 1M. We conducted GPU performance evaluations scalable to a sequence length of 1 million, based on profiling results from the first layer. Since SampleAttention is content-aware, for sequences longer than 128K, we derived the average attention latency per layer from the first layer results of SampleAttention combined with model sparsity analysis to avoid memory issues. Figure 6 illustrates that at a sequence scaling to 1M, thresholds of 0.95 and 0.80 respectively achieve reductions in the TTFT metric by 2.27×2.27\times2.27 × and 4.62×4.62\times4.62 ×.

Refer to caption
Figure 6: (a) and (b) compare the latency of attention and TTFT metrics as the sequence scales from 8K to 1M, respectively. The numbers represent the speedup compared with FlashAttention2.

6 Conclusion

In this paper, we first present both theoretical and empirical foundation for near-lossless sparse attention, and then leverage observed significant patterns to design SampleAttention, an adaptive structured sparse attention that can seamlessly replace FlashAttention in long context LLMs without accuracy loss. SampleAttention significantly reduces the TTFT of long context requests. Limitations and future work are discussed in Appendix A.6.

References

  • [1] Wenhan Xiong, **gyu Liu, Igor Molybog, Hejia Zhang, Prajjwal Bhargava, Rui Hou, Louis Martin, Rashi Rungta, Karthik Abinav Sankararaman, Barlas Oguz, et al. Effective long-context scaling of foundation models. arXiv preprint arXiv:2309.16039, 2023.
  • [2] Xiaoran Liu, Hang Yan, Shuo Zhang, Chenxin An, Xipeng Qiu, and Dahua Lin. Scaling laws of rope-based extrapolation. arXiv preprint arXiv:2310.05209, 2023.
  • [3] Yukang Chen, Shengju Qian, Haotian Tang, Xin Lai, Zhijian Liu, Song Han, and Jiaya Jia. Longlora: Efficient fine-tuning of long-context large language models. In The Twelfth International Conference on Learning Representations, 2023.
  • [4] Dacheng Li, Rulin Shao, Anze Xie, Ying Sheng, Lianmin Zheng, Joseph E. Gonzalez, Ion Stoica, Xuezhe Ma, and Hao Zhang. How long can open-source llms truly promise on context length?, June 2023.
  • [5] Shouyuan Chen, Sherman Wong, Liangjian Chen, and Yuandong Tian. Extending context window of large language models via positional interpolation. arXiv preprint arXiv:2306.15595, 2023.
  • [6] Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel M. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners. In Hugo Larochelle, Marc’Aurelio Ranzato, Raia Hadsell, Maria-Florina Balcan, and Hsuan-Tien Lin, editors, Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020.
  • [7] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • [8] Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  • [9] Tianyi Zhang, Faisal Ladhak, Esin Durmus, Percy Liang, Kathleen McKeown, and Tatsunori Hashimoto. Benchmarking large language models for news summarization. Transactions of the Association for Computational Linguistics, 12, 2024.
  • [10] Mark Chen, Jerry Tworek, Heewoo Jun, Qiming Yuan, Henrique Ponde de Oliveira Pinto, Jared Kaplan, Harri Edwards, Yuri Burda, Nicholas Joseph, Greg Brockman, et al. Evaluating large language models trained on code. arXiv preprint arXiv:2107.03374, 2021.
  • [11] Baptiste Roziere, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, **gyu Liu, Tal Remez, Jérémy Rapin, et al. Code llama: Open foundation models for code. arXiv preprint arXiv:2308.12950, 2023.
  • [12] Wei-Lin Chiang, Zhuohan Li, Zi Lin, Ying Sheng, Zhanghao Wu, Hao Zhang, Lianmin Zheng, Siyuan Zhuang, Yonghao Zhuang, Joseph E. Gonzalez, Ion Stoica, and Eric P. Xing. Vicuna: An open-source chatbot impressing gpt-4 with 90%* chatgpt quality, March 2023.
  • [13] Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li, Carlos Guestrin, Percy Liang, and Tatsunori B. Hashimoto. Stanford alpaca: An instruction-following llama model. https://github.com/tatsu-lab/stanford_alpaca, 2023.
  • [14] Gemini Team, Rohan Anil, Sebastian Borgeaud, Yonghui Wu, Jean-Baptiste Alayrac, Jiahui Yu, Radu Soricut, Johan Schalkwyk, Andrew M Dai, Anja Hauth, et al. Gemini: a family of highly capable multimodal models. arXiv preprint arXiv:2312.11805, 2023.
  • [15] Anthropic. Claude. https://www.anthropic.com/claude, 2023.
  • [16] Moonshot. Kimi chat. https://kimi.moonshot.cn/, 2023.
  • [17] Zhengxiao Du, Yujie Qian, Xiao Liu, Ming Ding, Jiezhong Qiu, Zhilin Yang, and Jie Tang. Glm: General language model pretraining with autoregressive blank infilling. arXiv preprint arXiv:2103.10360, 2021.
  • [18] Joshua Ainslie, Santiago Ontanon, Chris Alberti, Vaclav Cvicek, Zachary Fisher, Philip Pham, Anirudh Ravula, Sumit Sanghai, Qifan Wang, and Li Yang. ETC: Encoding long and structured inputs in transformers. In Bonnie Webber, Trevor Cohn, Yulan He, and Yang Liu, editors, Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), pages 268–284, Online, November 2020. Association for Computational Linguistics.
  • [19] Iz Beltagy, Matthew E Peters, and Arman Cohan. Longformer: The long-document transformer. arXiv preprint arXiv:2004.05150, 2020.
  • [20] Manzil Zaheer, Guru Guruganesh, Kumar Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, et al. Big bird: Transformers for longer sequences. Advances in neural information processing systems, 33:17283–17297, 2020.
  • [21] Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya. Reformer: The efficient transformer. arXiv preprint arXiv:2001.04451, 2020.
  • [22] Jiayu Ding, Shuming Ma, Li Dong, Xingxing Zhang, Shaohan Huang, Wenhui Wang, Nanning Zheng, and Furu Wei. Longnet: Scaling transformers to 1,000,000,000 tokens. arXiv preprint arXiv:2307.02486, 2023.
  • [23] Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509, 2019.
  • [24] Matteo Pagliardini, Daniele Paliotta, Martin Jaggi, and François Fleuret. Faster causal attention over large sequences through sparse flash attention. arXiv preprint arXiv:2306.01160, 2023.
  • [25] Aurko Roy, Mohammad Saffar, Ashish Vaswani, and David Grangier. Efficient content-based sparse attention with routing transformers. Transactions of the Association for Computational Linguistics, 9:53–68, 2021.
  • [26] Insu Han, Rajesh Jayaram, Amin Karbasi, Vahab Mirrokni, David Woodruff, and Amir Zandieh. Hyperattention: Long-context attention in near-linear time. In The Twelfth International Conference on Learning Representations, 2023.
  • [27] Sinong Wang, Belinda Z Li, Madian Khabsa, Han Fang, and Hao Ma. Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768, 2020.
  • [28] Krzysztof Marcin Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Quincy Davis, Afroz Mohiuddin, Lukasz Kaiser, et al. Rethinking attention with performers. In International Conference on Learning Representations, 2020.
  • [29] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pages 5156–5165. PMLR, 2020.
  • [30] Beidi Chen, Tri Dao, Eric Winsor, Zhao Song, Atri Rudra, and Christopher Ré. Scatterbrain: Unifying sparse and low-rank attention. Advances in Neural Information Processing Systems, 34:17413–17426, 2021.
  • [31] Beidi Chen, Tri Dao, Kaizhao Liang, Jiaming Yang, Zhao Song, Atri Rudra, and Christopher Re. Pixelated butterfly: Simple and efficient sparse training for neural network models. In International Conference on Learning Representations, 2021.
  • [32] Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752, 2023.
  • [33] Bo Peng, Eric Alcaide, Quentin Anthony, Alon Albalak, Samuel Arcadinho, Huanqi Cao, Xin Cheng, Michael Chung, Matteo Grella, Kranthi Kiran GV, et al. Rwkv: Reinventing rnns for the transformer era. arXiv preprint arXiv:2305.13048, 2023.
  • [34] Aydar Bulatov, Yury Kuratov, and Mikhail Burtsev. Recurrent memory transformer. Advances in Neural Information Processing Systems, 35:11079–11091, 2022.
  • [35] Yuhuai Wu, Markus N Rabe, DeLesley Hutchins, and Christian Szegedy. Memorizing transformers. arXiv preprint arXiv:2203.08913, 2022.
  • [36] Sebastian Borgeaud, Arthur Mensch, Jordan Hoffmann, Trevor Cai, Eliza Rutherford, Katie Millican, George Bm Van Den Driessche, Jean-Baptiste Lespiau, Bogdan Damoc, Aidan Clark, et al. Improving language models by retrieving from trillions of tokens. In International conference on machine learning, pages 2206–2240. PMLR, 2022.
  • [37] Guangxuan Xiao, Yuandong Tian, Beidi Chen, Song Han, and Mike Lewis. Efficient streaming language models with attention sinks. arXiv preprint arXiv:2309.17453, 2023.
  • [38] Vijay Janapa Reddi, Christine Cheng, David Kanter, Peter Mattson, Guenther Schmuelling, Carole-Jean Wu, Brian Anderson, Maximilien Breughe, Mark Charlebois, William Chou, et al. Mlperf inference benchmark. In 2020 ACM/IEEE 47th Annual International Symposium on Computer Architecture (ISCA), pages 446–459. IEEE, 2020.
  • [39] Zhenyu Zhang, Ying Sheng, Tianyi Zhou, Tianlong Chen, Lianmin Zheng, Ruisi Cai, Zhao Song, Yuandong Tian, Christopher Ré, Clark Barrett, et al. H2o: Heavy-hitter oracle for efficient generative inference of large language models. Advances in Neural Information Processing Systems, 36, 2024.
  • [40] Luka Ribar, Ivan Chelombiev, Luke Hudlass-Galley, Charlie Blake, Carlo Luschi, and Douglas Orr. Sparq attention: Bandwidth-efficient llm inference. arXiv preprint arXiv:2312.04985, 2023.
  • [41] Jesse Mu, Xiang Li, and Noah Goodman. Learning to compress prompts with gist tokens. Advances in Neural Information Processing Systems, 36, 2024.
  • [42] Lei Zhu, Xinjiang Wang, Zhanghan Ke, Wayne Zhang, and Rynson WH Lau. Biformer: Vision transformer with bi-level routing attention. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 10323–10333, 2023.
  • [43] Suyu Ge, Yunan Zhang, Liyuan Liu, Minjia Zhang, Jiawei Han, and Jianfeng Gao. Model tells you what to discard: Adaptive kv cache compression for llms. arXiv preprint arXiv:2310.01801, 2023.
  • [44] Haojie Duanmu, Zhihang Yuan, Xiuhong Li, Jiangfei Duan, Xingcheng Zhang, and Dahua Lin. Skvq: Sliding-window key and value cache quantization for large language models. arXiv preprint arXiv:2405.06219, 2024.
  • [45] Guangxuan Xiao, Ji Lin, Mickael Seznec, Hao Wu, Julien Demouth, and Song Han. Smoothquant: Accurate and efficient post-training quantization for large language models. In International Conference on Machine Learning, pages 38087–38099. PMLR, 2023.
  • [46] Yilong Zhao, Chien-Yu Lin, Kan Zhu, Zihao Ye, Lequn Chen, Size Zheng, Luis Ceze, Arvind Krishnamurthy, Tianqi Chen, and Baris Kasikci. Atom: Low-bit quantization for efficient and accurate llm serving. arXiv preprint arXiv:2310.19102, 2023.
  • [47] G Kamradt. Needle in a haystack–pressure testing llms, 2023.
  • [48] Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
  • [49] Zheng Cai, Maosong Cao, Haojiong Chen, Kai Chen, Keyu Chen, Xin Chen, Xun Chen, Zehui Chen, Zhi Chen, Pei Chu, et al. Internlm2 technical report. arXiv preprint arXiv:2403.17297, 2024.
  • [50] Alec Radford, Karthik Narasimhan, Tim Salimans, Ilya Sutskever, et al. Improving language understanding by generative pre-training. 2018.
  • [51] Jianlin Su, Murtadha Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, and Yunfeng Liu. Roformer: Enhanced transformer with rotary position embedding. Neurocomputing, 568:127063, 2024.
  • [52] Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit Sanghai. Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv preprint arXiv:2305.13245, 2023.
  • [53] Yushi Bai, Xin Lv, Jiajie Zhang, Hongchang Lyu, Jiankai Tang, Zhidian Huang, Zhengxiao Du, Xiao Liu, Aohan Zeng, Lei Hou, et al. Longbench: A bilingual, multitask benchmark for long context understanding. arXiv preprint arXiv:2308.14508, 2023.
  • [54] Yuri Kuratov, Aydar Bulatov, Petr Anokhin, Dmitry Sorokin, Artyom Sorokin, and Mikhail Burtsev. In search of needles in a 10m haystack: Recurrent memory finds what llms miss. arXiv preprint arXiv:2402.10790, 2024.
  • [55] Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, et al. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.

Appendix A Appendix

A.1 Proof of Theorems

For Theorem 1,

Proof.

In the worst case, we can set the attention mask M to all ones, ensuring that the sparse attention score P~~P\tilde{\textbf{P}}over~ start_ARG P end_ARG is identical to the original attention score P. Therefore, it is always possible to find an attention mask M that satisfies the condition P~P1ϵRsubscriptnorm~PP1italic-ϵ𝑅||\tilde{\textbf{P}}-\textbf{P}||_{1}\leq\frac{\epsilon}{R}| | over~ start_ARG P end_ARG - P | | start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ divide start_ARG italic_ϵ end_ARG start_ARG italic_R end_ARG. With that attention mask,

O~O1=P~VPV1=(P~P)V1P~P1V1.subscriptnorm~OO1subscriptnorm~PVPV1subscriptnorm~PPV1subscriptnorm~PP1subscriptnorm𝑉1||\tilde{\textbf{O}}-\textbf{O}||_{1}=||\tilde{\textbf{P}}\textbf{V}-\textbf{% PV}||_{1}=||(\tilde{\textbf{P}}-\textbf{P})\textbf{V}||_{1}\leq||\tilde{% \textbf{P}}-\textbf{P}||_{1}\cdot||V||_{1}.| | over~ start_ARG O end_ARG - O | | start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = | | over~ start_ARG P end_ARG V - PV | | start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = | | ( over~ start_ARG P end_ARG - P ) V | | start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ | | over~ start_ARG P end_ARG - P | | start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋅ | | italic_V | | start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT . (7)

Therefore, O~O1ϵsubscriptnorm~OO1italic-ϵ||\tilde{\textbf{O}}-\textbf{O}||_{1}\leq\epsilon| | over~ start_ARG O end_ARG - O | | start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ italic_ϵ, completing the proof. ∎

For Theorem 2,

Proof.

In the worst case, we can also set the attention mask M^^M\hat{\textbf{M}}over^ start_ARG M end_ARG to all ones, ensuring that the sparse attention score P~~P\tilde{\textbf{P}}over~ start_ARG P end_ARG is identical to the original attention score P. In this case, M^^M\hat{\textbf{M}}over^ start_ARG M end_ARG maintains the decomposed structured sparse pattern. Therefore, following the proof of Theorem 1 completes the proof. ∎

A.2 Detailed results

Figure 7 and Figure 8 report the detailed scores of the two evaluated models on the BABILong and "Needle in a Haystack" tasks across different sequence lengths, respectively. For settings of the baselines and overall score statistics, please refer to Section 5.2.

Refer to caption
(a) Full attention
Refer to caption
(b) SampleAttention
Refer to caption
(c) BigBrid
Refer to caption
(d) StreamingLLM
Refer to caption
(e) Full attention
Refer to caption
(f) SampleAttention
Refer to caption
(g) BigBrid
Refer to caption
(h) StreamingLLM
Figure 7: Detailed results of the evaluations on the BABILong benchmark: (a), (b), (c), and (d) are based on the ChatGLM-6B model, while (e), (f), (g), and (h) are based on the InternLM2-7B model.
Refer to caption
(a) Full attention
Refer to caption
(b) SampleAttention
Refer to caption
(c) BigBrid
Refer to caption
(d) StreamingLLM
Refer to caption
(e) Full attention
Refer to caption
(f) SampleAttention
Refer to caption
(g) BigBrid
Refer to caption
(h) StreamingLLM
Figure 8: Detailed results of the evaluations on the "Need in a Haystack" task: (a), (b), (c), and (d) are based on the ChatGLM-6B model, while (e), (f), (g), and (h) are based on the InternLM2-7B model.

Table 4 displays the sequence scaling results at the prefill stage based on the text-generation-interface serving framework, using the ChatGLM2-6B model with 8×\times×NVIDIA A100 GPUs. The parallelism configuration employed is TP=4 and PP=2, and a chunking implementation on the sequence length has been used for memory-efficiency. The profiled TTFT (Time To First Token) metric and the proportion of self-attention modules demonstrate the influence of the attention mechanism’s quadratic complexity. As sequence lengths increase, this complexity causes a significant rise in the latency of the attention module, which can approach around 90% at sequence lengths of 1 million.

Table 4: Latency breakdown at the prefill stage (Based on the ChatGLM-6B).
Sequence Length TTFT (ms) Full Attention (ms) Precent (%)
32K 1273.4 410.4 32.2
64K 2917.3 1538.1 52.7
128K 7756.5 4403.9 56.8
256K 23403.7 16839.5 72.0
512K 51084.3 43477.0 85.1
1M 169653.0 148774.1 87.7

A.3 Visualization of attention

Figures 9 and Figures 10 present the sparse patterns across various heads in the ChatGLM2-6B model (28 layers x 32 heads) under a sequence length of 61K. We conducted row-by-row filtering based on the full attention softmax weight, using a CRA threshold of α=0.95𝛼0.95\alpha=0.95italic_α = 0.95, and randomly selected four heads from different layers for display.

According to the visualization results on the majority of heads, we observed two distinct and prominent patterns prevalent in the heatmap of attention weight: column stripes and local windows. Column stripe patterns embody the global contextual information whereas diagonal window patterns capture local information.

Refer to caption
(a) Layer0
Refer to caption
(b) Layer0
Refer to caption
(c) Layer0
Refer to caption
(d) Layer0
Refer to caption
(e) Layer4
Refer to caption
(f) Layer4
Refer to caption
(g) Layer4
Refer to caption
(h) Layer4
Refer to caption
(i) Layer8
Refer to caption
(j) Layer8
Refer to caption
(k) Layer8
Refer to caption
(l) Layer8
Refer to caption
(m) Layer12
Refer to caption
(n) Layer12
Refer to caption
(o) Layer12
Refer to caption
(p) Layer12
Figure 9: The visualization attention based on a content length of 61K, displays the sparse patterns for randomly chosen heads from layers 0, 4, 8 and 12.
Refer to caption
(a) Layer16
Refer to caption
(b) Layer16
Refer to caption
(c) Layer16
Refer to caption
(d) Layer16
Refer to caption
(e) Layer20
Refer to caption
(f) Layer20
Refer to caption
(g) Layer20
Refer to caption
(h) Layer20
Refer to caption
(i) Layer24
Refer to caption
(j) Layer24
Refer to caption
(k) Layer24
Refer to caption
(l) Layer24
Figure 10: The visualization attention based on a content length of 61K, displays the sparse patterns for randomly chosen heads from layers 16, 20 and 24.

A.4 Sparisty analysis

To further quantify the degree of sparsity exposed as sequence lengths increase, we conducted scalability tests on the ChatGLM2-6B model using the "Needle-in-a-Haystack" task to evaluate sparsity. The results are presented in Table 5. According to the results, the increase in sequence length introduces more apparent sparsity. With each doubling of length, the proportion of KV elements needed to maintain the same threshold α𝛼\alphaitalic_α decreases by approximately 20%. Concurrently, a smaller threshold results in the filtering of more KV elements, which may also lead to a decline in task performance in accuracy. Additionally, Figures 11 illustrate the frequency statistics of the retained KV elements in the Sk𝑆𝑘Skitalic_S italic_k dimension for heads exhibiting different degrees of sparsity.

Table 5: Sparsity analysis for ChatGLM2-6B model as sequence length scales.
Sequence length Average SD (α𝛼\alphaitalic_α=0.90) Average SD (α𝛼\alphaitalic_α=0.95) Average SD (α𝛼\alphaitalic_α=0.98)
4K 91.27% 88.00% 79.17%
8K 93.68% 90.74% 83.43%
16K 95.84% 92.52% 86.37%
32K 96.34% 93.88% 88.68%
64K 96.91% 94.89% 90.70%
128K 97.44% 95.84% 92.43%
Refer to caption
Figure 11: The frequency reduction results for the retained KV elements in the Sk dimension on two randomly selected heads. The SD (α=0.95𝛼0.95\alpha=0.95italic_α = 0.95) for the left head under sequence length of 61K is 41.2%, while 97.5% for the right head.

A.5 Effectiveness of sampling

To verify the efficiency of this sampling method, we conducted tests on different heads using two distinct sampling ratios rwsubscript𝑟𝑤r_{w}italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT. We applied different ratios of top-k stripes combined with a tuned window mask to the full attention matrices to observe the changes in CRA. The results, as shown in Table 6, indicate that the CRA achieved by selecting top-k stripes at a 5% sampling ratio is remarkably close to that obtained from the full attention score. This confirms that SampleAttention’s simple sampling method is highly efficient.

Table 6: The CRA percentages that can be achieved by selecting different ratios of top-k stripes under different sampling ratios for each head. The sequence length of tested content is 61K.

ratio of top-k stripes 2.5% 5% 10% 20% 40% 80% sampling ratio 100% 5% 100% 5% 100% 5% 100% 5% 100% 5% 100% 5% Layer0-Head0 10.60% 10.31% 17.85% 17.74% 29.49% 28.83% 47.09% 46.14% 71.19% 70.15% 97.12% 96.65% Layer13-Head0 75.29% 65.62% 80.57% 74.89% 86.33% 81.58% 92.09% 89.98% 97.07% 95.21% 99.85% 98.68% Layer13-Head13 98.24% 97.85% 98.63% 98.29% 99.02% 98.73% 99.41% 99.12% 99.76% 99.66% 100.00% 99.80a%

A.6 Limitations and Future Work

We discuss limitations of SampleAttention and future directions in this subsection.

Other pattern and sampling.

We also identified additional diagonal structures in heads with lower sparsity levels. Although SampleAttention is capable of covering these areas by selecting an adequate proportion of KVs, accurately capturing these patterns could potentially lead to further performance enhancements. Additionally, considering the time overhead associated with sampling, how to further improve sampling efficiency to achieve acceleration even at shorter sequence lengths remains an important challenge for future research.

Hyperparameter tuning.

The experimental results demonstrate that hyperparameters substantially influence the trade-off between task performance and speedup. Consequently, swiftly determining efficient hyperparameters for a specific model emerges as a critical challenge. In the future, we aim to implement autotuning of these hyperparameters during task runtime, enabling SampleAttention to consistently achieve high accuracy and low latency across diverse sequence lengths and scenarios.

Serving.

After integrating SampleAttention into the distributed serving framework, we found that requests with ultra-long sequences (>=128K) or large batch sizes will cause memory issues. More engineering efforts are required to achieve memory efficiency, potentially through strategies like implementing pipeline or sequence parallelism and chunking along the sequence dimension.

A.7 PyTorch-Style Implementation Algorithm

Algorithm 1 presents a succinct pseudo-code of the SampleAttention’s implementation in the PyTorch style. Link to the source code based on PyTorch and Triton, along with scripts to reproduce the main experimental results, will be provided in the camera-ready version.

Input:QSq×d,KSk×d,VSk×d,α(0,1),rrow(0,1),rwformulae-sequenceQsuperscript𝑆𝑞𝑑formulae-sequenceKsuperscript𝑆𝑘𝑑formulae-sequenceVsuperscript𝑆𝑘𝑑formulae-sequence𝛼01formulae-sequencesubscript𝑟𝑟𝑜𝑤01subscript𝑟𝑤\textbf{Q}\in\mathbb{R}^{Sq\times d},\textbf{K}\in\mathbb{R}^{Sk\times d},% \textbf{V}\in\mathbb{R}^{Sk\times d},\alpha\in(0,1),r_{row}\in(0,1),r_{w}\in% \mathbb{N}Q ∈ blackboard_R start_POSTSUPERSCRIPT italic_S italic_q × italic_d end_POSTSUPERSCRIPT , K ∈ blackboard_R start_POSTSUPERSCRIPT italic_S italic_k × italic_d end_POSTSUPERSCRIPT , V ∈ blackboard_R start_POSTSUPERSCRIPT italic_S italic_k × italic_d end_POSTSUPERSCRIPT , italic_α ∈ ( 0 , 1 ) , italic_r start_POSTSUBSCRIPT italic_r italic_o italic_w end_POSTSUBSCRIPT ∈ ( 0 , 1 ) , italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ∈ blackboard_N
# Stage1: Query-Guided Attention Sampling
SampleWeight = sample_bmm_softmax_reduction(Q,K,rrowsubscript𝑟𝑟𝑜𝑤r_{row}italic_r start_POSTSUBSCRIPT italic_r italic_o italic_w end_POSTSUBSCRIPT)
SortedWeight = SampleWeight.sort(dim=-1)
WeightSum = SortedWeight.sum(dim=-1)
# Stage2: Score-Based Key-Value Filtering
# example prefixsum_sample_list=[0.0125, 0.025,0.05,0.1,0.2,0.4,0.8,1.0] * Sk
SD_sample_list = SortedWeight[::,:prefixsum_sample_list].sum()/WeightSum
KV_ratio_per_head = searchsorted(SD_sample_list, α𝛼\alphaitalic_α)
IKVsubscript𝐼𝐾𝑉I_{KV}italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT_per_head = gather_KV_Index(SortedWeight.idx, KV_ratio_per_head)
# Sparse computation of the attention
# combined mask of IKVsubscript𝐼𝐾𝑉I_{KV}italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT and tuned window size k𝑘kitalic_k
M_Merged = merge_mask( IKVsubscript𝐼𝐾𝑉I_{KV}italic_I start_POSTSUBSCRIPT italic_K italic_V end_POSTSUBSCRIPT_per_head , rwsubscript𝑟𝑤r_{w}italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT)
Output = sparse_flash_attn(Q, K, V, M_Merged)
Algorithm 1 Implementation of SampleAttention