Go To:

Paper Title Paper Authors Table Of Contents Abstract References
Home
Report a problem with this paper

Value-aware Approximate Attention

Authors

Abstract

Following the success of dot-product attention in Transformers, numerous approximations have been recently proposed to address its quadratic complexity with respect to the input length. However, all approximations thus far have ignored the contribution of the value vectors to the quality of approximation. In this work, we argue that research efforts should be directed towards approximating the true output of the attention sub-layer, which includes the value vectors. We propose a valueaware objective, and show theoretically and empirically that an optimal approximation of a value-aware objective substantially outperforms an optimal approximation that ignores values, in the context of language modeling. Moreover, we show that the choice of kernel function for computing attention similarity can substantially affect the quality of sparse approximations, where kernel functions that are less skewed are more affected by the value vectors.

1 Introduction

The Transformer architecture (Vaswani et al., 2017) has been widely successful in a wide range of natural language processing tasks, including machine translation (Edunov et al., 2018) , language modeling (Roy et al., 2020) , question-answering (Karpukhin et al., 2020) , and many more. Transformers pre-trained on large amounts of text with a language modeling (LM) objective, have become the standard in NLP, exhibiting surprising amounts of linguistic and world knowledge (Peters et al., 2018; Devlin et al., 2019; Petroni et al., 2019; Hewitt and Manning, 2019; Roberts et al., 2020) .

The contextualizing component of the Transformer is the attention layer where all positions in an input sequence of length L aggregate information from the entire sequence in parallel. At its core, given L query, key and value vec-tors, the dot-product attention function outputs 1 softmax(QK )V where the softmax function is applied row-wise on the matrix QK ∈ R L×L , consisting of similarity scores of the query-key pairs. Unfortunately, computing Ω(L • L) similarity scores is prohibitive for long sequences.

To alleviate this, past work proposed to compute an approximation of softmax(QK ). One major line of research focused on sparse attention variants, where only a few similarity scores are computed per position, and the rest are ignored. Methods differ by which query-key pairs are selected Ye et al., 2019; Roy et al., 2020; Kitaev et al., 2020; Beltagy et al., 2020; Gupta and Berant, 2020) . A second line of research explored dense variants (Katharopoulos et al., 2020; Wang et al., 2020; Tay et al., 2020a) (cf. (Tay et al., 2020b ) for a survey). E.g., instead of computing the attention scores exactly for only a few query-key pairs, (Choromanski et al., 2020) compute an approximation of scores for all pairs.

In this work, we point to a lacuna in current research on efficient Transformers. While recent work focused on approximating the attention scores softmax(QK ), the true target of approximation should be the output of the attention sub-layer, namely H = softmax(QK )V , which also includes the value vectors, V . We show that ignoring value vectors leads to unwarranted consequences both theoretically and empirically.

To demonstrate the importance of value-aware approximation, we analyze optimal sparse attention, that is, the case where, in hindsight, the model computes dot product similarity only with the most similar key vectors, while still ignoring the value vectors. We show that in the popular masked language modeling (MLM) setup, optimal sparse attention dramatically under-performs compared to an optimal approximation of the true output of the attention sub-layer, H, leading to an error increase of 8-20 points. Next, by theoretically focusing on the case where queries compute similarity to the single most similar key vector, we show that approximating softmax(QK ) is equivalent to approximating H when the value vectors V satisfy strong orthogonality and norm constraints. Conversely, when they do not, ignoring V can lead unbounded approximation error.

Second, we discuss the kernel-based view of attention, where efficiency is gained by replacing the exponential kernel (corresponding to softmax) with other kernel functions (Katharopoulos et al., 2020) . We theoretically show that while in the exponential kernel case (corresponding to softmax), the effect of the norm of the value vectors is potentially small, switching to other kernels can dramatically increase the importance of the value vectors. We empirically test this by comparing optimal sparse attention given different kernel functions, and see that indeed approximation quality decreases when replacing the exponential kernel,

