Go To:

Paper Title Paper Authors Table Of Contents Abstract References
Home
Report a problem with this paper

Memory-efficient Transformers via Top-k Attention

Authors

Abstract

Following the success of dot-product attention in Transformers, numerous approximations have been recently proposed to address its quadratic complexity with respect to the input length. While these variants are memory and compute efficient, it is not possible to directly use them with popular pre-trained language models trained using vanilla attention, without an expensive corrective pre-training stage. In this work, we propose a simple yet highly accurate approximation for vanilla attention. We process the queries in chunks, and for each query, compute the top-k scores with respect to the keys. Our approach offers several advantages: (a) its memory usage is linear in the input size, similar to linear attention variants, such as Performer and RFA (b) it is a drop-in replacement for vanilla attention that does not require any corrective pre-training, and (c) it can also lead to significant memory savings in the feed-forward layers after casting them into the familiar query-keyvalue framework. We evaluate the quality of top-k approximation for multi-head attention layers on the Long Range Arena Benchmark, and for feed-forward layers of T5 and UnifiedQA on multiple QA datasets. We show our approach leads to accuracy that is nearly-identical to vanilla attention in multiple setups including training from scratch, fine-tuning, and zero-shot inference.

1 Introduction

The Transformer architecture [48] has been successful in a wide range of natural language processing tasks, including machine translation [11] , language modeling [40] , question-answering [18] , and many more. Transformers pre-trained on large amounts of text with a language modeling (LM) objective, have become the standard in NLP, exhibiting surprising amounts of linguistic and world knowledge [30, 10, 31, 17, 39] .

The contextualizing component of the Transformer is the attention layer where all positions in an input sequence of length L aggregate information from the entire sequence in parallel. At its core, given L query, key, and value vectors Q, K, V respectively, the dot-product attention function outputs softmax(QK )V where the softmax function is applied row-wise on the matrix QK ∈ R L×L of similarity scores of the query-key pairs, leading to an expensive Ω(L 2 ) memory requirement.

To alleviate this, past work proposed approximation methods for the computation of softmax(QK ). One major line of research focused on sparse attention variants, where only a few similarity scores are computed per query, and the rest are ignored. Methods differ by which query-key pairs are selected [5, 52, 32, 40, 21, 2, 15, 49] . A second line of research explored dense variants [19, 50, 1, 44] (cf. [46] for a survey). For example, instead of computing the attention scores exactly for only a small number of query-key pairs, [6] compute an approximation of scores for all pairs.

In this work, we adopt the sparse attention approach, but rather than approximating the k most similar key vectors per query vector, we compute this quantity exactly. Specifically, we propose top-k attention where, for each query vector, we only keep its k largest similarity scores with respect to the L keys, where k L. We show that top-k attention can be implemented in a memory-efficient manner by (a) chunking the query vectors when computing the output one chunk at a time, when computing softmax(QK )V , and (b) a custom implementation of the forward and backward pass that does not require caching activations while processing chunks in the forward pass.

Compared to prior methods, top-k attention has multiple attractive properties:

• Top-k attention has the same memory footprint as Performer [6] , a state-of-the-art attention variant with linear time and memory complexity, on very long inputs (orange curve, Fig. 1 , top-right), while being as fast as vanilla attention, and even faster than linear variants on inputs of length up to 4K (Figure 1 , bottom-left). This allows us, e.g., to train a typical 12-layer Transformer decoder over 32K-long inputs on a 30GiB GPU ( Figure 3a ). • Top-k attention also reduces memory consumption in Transformer feed-forward layers, by casting this layer into the familiar query-key-value framework using ReLU instead of the row-wise softmax [42] . This is specifically appealing in models such as T5 [34] and GPT-3 [3] , where for short inputs, the memory consumption is dominated by the feed-forward layers, as the number of keys, corresponding to the feed-forward hidden dimension size, is as large as 65K. Conversely, methods that rely on random feature approximations of attention, such as Performer [6] and RFA [29] do not admit an efficient approximation for the ReLU activation [53] . • Top-k attention is a highly accurate approximation to vanilla attention and is a plug-and-play replacement at both multi-head attention and feed-forward layers of a Transformer. This is unlike past attention variants [19, 6, 29] that require an expensive corrective pre-training stage to adjust model weights to the new variant, which can be prohibitive for large models. We show top-k attention can replace vanilla attention in a zero-shot inference setup and at fine-tuning time without any corrective pre-training.

