Go To:

Paper Title Paper Authors Table Of Contents Abstract References
Home
Report a problem with this paper

Transformer Feed-Forward Layers Are Key-Value Memories

Authors

Abstract

Feed-forward layers constitute two-thirds of a transformer model’s parameters, yet their role in the network remains under-explored. We show that feed-forward layers in transformerbased language models operate as key-value memories, where each key correlates with textual patterns in the training examples, and each value induces a distribution over the output vocabulary. Our experiments show that the learned patterns are human-interpretable, and that lower layers tend to capture shallow patterns, while upper layers learn more semantic ones. The values complement the keys’ input patterns by inducing output distributions that concentrate probability mass on tokens likely to appear immediately after each pattern, particularly in the upper layers. Finally, we demonstrate that the output of a feed-forward layer is a composition of its memories, which is subsequently refined throughout the model’s layers via residual connections to produce the final output distribution.

1 Introduction

Transformer-based language models (Vaswani et al., 2017) are at the core of state-of-the-art natural language processing (Devlin et al., 2019; Brown et al., 2020) , largely due to the success of self-attention. While much literature has been devoted to analyzing the function of self-attention layers (Voita et al., 2019; Clark et al., 2019; Vig and Belinkov, 2019) , they account for only a third of a typical transformer's parameters (4d 2 per layer, where d is the model's hidden dimension). Most of the parameter budget is spent on positionwise feed-forward layers (8d 2 per layer), yet their role remains under-explored. What, if so, is the function of feed-forward layers in a transformer language model?

k 1 k 2 k dm v 1 v

x 4 x 3 x 2 x 1 x' 5 x' 4 x' 3 x' 2 x' 1 v dm

Transformer Layers

Transformer layers Figure 1 : An illustration of how a feed-forward layer emulates a key-value memory. Input vectors (here, x 5 ) are multiplied by keys to produce memory coefficients (e.g., the memory coefficient for v 1 is 0.2), which then weigh distributions over the output vocabulary, stored in the values. The feed-forward layer's output is thus the weighted sum of its values.

Figure 1: An illustration of how a feed-forward layer emulates a key-value memory. Input vectors (here, x5) are multiplied by keys to produce memory coefficients (e.g., the memory coefficient for v1 is 0.2), which then weigh distributions over the output vocabulary, stored in the values. The feed-forward layer’s output is thus the weighted sum of its values.

We show that feed-forward layers emulate neural memories (Sukhbaatar et al., 2015) , where the first parameter matrix in the layer corresponds to keys, and the second parameter matrix to values. Figure 1 shows how the keys (first parameter matrix) interact with the input to produce coefficients, which are then used to compute a weighted sum of the values (second parameter matrix) as the output. While the theoretical similarity between feedforward layers and key-value memories has previously been suggested by Sukhbaatar et al. (2019) , we take this observation one step further, and analyze the "memories" that the feed-forward layers store.

We find that each key correlates with a specific set of human-interpretable input patterns, such as n-grams or semantic topics. For example, k 2 in Figure 1 is triggered by inputs that describe a period of time and end with "a". Simultaneously, we observe that each value can induce a distribution over the output vocabulary, and that this distribution correlates with the next-token distribution of the corresponding keys in the upper layers of the model. In the above example, the corresponding value v 2 represents a distribution that puts most of its probability mass on the word "while". Lastly, we analyze how the language model, as a whole, composes its final prediction from individual memories. We observe that each layer combines hundreds of active memories, creating a distribution that is qualitatively different from each of its component memories' values. Meanwhile, the residual connection between layers acts as a refinement mechanism, gently tuning the prediction at each layer while retaining most of the residual's information.

In conclusion, our work sheds light on the function of feed-forward layers in transformer-based language models. We show that feed-forward layers act as pattern detectors over the input across all layers, and that the final output distribution is gradually constructed in a bottom-up fashion.

2 Feed-Forward Layers As Unnormalized Key-Value Memories