To conclude, we theoretically and empirically show that approximating the attention score matrix alone is insufficient, and propose that the research community should instead approximate the true output of the sub-attention layer, which importantly includes value vectors. Our code and trained models are available at https://github.com/ ag1988/value_aware_attn.

2 Background

We review the kernel-based view of attention (Tsai et al., 2019) , which will be instructive in §3.

Generalized Attention Let κ(x, y) = Φ(x), Φ(y) ≥ 0 be a kernel function with feature map Φ :

R d → H for some implicit reproducing kernel Hilbert space (RKHS) H. Given a query vector q, keys k 1 , . . . , k L , values v 1 , . . . , v L , all in R d : att κ (q, k 1 , . . . , v 1 , . . .) = L i=1 κ(q, k i )v i L i=1 κ(q, k i ) , (1)

where the normalization induces a probability distribution α over the value vectors with

α i = κ(q, k i )/ i κ(q, k i ).

The most popular use case is the exponential kernel κ(x, y) = exp( x, y ), referred to as dot-product attention in Transformers. Some other examples include the degree-2 polynomial kernel κ(x, y) = x, y 2 and the recently proposed elu kernel Φ(x), Φ(y) with Φ(•) = 1 + ELU(•) (Katharopoulos et al., 2020) .

Given L d queries, the attention function (Eq. 1) requires computing L • L similarity scores for the query-key pairs, which is prohibitive for long sequences. Sparse attention variants relax this requirement and compute only a few similarity scores, ignoring the rest:

EQUATION (2): Not extracted; please refer to original document.

for some S ⊆ {1, . . . , L}, |S| L. Methods differ in how S is determined given the queries and keys, and include use of locality bias (Beltagy et al., 2020), global memory (Gupta and Berant, 2020) , and LSH hashing (Kitaev et al., 2020) , among others. Conversely, instead of exactly computing the attention scores only on a few query-key pairs, dense variants compute an approximation of the true kernel values for all pairs. Such methods output i β i • v i for some approximation β of the the true attention distribution α (Choromanski et al., 2020; Peng et al., 2021) .

3 Optimal Sparse Attention

Prior methods for approximating attention, ignored the contribution of the values vectors V . As the true output of the attention sub-layer also depends on V , a natural question is whether it is possible to design better approximation methods by incorporating V , and if so, how much improvement is even possible?

To answer this, we focus on sparse attention, and analyze the difference between an oracle sparse approximation that considers the value vectors, and an oracle approximation that does not. That is, we look at the difference between the two approximations from the perspective of expressivity, ignoring any memory and computational constraints. We denote an optimal value-aware approximation that uses r key vectors per query by optimal-vaware-r, and an optimal approximation that ignores value vectors by optimal-v-oblivious-r. We define optimal-v-oblivious-r as the output of Eq. 2 in which S is selected to be the r indices with the highest attention scores α i 's. This is a natural baseline since this is what current sparse methods are trying to emulate. We now explicitly derive and analyze the value-aware objective.

Value-aware objective Let o = L i=1 α i v i be a convex combination of v 1 , . . . , v L ∈ R d , corre- sponding to the true output of the attention sub- layer. Let C r = { L i=1 β i v i : ∀i β i ≥ 0, i β i = 1, |{β i : β i > 0}|

≤ r} denote the set of points in the polytope of v i 's that can be expressed as a convex combination of at most r value vectors v i . The goal of value-aware approximation is to solve for the point in the constrained region C r closest to the true output o, i.e. argminõ ∈Cr ||o −õ|| 2 . As mentioned, this solution is termed optimal-v-aware-r. We consider two extreme cases of r: r = 1 and r ≥ d + 1. For r ≥ d + 1, the Carathéodory Theorem (Bárány and Karasev, 2012)

states that o = i α i v i can be expressed as a convex combi- nation of at most d + 1 v i 's. Hence, if r ≥ d + 1

then o ∈ C r and the optimal approximation error is 0. In most popular architectures, such as BERT (Devlin et al., 2019) , d = 64

L. This means that from the point of expressivity, optimal-v-aware-65 can obtain a perfect approximation. Conversely, we will show in §4 that the performance of optimalv-oblivious-65 is substantially lower.