Figure 1: Memory and time required for a forward and backward pass on a single BERT-base multi-head self-attention layer with causal masking on short (left) and long (right) inputs. Details are in §3.
Figure 3: Memory and time required for a combined forward and backward pass (details in §3).

We extensively evaluate top-k attention on a wide range of tasks and demonstrate its mentioned advantages. Training from scratch, we show top-k attention performs as well as vanilla self-attention on Long Range Arena, a benchmark dedicated to evaluating the ability of transformers to handle long sequences, and in a language modeling task (WikiText-103). Second, we show top-k attention can be used as a drop-in replacement for vanilla attention at inference time without any additional training at the feed-forward layer of the UnifiedQA model [20] on 12 different question answering (QA) datasets, reducing the number of keys used per query by more than 99%. Last, we show top-k attention obtains similar performance to vanilla attention on a wide range of QA tasks when fine-tuning T5 [34] , without the need for any corrective pre-training.

Overall, our results demonstrate that top-k attention is a simple and effective method for dramatically reducing the memory footprint of Transformers without loss of accuracy that can allow resourceconstrained researchers enjoy the benefits of large pre-trained Transformer-based models. Our code is available at https://github.com/ag1988/top_k_attention.

2 Efficient Transformer Through Top-K Attention

In this section, we briefly review the Transformer architecture, its sparse approximations, and show how to cast the feed-forward layer into the query-key-value framework ( §2.1). We then describe top-k attention and our memory-efficient implementation for it ( §2.2).

2.1 Attention In Transformers

A Transformer [48] is a stack of layers each consisting of multi-head attention and feed-forward sub-layers. Its contextualizing component is the multi-head attention defined as follows.

Multi-head Attention Given a query Q ∈ R L Q ×d , key K ∈ R L K ×d and value V ∈ R L K ×d , the output ∈ R L Q ×d of dot-product attention is defined as:

EQUATION (1): Not extracted; please refer to original document.

where λ is an optional temperature typically fixed as √

EQUATION (2): Not extracted; please refer to original document.

The sparsity of B can be leveraged via customized implementations of matrix product [5, 2] and, thus Eq. 2 can be significantly cheaper to compute compared to Eq. 1.

Feed-forward as attention In the feed-forward layer, a 1-hidden layer fully-connected network is applied identically to every input token. As observed in past work [42, 41, 13] , a feed-forward layer can be cast into the query-key-value framework as:

EQUATION (3): Not extracted; please refer to original document.

In this case, the queries Q ∈ R L Q ×dmodel are the inputs to the layer with L Q tokens, similar to self-attention. However, the keys

K = W K ∈ R L K ×dmodel and values V = W V ∈ R L K ×dmodel

are learned parameters that are independent of the input. The number of keys L K here is known as the feed-forward dimension and can be as large as 65K for wide models such as T5 [34] and GPT-3 [3] . In the common case where the input sequences are relatively short, memory consumption is dominated by the feed-forward sub-layer and not the self-attention sub-layer.

Unlike top-k attention, past approaches for approximating attention are incompatible with feedforward layers. Most approximate attention variants, such as Sparse Transformer [5] , LongFormer [2] , BigBird [54] , Sinkhorn attention [44] , rely on a locality bias in sequences, where the key vectors indexed close to each other in K are assumed to have similar representations. This is irrelevant for keys in a feed-forward layer, which are permutation-equivariant and do not have any local structure. Dense attention variants relying on random fourier features for approximating the softmax function are also not applicable, since it is known that ReLU cannot be approximated using such features [53] .

2.2 Top-K Attention

In this work we propose top-k attention, where for each query, we mask out all but its k largest dot products with the keys, that is, in each row of QK we only keep its k largest elements and mask

EQUATION (4): Not extracted; please refer to original document.

where activation can be softmax, ReLU, or any other activation, and top-k(QK ) denotes a sparse matrix consisting only of the row-wise top-k elements of QK . A naïve approach for computing top-k(QK ) would be to first compute QK and applying a row-wise top-k operation.