Feed-forward layers A transformer language model (Vaswani et al., 2017) is made of intertwined self-attention and feed-forward layers. Each feed-forward layer is a position-wise function, processing each input vector independently. Let x ∈ R d be a vector corresponding to some input text prefix. We can express the feed-forward layer FF(•) as follows (bias terms are omitted):

EQUATION (1): Not extracted; please refer to original document.

Here, K, V ∈ R dm×d are parameter matrices, and f is a non-linearity such as ReLU.

Neural memory A neural memory (Sukhbaatar et al., 2015) consists of d m key-value pairs, which we call memories. Each key is represented by a d-dimensional vector k i ∈ R d , and together form the parameter matrix K ∈ R dm×d ; likewise, we define the value parameters as V ∈ R dm×d . Given an input vector x ∈ R d , we compute a distribution over the keys, and use it to compute the expected value:

p(k i | x) ∝ exp(x • k i ) MN(x) = dm i=1 p(k i | x)v i

With matrix notation, we arrive at a more compact formulation:

EQUATION (2): Not extracted; please refer to original document.

Feed-forward layers emulate neural memory Comparing equations 1 and 2 shows that feedforward layers are almost identical to key-value neural memories; the only difference is that neural memory uses softmax as the non-linearity f (•), while the canonical transformer does not use a normalizing function in the feed-forward layer.

The hidden dimension d m is essentially the number of memories in the layer, and the activation m = f (x • K ), commonly referred to as the hidden layer, is a vector containing an unnormalized non-negative coefficient for each memory. We refer to each m i as the memory coefficient of the ith memory cell. Sukhbaatar et al. (2019) make an analogous observation, and incorporate the parameters of the feed-forward layers as persistent memory cells in the self-attention layers. While this reparameterization works in practice, the experiment does not tell us much about the role of feed-forward layers in the canonical transformer. If transformer feed-forward layers are indeed key-value memories, then what memories do they store?

We conjecture that each key vector k i captures a particular pattern (or set of patterns) in the input sequence (Section 3), and that its corresponding value vector v i represents the distribution of tokens that follows said pattern (Section 4).

3 Keys Capture Input Patterns

We posit that the key vectors K in feed-forward layers act as pattern detectors over the input sequence, where each individual key vector k i corresponds to a specific pattern over the input prefix x 1 , . . . , x j . To test our claim, we analyze the keys of a trained language model's feed-forward layers. We first retrieve the training examples (prefixes of a sentence) most associated with a given key, that is, the input texts where the memory coefficient is highest. We then ask humans to identify patterns within the retrieved examples. For almost every

Pattern

Example trigger prefixes k 1

449

Ends with "substitutes" (shallow)

At the meeting, Elton said that "for artistic reasons there could be no substitutes In German service, they were used as substitutes Two weeks later, he came off the substitutes

k 6 2546

Military, ends with "base"/"bases" (shallow + semantic)

On 1 April the SRSG authorised the SADF to leave their bases Aircraft from all four carriers attacked the Australian base Bombers flying missions to Rabaul and other Japanese bases key k i in our sample, a small set of well-defined patterns, recognizable by humans, covers most of the examples associated with the key.

k

3.1 Experiment

We conduct our experiment over the language model of Baevski and Auli 2019 We randomly sample 10 keys per layer (160 in total).

Retrieving Trigger Examples

We assume that patterns stored in memory cells originate from examples the model was trained on. Therefore, given a key k i that corresponds to the i-th hidden dimension of the -th feed-forward layer, we compute the memory coefficient ReLU(x j • k i ) for every prefix x 1 , . . . , x j of every sentence from the WikiText-103's training set. 1 Then, we retrieve the top-t trigger examples, that is, the t prefixes whose representation at layer yielded the highest inner product with k i .

Pattern analysis For each key, we provide human experts (NLP graduate students) with the retrieved top-25 prefixes and ask them to (a) identify repetitive patterns that occur in at least 3 prefixes, (b) describe each recognized pattern, and (c) classify each recognized pattern as "shallow" (e.g. recurring n-grams) or "semantic" (recurring topic).