At the other extreme, when r = 1 (a single value vector), the above objective is equivalent to argmin i∈(1,...,L) ||o − v i || 2 and can be simplified as

argmin i ||o|| 2 + ||v i || 2 − 2 v i , o = argmin i ||v i || 2 − 2 v i , j α j v j = argmin i ||v i || 2 (0.5 − α i ) − j =i α j v i , v j .

(3) This equation induces a ranking over value vectors that depends on the value vectors themselves, in contrast to a value-oblivious ranking induced solely by attention weights α.

If v 1 , . . . , v L are orthogonal, the above equation further simplifies to argmin i ||v i || 2 (0.5 − α i ) − j =i α j • 0 = argmin i ||v i || 2 (0.5 − α i ). In this case, if some α i ≥ 0.5 or if v 1 , . . . , v

L have equal norms, this would further simplify to argmax i α i , and would therefore be independent of the valuevectors v i 's, implying that a value-oblivious approximation would work well.

But such assumptions on v 1 , . . . , v L do not hold in general and thus an approximation that only depends on α i 's can be sub-optimal. E.g., let v 1 , v 2 , v 3 be orthogonal vectors (1, 0, 0), (0, 2, 0), (0, 0, 3) respectively and let α 1 , α 2 , α 3 be 0.25, 0.35, 0.4. Then v 3 with the highest attention weight α 3 has a squared distance of 3.79 from the true output i α i v i whereas v 1 with the least attention weight α 1 has only 2.49. In this case, optimalv-aware-1 induces exactly the opposite ranking of value vectors compared to optimal-v-oblivious-1. Moreover, if we increase the value 3 in v 3 to infinity, the approximation error will also infinitely grow. This example and, in general, Eq. 3 also show that the optimal ranking can be significantly different from the one induced by α i ||v i || proposed recently by (Kobayashi et al., 2020) for obtaining better interpretability of attention models.

Effect Of Kernel Function Recently, Linear

Transformer (Katharopoulos et al., 2020) proposed to replace the existing exponential kernel with more efficient kernels. We now show that replacing the exponential kernel with a polynomial kernel can lead to a drop in quality for current sparse approximation methods.

Intuitively, because the kernel function affects the skewness of α, it also affects the difference between the ranking induced by the optimalvalue-aware approximation and the optimal-valueoblivious one. For simplicity, consider the case of orthogonal value vectors in which Eq. 3 simplifies to

argmin i ||v i || 2 (0.5 − α i ). From Eq. 1, we have α i = κ(q, k i )/ j κ(q, k j ) which is q, k i C / j q, k j C

for the degree-C polynomial kernel. For C = 0, we have α i = 1/L, which gives argmin i ||v i || 2 . In this case, the value vectors become crucial when α is uniform. On the other hand, assuming distinct inner products, for C 0 we will obtain max i α i ≥ 0.5, thereby reducing us to argmax i α i , where value vectors do not affect the approximation. The complexity of the Transformer grows exponentially with the degree C and thus in practice a low C must be used (e.g., degree-2 polynomial). In such case, α is likely to be less skewed compared to the exponential kernel and more likely to induce a sub-optimal ranking.

In the next section, we empirically verify the above observations and show a significant performance gap between value-oblivious approximations and value-aware ones.

4 Experiments

We empirically verify our observations in the context of training causal and masked language models, which are known to strongly correlate with performance on downstream applications Devlin et al., 2019) .

Figure 1: Evaluation of MLM error of ROBERTA-4096 after replacing vanilla attention with approximation schemes. Dashed line denotes error using vanilla attention.
Figure 2: Evaluation loss (base e) of optimal-v-oblivious-r oracle on the causal LM task for distinct kernel functions.