Unfortunately, computing QK ∈ R L Q ×L K explicitly would require Ω (L Q • L K ) memory.

We now describe our approach, which avoids this high cost.

Query chunking A simple way to implement attention and reduce its peak memory consumption is to chunk queries: instead of processing all the queries at once, we partition the queries into chunks and process them sequentially, one chunk at a time. For a chunk size C, the rows of Q are grouped into L Q /C contiguous chunks of size C and the attention function (Eq. 1, 3, 4) is computed using Q C , K, V as inputs where Q C denotes the subset of Q corresponding to chunk C.

During inference, once a query chunk is fully processed, the intermediate activations produced during its processing can be discarded and, hence, the peak memory required to process all L Q queries is bounded by the memory required to process a single chunk. Therefore, modulo the storage required for Q, K, V and the outputs themselves, the peak memory usage reduces from

Ω (L Q • L K ) to O(C • L K )

which is linear with respect to L K for a fixed chunk size C.

Chunk size provides a simple way to trade-off between the maximum memory usage and the slowdown due to the sequential processing of chunks. Fig. 2 shows memory and time for different chunk sizes for a single BERT-base self-attention layer over a sequence of length 65, 536. We observe that chunk sizes 2 9 , 2 10 yield a good trade-off between time and memory.

Figure 2: Memory and time required for a forward pass on a single BERTbase multi-head self-attention layer on inputs of length 65, 536.

Input checkpointing While query chunking provides a straightforward approach for bounding the peak memory usage of attention during inference, it is not so straightforward to employ it during training. Let d (A) denote the gradient of the loss with respect to a tensor A. For a given query chunk Q C , the intermediate activations produced during the computation of the output

o C = Attention(Q C , K, V ) are required for computing d (Q C ) from d (o C ) via backpropagation.

Unfortunately, for the above bound on the peak memory usage to hold, we cannot afford to cache these activations for all the chunks, as done by standard automatic differentiation packages.

Taking inspiration from gradient checkpointing [4] , we observe that if the inputs Q C , K, V are available during the backward pass, we can re-compute o C and then use the produced intermediate

activations to compute d (Q C ) from d (o C ). Once d (Q C )

is computed, we can again discard the intermediate activations and gradients produced during this step and move on to the next chunk. This ensures that the peak memory usage during the backward pass through the attention layer is bounded by the memory required to backpropagate through a single chunk.

To summarize, a customized backward pass allows us to utilize query chunking, both during forward and backward passes, and only requires us to cache the inputs to the attention function. For a stack of N attention layers and fixed d, this reduces the peak memory usage from

Ω (L Q • L K • N ) to O ((L Q + L K ) • N + C • L K ).

As described above, the combination of query chunking and input checkpointing provides a simple method for reducing the memory-footprint of vanilla attention, independent of top-k attention. Indeed, our benchmarking experiments in §3 demonstrate this. However, a drawback of this approach is that, during the backward pass, an implicit second forward pass is performed to re-compute the intermediate activations as described above. This can potentially increase the compute (FLOPs) required for a combined forward and backward pass by 50%. We now describe how to further improve both compute and memory by combining query chunking and input checkpointing with top-k attention.

Improving efficiency through top-k attention We now show that one can avoid re-computing activations in the case of top-k-Attention (Eq. 4). At a high level, top-k(QK ) provides a highly compressed but accurate representation of QK and requires only O(

L Q • k) storage, compared to Ω(L Q • L K ) for QK ,

where we assume k L K . Hence, we can cache it in addition to Q, K, V without incurring a significant increase in memory usage.

In the pseudo-code below, we show the forward and backward pass for top-k attention over a query chunk with input checkpointing. The steps of the forward pass that we re-compute during the backward pass is the application of activation on the output of top-k(QK ) and forming its matrix representation for subsequent operations (compare Lines 6, 7 and Lines 18, 20). Therefore, the number of FLOPs spent on re-computation in our implementation of top-k attention is at most

O(L Q •(k +L K )•N ), typically much lower than Ω(L Q •L K •d•N )

, as there is no need to re-compute the dot-products.