To assure that every pattern is grounded in at least 3 prefixes, we instruct the experts to specify, for 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 layer each of the top-25 prefixes, which pattern(s) it contains. A prefix may be associated with multiple (shallow or semantic) patterns. Table 1 shows example patterns. A fullyannotated example of the top-25 prefixes from a single memory key is shown in Appendix A.

Table 1: Examples of human-identified patterns that trigger different memory keys.

3.2 Results

Memories are associated with humanrecognizable patterns Experts were able to identify at least one pattern for every key, with an average of 3.6 identified patterns per key. Furthermore, the vast majority of retrieved prefixes (65%-80%) were associated with at least one identified pattern ( Figure 2) . Thus, the top examples triggering each key share clear patterns that humans can recognize. Shallow layers detect shallow patterns Comparing the amount of prefixes associated with shallow patterns and semantic patterns (Figure 2 ), the lower layers (layers 1-9) are dominated by shallow patterns, often with prefixes that share the last word (e.g. k 1 449 in Table 1 ). In contrast, the upper layers (layers 10-16) are characterized by more semantic patterns, with prefixes from similar contexts but without clear surface-form similarities (e.g. k 16 1935 in Table 1 ). This observation corroborates recent findings that lower (upper) layers in deep contextualized models encode shallow (semantic) features of the inputs (Peters et al., 2018; Jawahar et al., 2019; Liu et al., 2019) .

Figure 2: Breakdown of the labels experts assigned to trigger examples in each layer. Some examples were not associated with any pattern (“not-covered”).

To further test this hypothesis, we sample 1600 random keys (100 keys per layer) and apply local modifications to the top-50 trigger examples of every key. Specifically, we remove either the first, last, or a random token from the input, and measure how this mutation affects the memory coefficient. Figure 3 shows that the model considers the end of an example as more salient than the beginning for predicting the next token. In upper layers, removing the last token has less impact, supporting our conclusion that upper-layer keys are less correlated with shallow patterns.

Figure 3: Relative change in memory coefficient caused by removing the first, the last, or a random token from the input.

4 Values Represent Distributions

After establishing that keys capture patterns in training examples, we turn to analyze the information stored in their corresponding values. We show that each value v i can be viewed as a distribution over the output vocabulary, and demonstrate that this distribution complements the patterns in the 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 corresponding key k i in the model's upper layers (see Figure 1 ). We begin by converting each value vector v i into a probability distribution over the vocabulary by multiplying it by the output embedding matrix E and applying a softmax: 2

p i = softmax(v i • E).

The probability distribution p i is uncalibrated, since the value vector v i is typically multiplied by the input-dependent memory coefficient m i , changing the skewness of the output distribution. That said, the ranking induced by p i is invariant to the coefficient, and can still be examined. This conversion assumes (naïvely) that all model's layers operate in the same embedding space.

For every layer and memory dimension i, we compare the top-ranked token according to v i , (argmax(p i )) to the next token w i in the top-1 trigger example according to k i (the example whose memory coefficient for k i is the highest). Figure 4 shows the agreement rate, i.e. the fraction of memory cells (dimensions) where the value's top prediction matches the key's top trigger example (argmax(p i ) = w i ). It can be seen that the agreement rate is close to zero in the lower layers (1-10), but starting from layer 11, the agreement rate quickly rises until 3.5%, showing higher agreement between keys and values on the identity of the top-ranked token.

Figure 4: Agreement rate between the top-ranked token based on the value vector v`i , and the next token of the top-ranked trigger example associated with the key vector k`i .

Next, we take the next token of k i 's top-1 trigger example (w i ), and find where it ranks in the value vector's distribution p i . Figure 5 shows that the rank of the next token of a trigger example increases through the layers, meaning that w i tends to get higher probability in the upper layers.