Masked LM task We form examples by sampling sequences and replacing sub-words with following the procedure in (Devlin et al., 2019) . The model is trained to maximize the log probability of the masked out tokens and we evaluate the error of the model as the percentage of masked tokens predicted incorrectly. As approximate attention becomes increasingly relevant for long sequences, we train ROBERTA-4096 on sequences of length 4096 (Fig. 3) . Training was warm-started using ROBERTA-base (Liu et al., 2019) . Full details on the experimental setup are in §A.1. After training the model for ∼ 2.5M steps, the error of the model (that is, proportion of incorrect predictions) on the evaluation set was 24.2 (compared to 26.6 for an analogous training on 512-long sequences), ensuring that tokens in ROBERTA-4096 indeed attend over longer distances and result in higher quality representations. We then replace the attention function of the trained model with various approximation schemes and evaluate the resulting model on the evaluation set.

Figure 3: Evaluation error on the MLM task using vanilla attention (computing the full attention matrix).

We first compare optimal-v-oblivious-r to optimal-v-aware-r. We know that the approximation error of value-aware approximation is 0 for r > 64. For r = 1, we exhaustively go through all possible values and choose the one that minimizes the value-aware objective. As seen in Fig. 1 and Table 1 , there is substantial gap between the two approximations. For instance, optimal-v-oblivious-65 gives an MLM error of 43.5 whereas the error of optimal-v-aware-65 is 24.2, since it can perfectly approximate full attention. Moreover, we compare optimal-v-oblivious-r to existing approximations: (a) sliding-window-r, where a position attends to r/2 positions to its left and right), (b) LSH attention (Kitaev et al., 2020) and (c) ORF attention (Choromanski et al., 2020) . Fig. 1 shows that sliding-window-r trails behind optimal-v-obliviousr. LSH attention, which tries to emulate optimal-voblivious-r, either requires a large number of hash rounds or a large chunk size. Similarly, the dense approximation, ORF, provides an unbiased approximation of the exponential kernel but suffers from high variance in practice. Table 2 : Evaluation perplexity of models using approximate attention. OVO: optimal-v-oblivious, OVA: optimal-v-aware. Causal LM task To investigate the effect of the kernel function on the quality of value-oblivious methods, we train a 6-layer Transformer LM over 512 tokens on WikiText-103 (Merity et al., 2017) (details in §A.2). We train 3 models with identical hyperparameters using the exponential, degree-2 polynomial, and elu kernels respectively and evaluate the trained models with value-aware and valueoblivious approximations. Again, optimal-v-awarer substantially outperforms optimal-v-oblivious-r (Table 2) , pointing to the potential of working on approximating the value-aware objective. More importantly, comparing the approximation quality across different kernel functions (Fig. 2) , we see that the gap between the three kernels is small when using full attention (512 keys) vectors. However, convergence is much slower for the elu kernel, and especially the degree-2 polynomial, demonstrating that the approximation based on the top-r key vectors is sub-optimal when switching to a less skewed kernel, which is more affected by the value vectors.

Table 1: MLM error of ROBERTA-4096 on the evaluation set using approximate attention described in §4. OVO: optimal-voblivious, OVA: optimal-v-aware. In LSH, each query attends to a total of r keys per hash round.
Table 2: Evaluation perplexity of models using approximate attention. OVO: optimal-v-oblivious, OVA: optimal-v-aware.

In this work, we provide theoretical and empirical evidence against current practice of focusing on approximating the attention matrix in Transformers, while ignoring the value vectors. We propose a value-aware objective and argue that the efforts to develop more efficient Transformers should consider this objective function as a research target.

A Supplemental Material

A.1 Masked LM task