Moreover, our benchmarking experiments ( §3) show that top-k attention leads to improved memory usage compared to query chunking. This is because in vanilla attention (Eq. 1), while performing the re-computation in the backward pass of a query chunk Q C , we first re-compute Q C K , then apply activation and backpropagate through this operation to compute d(Q C K ). This implies that at this point there are three C × L K matrices in memory. In top-k attention, activation is applied only on a C × k matrix and at any given time during the backward pass shown below there is at most one C × L K matrix in memory: either actv (Line 20) or d_dots (Line 23). As our experiments show, this can lead to a much smaller memory footprint for small values of k, N .

def forward ( Q , K , V , k , activation ) : # Q: query chunk , K: keys , V: values [C, d], [L_K , d], [L_K , d] dots = matrix_prod ( Q , transpose ( K ) ) # [C, L_K] top_dots , top_indices = row_wise_topk ( dots , k ) # [C, k], [C, k] del dots top_actv = activation ( top_dots ) # [C, k] actv = matrix ( top_actv , top_indices ) out = matrix_prod ( actv , V ) # [C, d] to_cache ( Q , K , V , top_dots , top_indices , activation )

return out def backward ( d_out ) :

# d_out: grad of loss w.r.t. out # [C, d] Q , K , V , top_dots , top_indices , activation = from_cache ( ) d_top_actv = matrix_prod ( d_out , transpose ( V ) , out_indices=top_indices ) # [C, k]

# did not cache top_actv so re -compute it to backpropagate with compute_grads ( ) :

top_actv = activation ( top_dots ) # [C, k] d_top_dots = top_actv . backpropagate ( d_top_actv ) # [C, k] actv = matrix ( top_actv , top_indices ) d_V = transpose ( matrix_prod ( transpose ( d_out ) , actv ) ) # [L_K , d] del actv d_dots = matrix ( d_top_dots , top_indices ) d_Q = matrix_prod ( d_dots , K ) # [C, d] d_K = transpose ( matrix_prod ( transpose ( Q ) , d_dots ) ) # [L_K , d] return d_Q , d_K , d_V

3 Benchmarking

In this section, we benchmark top-k attention in terms of time and memory, and compare it to vanilla attention, query-chunking without the top-k operation, and to Performer [6] , as a representative of state-of-the-art linear attention variants. We separately benchmark (a) a single self-attention layer Experimental details For all models, we benchmark by running a forward and backward pass over random inputs. Each measurement is an average over runs on an Nvidia A100 GPU and is discarded if memory usage exceeds 30GiB. We use causal masking for self-attention layers to highlight the simplicity of our approach that can seamlessly handle arbitrary attention masks, unlike other methods [50, 19, 6] , where implementing causal masking requires customized CUDA implementations. For Performer, we use 256 random features, and the CUDA implementation from [19] .

Multi-head attention layer: We benchmark a single multi-head attention layer over long sequences in a configuration similar to BERT-base: d model is 768, 12 heads of size 64, and feed-forward dimension 3072. Fig. 1 shows the results when setting k to 128 and the query chunk size to 1024, which was shown to provide a good time-memory trade-off in §2.2.

We observe that top-k attention has the same device-memory usage as the Performer (top) for sequences as long as 65K tokens, while being as fast as vanilla attention, and even faster than Performer on inputs of length up to 4K. With vanilla attention, we cannot fit even a single multi-head attention layer over a sequence of more than 10K tokens, while top-k uses less than 10GiB of memory over sequences of length 65K. Lastly, we observe improvement in both time and memory when comparing top-k attention to query chunking over vanilla attention, where using top-k leads to a 3× memory reduction for sequences of length 65K.

Feed-forward layer : While considerable effort has been dedicated to devising efficient models for long contexts, a large feed-forward dimension is useful for knowledge-intensive tasks such as opendomain QA [39, 3] , and efforts have been made to reduce its complexity [12] . We benchmark the resource usage of top-k attention at a single feed-forward layer for different feed-forward dimensions using batch size 512 and input length 512, which results in 2 18 queries per batch.

Top-k attention (Figure 3b ), for k = 512 and query chunk size 2 14 , dramatically improves devicememory usage compared to vanilla attention: it allowed us to use a feed-forward dimension 65K with 11GiB, while vanilla attention uses the same amount of memory with a feed-forward dimension 2K. Fitting a linear curve to the memory usage of vanilla attention and top-k attention, we estimate that top-k attention can handle feed-forward dimension 205K compared to 7K for vanilla attention on a 30GiB machine. Moreover, comparing top-k attention to query chunking, we again observe a 3× improvement in memory usage when the number of keys is 65K. Lastly, we observe only a minor slowdown in top-k attention compared to vanilla attention.