Figure 5: Distribution of the rank of the next-token in the top-1 trigger example of k`i (w ` i ), according to the ranking induced by the value vector v`i . We cut the tail of the distribution, which stretches up to the vocabulary size (∼270K tokens).

To examine if we can automatically detect values with high agreement rate, we analyze the probability of the values' top prediction, i.e., (max(p i )). Figure 6 shows that although these distributions are not calibrated, distributions with higher maximum probabilities are more likely to agree with their key's top trigger example. We then take the 100 values with highest probability across all layers and dimensions (97 out of the 100 are in the upper layers, 11-16), and for each value v i , analyze the top-50 trigger examples of k i . We find that in almost half of the values (46 out of 100), there is at least one trigger example that agrees with the value's top prediction. Examples are provided in Table 2 .

Figure 6: Agreement rate (between the top-ranked token based on the value vector v`i and the next token of the top-ranked trigger example associated with the key vector k`i ) as a function of the maximal probability assigned by the value vector.
Table 2: Example values, their top prediction, and the fraction of their key’s top-50 trigger examples that agree with their prediction.

When viewed as distributions over the output vocabulary, results show that values in the upper layers tend to assign higher probability to the nexttoken of examples triggering the corresponding keys. This suggests that memory cells often store information on how to directly predict the output (the distribution of the next word) from the input (patterns in the prefix). Conversely, the lower layers do not exhibit a clear correlation between the patterns stored in keys and the corresponding distributions induced by the values. A possible explanation is that the lower layers do not operate in the same embedding space, and therefore, projecting values onto the vocabulary using the output embeddings does not produce distributions that follow the trigger examples. However, our results imply that some intermediate layers do operate in the same or similar space to upper layers (exhibiting some agreement), which in itself is non-trivial. We leave further exploration of this phenomenon to future work. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 layer 0 5000 10000 15000 20000 25000 30000 rank distribution Figure 5 : Distribution of the rank of the next-token in the top-1 trigger example of k i (w i ), according to the ranking induced by the value vector v i . We cut the tail of the distribution, which stretches up to the vocabulary size (∼270K tokens).

8 . 9 e -5 1 . 3 e -4 1 . 8 e -4 2 . 2 e -4 2 . 7 e -4 3 . 1 e -4 3 . 6 e -4 4 . e -4 top prediction probability Figure 6 : Agreement rate (between the top-ranked token based on the value vector v i and the next token of the top-ranked trigger example associated with the key vector k i ) as a function of the maximal probability assigned by the value vector.

5 Aggregating Memories

So far, our discussion has been about the function of a single memory cell in feed-forward layers. How does the information from multiple cells in multiple layers aggregate to form a model-wide prediction? We show that every feedforward layer combines multiple memories to produce a distribution that is qualitatively different from each of its component memories' value distributions (Section 5.1). These layer-wise distributions are then combined via residual connections in a refinement process, where each feed-forward layer updates the residual's distribution to finally form the model's output (Section 5.2).

5.1 Intra-Layer Memory Composition

The feed-forward layer's output can be defined as the sum of value vectors weighted by their memory coefficients, plus a bias term:

y = i ReLU(x • k i ) • v i + b .

If each value vector v i contains information about the target token's distribution, how is this information aggregated into a single output distribution? To find out, we analyze the behavior of 4,000 randomly-sampled prefixes from the validation set. Here, the validation set is used (rather than the training set used to find trigger examples) since we are trying to characterize the model's behavior at inference time, not find the examples it "memorizes" during training.

We first measure the fraction of "active" memories (cells with a non-zero coefficient). Figure 7 shows that a typical example triggers hundreds of memories per layer (10%-50% of 4096 dimensions), but the majority of cells remain inactive. Interestingly, the number of active memories drops towards layer 10, which is the same layer in which semantic patterns become more prevalent than shallow patterns, according to expert annotations (see Section 3, Figure 2) . While there are cases where a single memory cell dominates the output of a layer, the majority of outputs are clearly compositional. We count the number of instances where the feed-forward layer's top prediction is different from all of the memories' top predictions. Formally, we denote:

Figure 7: The fraction of active memories (i.e., with positive memory coefficient) out of 4096 memories in every layer, for a random sample of 4,000 examples.

top(h) = argmax(h • E)

as a generic shorthand for the top prediction from the vocabulary distribution induced by the vector h, and compute the number of examples where the following condition holds: Figure 8 shows that, for any layer in the network, the layer's final prediction is different than every one of the memories' predictions in at least ∼68% of the examples. Even in the upper layers, where the memories' values are more correlated with the output space (Section 4), the layer-level prediction is typically not the result of a single dominant memory cell, but a composition of multiple memories. We further analyze cases where at least one memory cell agrees with the layer's prediction, and find that (a) in 60% of the examples the target token is a common stop word in the vocabulary (e.g. "the" or "of"), and (b) in 43% of the cases the input prefix has less than 5 tokens. This suggests that very common patterns in the training data might be "cached" in individual memory cells, and do not require compositionality.

Figure 8: The fraction of examples in a random sample of 4,000 examples where the layer’s prediction is different from the prediction of all of its memories.

∀i : top(v i ) = top(y )

5.2 Inter-Layer Prediction Refinement

While a single feed-forward layer composes its memories in parallel, a multi-layer model uses the residual connection r to sequentially compose pre-1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 layer dictions to produce the model's final output: 3

x = LayerNorm(r )

y = FF(x ) o = y + r

We hypothesize that the model uses the sequential composition apparatus as a means to refine its prediction from layer to layer, often deciding what the prediction will be at one of the lower layers. To test our hypothesis, we first measure how often the probability distribution induced by the residual vector r matches the model's final output o L (L being the total number of layers): Figure 9 shows that roughly a third of the model's predictions are determined in the bottom few lay- 3 The residual propagates information from previous layers, including the transformer's self-attention layers. ers. This number grows rapidly from layer 10 onwards, implying that the majority of "hard" decisions occur before the final layer. We also measure the probability mass p that each layer's residual vector r assigns to the model's final prediction: Figure 10 shows a similar trend, but emphasizes that it is not only the top prediction's identity that is refined as we progress through the layers, it is also the model's confidence in its decision.

Figure 9: Fraction of examples in each layer, where the residual’s top prediction matches the model’s output.
Figure 10: Probability of the token output by the model according to the residual of each layer.

top(r ) = top(o L )

w = top(o L ) p = softmax(r • E) p = p w

To better understand how the refinement process works at each layer, we measure how often the residual's top prediction changes following its interaction with the feed-forward layer (top(r ) = top(o )), and whether this change results from the feed-forward layer overriding the residual (top(o ) = top(y )) or from a true composition (top(r ) = top(o ) = top(y )). Figure 11 shows the breakdown of different cases per layer. In the vast majority of examples, the residual's top prediction ends up being the model's prediction (residual+agreement). In most of these cases, the feed forward layer predicts something different (residual). Perhaps surprisingly, when the residual's prediction does change (composition+ffn), it rarely changes to the feedforward layer's prediction (ffn). Instead, we ob-serve that composing the residual's distribution with that of the feed-forward layer produces a "compromise" prediction, which is equal to neither (composition). This behavior is similar to the intra-layer composition we observe in Section 5.1. A possible conjecture is that the feedforward layer acts as an elimination mechanism to "veto" the top prediction in the residual, and thus shifts probability mass towards one of the other candidate predictions in the head of the residual's distribution.

Figure 11: Breakdown of examples by prediction cases: the layer’s output prediction matches the residual’s prediction (residual), matches the feed-forward layer’s prediction (ffn), matches both of the predictions (agreement), or none of the predictions (composition). By construction, there are no cases where the residual’s prediction matches the feed-forward layer’s prediction and does not match the output’s prediction.

Finally, we manually analyze 100 random cases of last-layer composition, where the feed-forward layer modifies the residual output in the final layer. We find that in most cases (66 examples), the output changes to a semantically distant word (e.g., "people" → "same") and in the rest of the cases (34 examples), the feed-forward layer's output shifts the residual prediction to a related word (e.g. "later" → "earlier" and "gastric" → "stomach"). This suggests that feed-forward layers tune the residual predictions at varying granularity, even in the last layer of the model.

6 Related Work

A lot of attention has been given to demystifying the operation of neural NLP models. An extensive line of work targets neuron functionality in general, extracting the properties that neurons and subsets of neurons capture (Durrani et al., 2020; Dalvi et al., 2019; Rethmeier et al., 2020; Mu and Andreas, 2020; Vig et al., 2020) , irrespective of the model architecture or neurons' position in it. Jacovi et al. (2018) analyzed CNN architectures in text classification and showed that they extract key n-grams from the inputs.

The study of the transformer architecture has focused on the role and function of self-attention layers (Voita et al., 2019; Clark et al., 2019; Vig and Belinkov, 2019) and on inter-layer differences (i.e. lower vs. upper layers) (Tenney et al., 2019; Jawahar et al., 2019) . To date, the role of feedforward layers remains under-explored.

Our method of characterizing the functionality of memory cells based on examples that trigger maximal activations has been used previously in NLP (Rethmeier et al., 2020) and vision (Erhan et al., 2009) .

7 Discussion And Conclusion

Understanding how and why transformers work is crucial to many aspects of modern NLP, including model interpretability, data security, and development of better models. Feed-forward layers account for most of a transformer's parameters, yet little is known about their function in the network.

In this work, we propose that feed-forward layers emulate key-value memories, and provide a set of experiments showing that: (a) keys are correlated with human-interpretable input patterns; (b) values, mostly in the model's upper layers, induce distributions over the output vocabulary that correlate with the next-token distribution of patterns in the corresponding key; and (c) the model's output is formed via an aggregation of these distributions, whereby they are first composed to form individual layer outputs, which are then refined throughout the model's layers using residual connections.

Our findings raise several key questions:

Layer embedding space. We observe a correlation between value distributions over the output vocabulary and key patterns, that increases from lower to upper layers (Section 4). Is this because the layer's output space transforms across layers? If so, how and why? We note that this possible transformation cannot be explained solely by the function of feed-forward layers as keyvalue memories: if the model only did a series of key-value look-ups and value-distribution aggregation via weighted addition, then a single, unifying embedding space would appear more natural. Thus, the transformation might have to do with the interplay between feed-forward layers and selfattention layers, and would seem to underlie some of the core aspects of transformers operation.

Beyond language modeling. In this work, we focused on a single transformer-based languagemodel decoder. Our formulation of feed-forward networks as key-value memories generalizes to any transformer model, e.g. BERT encoders and neural translation models. Yet, more experiments are needed to verify that our empirical observations hold in these settings.

Overall, this work raises new research questions about the inner workings of transformers, opening the door to further research that will shed light on the capabilities and limitations of modern NLP models.

1 It Requires Players To Press 1

The video begins at a press 1 The first player would press 1 Ivy, disguised as her former self, interrupts a Wayne Enterprises press 1 The video then cuts back to the press 1 The player is able to press Leto switched 1

In the Nintendo DS version, the player can choose to press 1 In-house engineer Nick Robbins said Shields made it clear from the outset that he (Robbins) "was just there to press 1 She decides not to press 1 she decides not to press 1 Originally Watson signaled electronically, but show staff requested that it press 1 At post-game press 1 In the buildup to the game, the press 2 Hard to go back to the game after that news 1

In post-trailer interviews, Bungie staff members told gaming press Space Gun was well received by the video game 1

As Bong Load struggled to press At Michigan, Clancy started as a quarterback, switched 1

Crush used his size advantage to perform a Gorilla press 1,2 Groening told the press 1 Creative director Gregoire argued that existing dance games were merely instructing players to press 1,2 Mattingly would be named most outstanding player that year by the press 1 At the post-match press 1,2 The company receives bad press ID Description shallow / semantic 1 Ends with the word "press" shallow 2 Press/news related semantic ) , which are classified as "shallow" or "semantic" (bottom table).

We segment training examples into sentences to simplify the annotation task and later analyses.

This is a simplification; in practice, we use the adaptive softmax(Baevski and Auli, 2019) to compute probabilities.

Table 3: A pattern annotation of trigger examples for the cell memory k5895. Trigger examples are annotated with repetitive patterns (upper table), which are classified as “shallow” or “semantic” (bottom table).