The instances for the MLM task ( §4) were formed separately using the corpora listed in Table 3 . For each dataset, after appending token at the end of each document, the documents were arranged in a random order and concatenated into a single long text which was then tokenized into a list of sub-words. Depending upon the final input sequence length L of the experiment (512/4096) this list was chunked into full length L − 2 sequences which were then masked randomly following (Devlin et al., 2019) and enclosed within and tokens. To handle sequences longer than 512 tokens, the positional embeddings were used following (Gupta and Berant, 2020) . The learning curves of ROBERTA-512 and ROBERTA-4096 are in Fig. 3 . 2.67B 1.53B 1.02M BookCorpus (Zhu et al., 2015) 1.06B 1.02B 1.02M ArXiv (Cohan et al., 2018) 1.78B 1.53B 1.02M PubMed (Cohan et al., 2018) 0.47B 510M 1.02M PG19 (Rae et al., 2020) 3.06B 510M 1.02M Hyperparameters For convenience, we denote the training hyperparameters using the following abbreviations, INS: number of training instances, BSZ: number of instances in a batch, ISZ: instance size, SQL: final input sequence length after rearranging BSZ instances each of length ISZ, LR: learning rate, WRM: linear LR warm-up proportion, EP: number of epochs, STP: number of optimizer steps, GAC: gradient accumulation steps, POSq: whether (y/n) q part is included in positional embeddings. The hyperparameters are listed in Table 4 . model init BSZ ISZ SQL LR WRM EP STP POSq ROBERTA-512 ROBERTA 8 512 512 5e-6 0.1 2 2.476M n ROBERTA-4096 ROBERTA 8 512 4096 5e-6 0.1 2 2.476M y Table 4 : Training hyperparameters. Common parameters: INS=10M, dropout-rate=0.0, optimizer=Bert-Adam, β1=0.9, β2=0.98, weight-decay=0.01, max-grad-norm=5.0, seed=42, GAC=1.

Table 3: Number of tokens in the datasets used for MLM training.
Table 4: Training hyperparameters. Common parameters: INS=10M, dropout-rate=0.0, optimizer=Bert-Adam, β1=0.9, β2=0.98, weight-decay=0.01, max-grad-norm=5.0, seed=42, GAC=1.

Details of LSH attention Given L queries and L keys in R d , in each hash round, we sample a new matrix R ∈ R C 2 ×d of standard gaussians and hash the queries and keys as H R (x) = argmax([−Rx; Rx]) ∈ {1, . . . , C}. We rearrange the queries (and similarly keys) according to their hash value, breaking ties using the original position, and then chunk them into L/B chunks of B vectors each. Denoting these chunks as Q 1 , . . . , Q L/B and K 1 , . . . , K L/B , for each query in Q i we compute its similarity scores with respect to all keys in K i−1 , K i . I.e. in each hash round a query attends to r = 2B keys. For each query, these similarity scores are accumulated over different hash rounds, and at the end normalized by their sum to get normalized attention scores over the keys. As recommended in the original paper (Kitaev et al., 2020) , we use C = 2L/B = 4L/r which in practice can be sub-optimal as rearrangement destroys the original locality structure.

Details of ORF attention Given L queries and L keys in R d we divide each vector by d 1 4 to account for the temperature term in dot-product attention. For a given number F of features, we sample a random orthogonal matrix R ∈ R F ×d as described in (Saxe et al., 2013) and provided as a tensor initialization option in PyTorch. We then map each vector to the feature space as Φ(x) = 1 √ F exp Rx − ||x|| 2 2 ∈ R F where (−) and exp operations are applied element-wise. Similarity score of a query-key pair (q, k) is computed as Φ(q), Φ(k) and and is normalized by the sum of the similarity scores of q with all the keys. Computing this directly leads to numerical instability so we instead compute Φ(q) = 2 ). Appropriately scaling both sides gives, E w [exp( w, q − ||q|| 2

2 ) • exp( w, k − ||k|| 2 2 )] = exp( q, k ), which is exactly the term for the exponential kernel.

A.2 Causal Lm Task

For this task, we used the language modeling framework provided by Faiseq 2 .

Model and training details number of decoder layers: 6, hidden size: 512, head size: 64, number of model parameters: 156M, dataset: WikiText-103, training examples: 1801350, input sequence length: 512, β 1 =0.9, β 2 =0.98, weight-decay: 0.01, gradient clip-norm: none, learning rate: 0.0005, learning rate schedule: inverse square root, number of warmup updates: 4000, batch size: 128, epochs: 20, number of steps: 31520, minimum contextwindow during evaluation on test-set: 400.

Usually, the term is softmax(QK / √ d)V but √ d can be dropped via scaling of queries and keys.

https://github.com/pytorch/fairseq