12-Layer Model:

We benchmark a 12-layer model to examine the cumulative utility of not caching QK in all N layers compared to the Performer. We use the same architecture as BERT-base with batch size 1 and vary the input length. We use a Transformer decoder with top-64 attention and chunk size 1, 024 at the self-attention layers, and simple query chunking with chunk size 4, 096 at the feed-forward layers.

We easily fit a 32K-long input on a 30GiB GPU, improving memory consumption by more than 8× compared to vanilla Transformer and 2× compared to Performer. Moreover, top-k attention outperforms query chunking in terms of both memory and runtime. As top-k attention targets memory consumption but not runtime, a current limitation is that runtime, unlike Performer, is still quadratic. Thus, running multi-layer models on long sequences is reasonable in a fine-tuning or zero-shot inference setup, but further work is required for training from scratch 12-layer models over large datasets that contain long sequences.

Overall, our benchmarking results over multi-head attention, feed-forward, and multi-layer Transformer establish top-k attention as a strong baseline for future work on efficient Transformers that dramatically improves memory consumption. Next, we evaluate top-k attention on downstream tasks and show that top-k attention can be used as a drop-in replacement for vanilla attention without additional pre-training, which can allow resource-constrained research groups experiment with Transformers over long sequences or models with a large feed-forward dimension.

4 Experimental Evaluation Of Top-K Attention

Having established top-k attention as a memory efficient alternative to vanilla attention, we now show that, even for small values of k, top-k attention provides a high-quality approximation of vanilla attention, both at the multi-head attention and feed-forward layers. We empirically show this in a wide range of setups including (a) training from scratch on tasks that require handling longrange dependencies ( §4.1) and on language modeling ( §4.2), (b) fine-tuning pre-trained language models (T5) on multiple QA datasets ( §4.5), and (c) performing zero-shot inference using pre-trained language models (UNIFIEDQA) without any training ( §4.3).

4.1 Long Range Arena

Long Range Arena [45] is a recently established benchmark for evaluating the ability of Transformer variants to handle long sequences. It comprises of multiple text classification tasks with inputs containing thousands of tokens (Table 1 ). In ListOps [28] , given a sequence of operations on singledigit integers, the model predicts a single-digit solution modeled as 10-way classification. IMDb movie reviews [25] is a character-level binary sentiment classification task. Lastly, in the ACL Anthology Network (AAN) [33] task, a character-level model classifies if there is a citation between a pair of papers.

Table 1. Not extracted; please refer to original document.

For each task, we downloaded and directly used the vanilla Transformer code offered by the authors [45] and compared the performance before and after replacing the multi-head attention layers with top-128 attention, using identical hyperparameters for both cases (details in §A.1). 3 Test accuracy measured at the training checkpoint with the highest accuracy on the development set is reported in Table 1 and the learning curves on the development and test sets are shown in Fig. 4 . On IMDb and AAN, the performance of top-128 is comparable or better than vanilla attention. For ListOps, there is a minor drop in performance (1.5 points), but learning curves ( Figure 4a ) exhibit similar behaviour.

Figure 4: Learning curves of vanilla and top-128 attention on Long Range Arena (§4.1).

Thus, top-k attention, even for k as small as 3% of the number of keys, results in a performance very similar to that of vanilla attention. This shows that an exact and sparse top-k solution is a high-quality approximation for vanilla attention at multi-head attention layers.

4.2 Language Modeling

We further ascertain the findings of §4.1 via language modeling on WikiText-103 [26] using a 6-layer Transformer decoder with 156M parameters. Using an input length of 1024, we trained two models with vanilla and top-64 attentions at the self-attention layers, obtaining test perplexity scores of 30.96 and 30.51 respectively, slightly better in case of top-64 (details in §A.3).

4.3 Zero-Shot Inference With Unifiedqa

We have established that the performance of top-k attention is comparable to vanilla attention when training the model from scratch. In this set-up, several recently-proposed approaches have also reported competitive performances [45] . Now, we consider a different and more practical setup, where the starting point is using an already pre-trained language model [10, 34] . As such models were trained using vanilla attention, replacing it with a new attention variant typically requires a corrective pre-training stage to allow the model weights to adjust to the new variant, which can be expensive for large models. For example, [16, 29] have shown that using random features without corrective pre-training leads to high error rates in a language modeling task. Moreover, as explained in §2.1, most past methods are incompatible with feed-forward layers. In the subsequent experiments we show that it is possible to replace vanilla with top-k attention, at multi-head attention and feed-forward layers, and perform inference and fine-tuning without any need for such correction.

First, we compare the performance of UNIFIEDQA [20] before and after replacing its feed-forward layers with our implementation of top-k attention and directly performing inference on 12 different question answering (QA) datasets without any training. UNIFIEDQA is a T5-based [34] model with 11B parameters [34] , fine-tuned on a weighted mixture of QA datasets. The 12 datasets include diverse domains, such as science questions, factoid questions over Wikipedia, commonsense questions, etc. Details regarding the datasets and metrics can be found in §A.2. Table 2 shows the results for increasing values of k, where the feed-forward dimension of the model is 65, 536. We observe that already when k = 256 and k = 512, i.e., less than 1% of the number of keys, performance is comparable to vanilla Transformer. When k = 4, 096 (6% of the number of keys), performance is equal or better than vanilla Transformer on all tasks. This highlights the plug-and-play property of top-k attention, which can be used without any additional training.

Table 2: Exact match scores of UNIFIEDQA on development sets with top-k attention at feed-forward layers. Notation: AI2 science elementary (AI2 elem.), AI2 science middle (AI2 mid.), ARC challenging (ARC chal.), CommonsenseQA (CSQA), NarrativeQA (NarQA), OpenbookQA (OBQA).

4.4 Zero-Shot Inference With Bert

To verify that the plug-and-play property of top-k attention also holds at self-attention layers, we downloaded a BERT-large-uncased-whole-word-masking checkpoint [10] already fine-tuned on SQuAD v1 [36] and evaluated its performance on the development set before and after replacing its self-attention layers with top-k attention. For k as low as 16 (4% of input length), we only saw a minor decrease in the exact match scores (86.9 → 86.2). Moreover, to empirically verify that dense approximations of vanilla attention (Performer, RFA, etc) indeed require corrective pre-training, we repeated the measurement using Performer attention with 256 features, obtaining a score of 0.38.

4.5 T5 Finetuning

Having established the plug-and-play property of top-k attention in zero-shot inference ( §4.3, §4.4), we now show the effectiveness of top-k attention when fine-tuning a model, and that there are no unforeseen issues stemming from training under high sprasity. Here, we use T5-base rather than T5-11B and evaluate on five QA datasets (and not 12) due to computational constraints.

Similar to §4.3, we replace the feed-forward layers of T5-base, which has feed-forward dimension 3072, with our implementation of top-256 attention and fine-tuned on multiple QA datasets. As summarized in Table 3 , we found that the performance of top-256 attention was again comparable to vanilla attention on BoolQ, CommonsenseQA and ROPES with a minor loss in performance on MCTest (81.2 → 79.4) and OpenbookQA (58.8 → 58.0).

Table 3: Exact match scores on development sets. “Finetuning” denotes model was finetuned on the dataset, else it was evaluated directly without any training. All models use vanilla feed-forward layers except the ones that say top-256 (§4.5).

5 Discussion

Related work Our work follows a long line of works on efficient Transformers (see §1). Our method employs three main ideas: (a) computing the top-k attention scores for each query (b) grouping the queries into chunks and processing these sequentially (c) caching only a part of the activations for the backward pass. Top-k operation was used at self-attention layers by [55] to show improved model performance, attributed to the removal of irrelevant information in the context. We use it to reduce the resource usage of multi-head attention and feed-forward layers. Processing query chunks sequentially was also used in Reformer [21] as activations are not cached. But in that case, by replacing vanilla residual connections in the Transformer with reversible connections [14] . Similar to the explanation provided in §2.2, these require an extra implicit forward pass during the backward pass and do not provide the compute and memory savings we get from our top-k specific backward pass ( §2.2). Secondly, replacing residual connections with reversible ones changes the function computed by the model and would require corrective pre-training to be used with BERT, T5, etc ( §4.3- §4.5).

Limitations and future work As our method requires computing inner products of all queries and keys, it has a quadratic compute requirement. As seen in our pseudo-code ( §2.2), there are four matrix products (Lines 8, 15, 21, 25) involving a large sparse matrix and a small dense one. Our current implementation does not leverage this sparsity and hence is as slow as vanilla attention. While future devices might allow faster sparse-dense products, in the immediate future, one can leverage block-sparse kernels [5, 47] which have been successfully used for such products [37] .

Conclusion In this work, we proposed a memory-efficient and accurate sparse approximation of the primary sub-layers of a Transformer, benchmarked the resulting resource savings, and verified its quality and unique advantages, on a wide range of downstream tasks and evaluation set-ups.

A.1 Details Of Long Range Arena

ListOps [28] aims to diagnose the capability of modelling hierarchically structured data. Given a sequence of operations on single-digit integers, the model predicts the solution, also a single-digit integer modeled as a 10-way classification. Character-level text classification with the IMDb movie review dataset [25] is a binary sentiment classification task. In the character-level document retrieval with the ACL Anthology Network (AAN) [33] , the model classifies if there is a citation between a pair of papers.

We used the code and pre-processed data provided by the authors of Long Range Arena [45] and default model configurations. For each task, we used identical hyperparameters for vanilla and top-k attentions (Table 4 ) and used at most two Nvidia A100 for each run.

Table 4: Hyperparameters for LRA tasks (§4.1). Other hyperparameters were used as provided at https: //github.com/google-research/long-range-arena.

A.2 Details Of Unifiedqa Inference & T5 Finetuning

We used Hugging Face's Transformers library [51] for these experiments. Authors of UNIFIEDQA collected and pre-processed several QA datasets into a common format: "QUESTION \n CHOICES \n CONTEXT". We downloaded this data by following the instructions provided by the authors 4 and used it for the UNIFIEDQA inference experiments ( §4.3). Some statistics are shown in Table 5 . Longer inputs were truncated to 512 tokens.

Table 5: Statistics of the pre-processed datasets used in §4.3 and §4.5. training samples eval samples T5 tokens per sample(90th percentile)

Given an instance from the pre-processed data, we computed the exact match score of a prediction with respect to the list of provided answers via the SQuAD v1 evaluation script [36] . For the T5 experiments ( §4.5), we used a slightly different input format. Given an instance in the UNIFIEDQA format, we formed the modified instance as "question: QUESTION context: CHOICES \n CONTEXT". WikiText-103 is a language modeling task based on English Wikipedia. We used the language modeling framework provided by Faiseq 5 and hyperparameters in Table 7 . The details of Adam optimizer are β 1 =0.9, β 2 =0.98, weight-decay: 0.01, CLIP: none, LR schedule: inverse square root. During evaluation on test set, dataset is chunked into segments of length 1024 and perplexity is computed over each segment normally without access to other segments.

Table 7. Not extracted; please refer to original document.

A.4 Benchmarking Details

Benchmarking ( §3) was done in PyTorch 1.8.1. For each run, we sampled a batch of random 32-bit input vectors and a backward pass was performed using the mean of the output elements as the loss. The part of code that was timed was enclosed within torch.cuda.synchronize() to ensure all CUDA threads finished. Memory usage was measured using torch.cuda.max_memory_reserved(). On Nvidia A100, any internal casting to TF32 was explicitly disabled.

We considered the option of performing matrix products involving large sparse matrices (Lines 8, 15, 21, 25 in our pseudo-code ( §2.2)) by representing them in torch.sparse_coo_tensor format and using the torch.sparse framework to explicitly leverage the sparsity. Unfortunately, we could not obtain encouraging results even for k = 1% of number of keys ( Figure 5 ) and plan on experimenting with block-sparse kernels [47] in the near future.

Figure 5: Memory and time required for a combined forward and backward pass on a single feed-forward layer using random inputs.

we omit this in rest of our presentation as Q can be scaled by 1/λ beforehand.

https://github.com/google-research/long-range-arena

https://github.com/allenai/unifiedqa

https://github.com/pytorch/fairseq