There are a few versions of attention that have been proposed, but this post will consider the version discussed in Attention Is All You Need, and the subsequent transformer papers (e.g. BERT, GPT, GPT2, etc). This post will only discuss attention, rather than the full transformer architecture.

Attention, as used in transformer models, is often known as self-attention, or intra-attention, and is used to combine a sequence of vectors into a single vector. It is called “attention” due to the mechanism used in this combination – a weight is calculated for each vector in the sequence, such that all weights sum to \(1\). Each vector is scaled by its weight, and the result is the sum of these weighted vectors. Thus, “more attention” can be paid to vectors by assigning them larger weights.

The Problem

Let’s assume we have a sequence \(\textbf{w} = (w_1, w_2, \ldots, w_n)\) of tokens (which we can think of as words, for convenience). We can map these tokens to a sequence \(\textbf{x} = (x_1, x_2, \ldots, x_n)\) of continuous valued embedding vectors – each of our \(n\) words have an embedding vector. Many tasks in language processing require us to combine these vectors in some way. We may want to predict word level properties, like whether a token corresponds to a named entity. We may want to predict sequence level properties, like the positive/negative sentiment of some text. We might want to produce a summary vector, for use in an auto-regressive language model, or for the generation of a summary of the text. All of these applications (and many more) require us to combine a sequence of word embeddings somehow – if we didn’t, any down stream task would only “know” about individual words. Attention is one way of performing this combination.

Recurrent Neural Networks

The traditional way of combining sequential vectors is to use a Recurrent Neural Network (assuming you are using neural networks, and something as recent as RNNs can be considered traditional). These are sequential models, and their use makes some intuitive sense – humans read and understand text sequentially. They also have the nice property of being able to deal with sequences of variable length, which is important when dealing with natural language. However, their sequential nature also presents some problems. Most obviously, the computation time grows linearly with the length of the sequence. The sequential nature also directly leads to problems like vanishing/exploding gradients when training with backpropagation through time.

Self Attention

Attention is a method for combining vectors that solves some of the problems that RNNs suffer from (e.g. slowness due to their sequential nature), while retaining good performance on tasks in NLP, and allowing for variable length sequences. (It should be noted that a number of methods have been proposed to reduce sequential operations in these sequential tasks – the paper cites Extended Neural GPU, ByteNet and ConvS2S as examples). Somewhat confusingly attention was originally applied in conjunction with RNNs, with “Attention Is All You Need” being the first paper (as far as I’m aware) to exclusively use attention.

Attention maps a sequence of continuous query \(\textbf{q} = (q_1, q_2, \ldots, q_n), q_i \in \mathbb{R}^{d_\text{k}}\), key \(\textbf{k} = (k_1, k_2, \ldots, k_n), k_i \in \mathbb{R}^{d_\text{k}}\), and value vectors \(\textbf{v} = (v_1, v_2, \ldots, v_n), v_i \in \mathbb{R}^{d_\text{v}}\), to a sequence of continuous output vectors \(\textbf{z} = (z_1, z_2, \ldots, z_n), z_i \in \mathbb{R}^{d_\text{v}}\). Each output \(z_i\) is effectively a weighted sum of \(\textbf{v}\), where the weight is computed by a compatibility function applied between \(q_i\) and each \(k_j\). The specific compatibility function used in the transformer model is “scaled dot product attention”. This takes the dot product between any query-key pair, and scales the result by the square root of the dimension of their vectors, giving a continuous valued output between \(-1\) and \(1\). These values are normalised with a \(\text{softmax}\), and the result treated as weights in a weighted sum. For a single output, \(z_j\), this looks like:

\[\begin{align} \textbf{c} &= \left(\frac{q_j \cdot k_1}{\sqrt{d_k}}, \frac{q_j \cdot k_2}{\sqrt{d_k}}, \ldots, \frac{q_j \cdot k_n}{\sqrt{d_k}}\right)\\\
\text{softmax}(\textbf{c})_i &= \frac{e^{c_i}}{\sum_{j = 1}^n e^{c_j}}\\\
\textbf{a} &= \left(\text{softmax}(\textbf{c})_1, \text{softmax}(\textbf{c})_2, \ldots, \text{softmax}(\textbf{c})_n \right)\\\
\textbf{z}_j &= \sum_{i=1}^{n} a_i \textbf{v}_i \end{align}\]

To make the calculation simpler, we can create three input matrices by stacking our vector inputs:

\[\begin{align} Q &\in \mathbb{R}^{n \times d_\text{k}} \\\
K &\in \mathbb{R}^{n \times d_\text{k}} \\\
V &\in \mathbb{R}^{n \times d_\text{v}} \end{align}\]

This allows us to write the basic equation of attention as: \[\text{Attention}(K, Q, V) = \text{softmax} \left( \frac{Q K^T}{\sqrt{d_k}} \right) V \]

Or, in PyTorch, as:

import torch.nn.functional as F

def attn(Q, K, V):
    a = F.softmax((Q @ K.t()) / Q.shape[-1], dim=-1)
    return a @ V
What does this mean?

Lets think back to our original sequence of embeddings, \(\textbf{x}\). We can choose that \(Q\), \(K\), and \(V\) are all equal to \(\textbf{x}\). In the paper, this is called “Encoder self-attention” (one of three ways attention is used). Lets say our input tokens are “The man walked the dog”. To compute the new vectors for each word, we take the projection of each words embedding onto each of the other embeddings (scaled by \(\sqrt{d_k}\) ), normalize these projections so they sum to \(1\), and then sum all of the words embeddings multiplied with the corresponding weight. Thus we get a vector who’s magnitude can’t grow, and that is a linear combination of all of input embeddings. Interestingly, this means the “new” vector for a word, \(z_{i}\), will give more weight in the sum to the embeddings, \(x_{j}\), that were similar to its own. To make this a bit more explicit, let’s imagine our embeddings \(\textbf{x}\) correspond to the sentence “the man walked the dog”. We want to compute a new vector the first word, “the”. To do so, we compute the similarity between “the”’s embedding vector, and the embedding vectors for “the”, “man”, “walked”, “the”, “dog”. These similarities are normalised so they sum to \(1\) (and the two “the” tokens will contribute the maximum similarity score, \(\frac{1}{\sqrt{d_k}}\) ). The new vector is therefore:

\[\begin{align} z_{\text{the}} &= \text{attention_weight}(\text{“the”}, \text{“the”}) \times \text{embedding}(\text{“the”}) \\\
&+ \text{attention_weight}(\text{“the”}, \text{“man”}) \times \text{embedding}(\text{“man”}) \\\
&+ \text{attention_weight}(\text{“the”}, \text{“walked”}) \times \text{embedding}(\text{“walked”}) \\\
&+ \ldots \end{align}\]

Multi-Headed Attention

The paper notes that “additive attention” performs better than the self attention described above, though it is much slower. Additive attention uses a more complicated compatibility function – namely a feed forward neural network. Naturally, this is slower than a highly optimized dot-product. Also naturally, it is more expressive. A single dot product can only give us the “overall” similarity between two vectors, while a feed forward layer can consider combinations of different subspaces.

As a solution, multiple attention “heads” are used in parallel, each of which can focus on a different subspace. We parameterise the number of heads with \(h\).

\[\begin{align} \text{Multihead}(Q, K, V) &= \left( \overset{h}{\underset{i = 0}{\parallel}} \text{head}_i \right) W^{O} \\\
\text{head}_i &= \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V) \end{align}\]

Where \(\parallel\) indicates the concatenation of vectors. Note that \(Q, K, V\) differ from the single headed version, and all are in \( \mathbb{R}^{n \times d_{\text{model}}}\).

Compared to the version of attention discussed in previously, each attention head has three additional matrices, used to learn a linear projection of each of \(Q, K,\) and \(V\). Additionally, \(W^{O}\) is used to project the concatenated output of all heads into \(d_\text{model}\).

\[\begin{align} W_{i}^Q &\in \mathbb{R}^{d_\text{model} \times d_\text{k}} \\\
W_{i}^K &\in \mathbb{R}^{d_\text{model} \times d_\text{k}} \\\
W_{i}^V &\in \mathbb{R}^{d_\text{model} \times d_\text{v}} \\\
W^O &\in \mathbb{R}^{ h d_\text{v} \times d_\text{model} } \end{align}\]

A naive PyTorch implementation looks like this:

import math
import torch
import torch.nn.functional as F
from torch import nn


class MultiHeadNaive(nn.Module):
    def __init__(self, d_model, d_k, d_v, h):
        super().__init__()

        self.heads = nn.ModuleList([AttnHead(d_model, d_k, d_v) for _ in range(h)])
        self.Wo = nn.Linear(h * d_v, d_model)

    def forward(self, Q, K, V):
        head_results = [head(Q, K, V) for head in self.heads]
        return self.Wo(torch.cat(head_results, dim=-1))


class AttnHead(nn.Module):
    def __init__(self, d_model, d_k, d_v):
        super().__init__()

        self.Wq = nn.Linear(d_model, d_k)
        self.Wk = nn.Linear(d_model, d_k)
        self.Wv = nn.Linear(d_model, d_v)

    def forward(self, Q, K, V):
        return attn(self.Wq(Q), self.Wk(K), self.Wv(V))


def attn(Q, K, V):
    a = F.softmax((Q @ K.t()) / math.sqrt(Q.shape[-1]), dim=-1)
    return a @ V

Most implementations are slightly more complicated for reasons of computational efficiency. It is possible to perform all of the \(W^{*}_{i}\) multiplications at one, by choosing a larger matrix:

import math

import torch
import torch.nn.functional as F
from torch import nn


class MultiHead(nn.Module):
    def __init__(self, h, d_model):
        super().__init__()

        assert d_model % h == 0, f"d_model, {d_model} not divisible by number of heads, {h}"
        self.d_model = d_model
        self.d_k = d_model // h
        self.h = h

        # These W^{*} linear layers contain _all_ W^{*}_{i} matices for each head.
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)
        self.WO = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V):
        bs = Q.shape[0]

        n = Q.shape[1]

        # (bs x n x d_model) -> (bs x n x h x d_k) -> (bs x h x n x d_k)
        Qs = self.WQ(Q).view(bs, -1, self.h, self.d_k).transpose(1, 2)
        Ks = self.WK(K).view(bs, -1, self.h, self.d_k).transpose(1, 2)
        Vs = self.WV(V).view(bs, -1, self.h, self.d_k).transpose(1, 2)

        batched_attn = attn(Qs, Ks, Vs)

        # (bs x h x n x d_k) -> (bs x n x h x d_k) -> (bs x n x d_model)
        concatted_attn = batched_attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)

        return self.WO(concatted_attn)


def attn(Q, K, V):
    # Q.matmul(K.transpose(-2, -1): (bs x h x n x d_k) x (bs x h x d_k x n) -> (bs x h x n x n)
    a = F.softmax((Q.matmul(K.transpose(-2, -1))) / math.sqrt(Q.shape[-1]), dim=-1)

    # a.matmul(V): (bs x h x n x n) x (bs x h x n x d_k) -> (bs x h x n x d_k)
    return a.matmul(V)

This implementation allows more operations to be parallelised, though is a bit more fiddly in terms of keeping track of dimensions. This is a simplified version of the implementation found here.

Notes

It is worth pointing out that neither of these versions of attention actually know about the positions of any of the inputs within a sequence – changing the order of the sequence would have no impact on the outcome. The full transformer model deals with this by adding “positional embeddings”. It is also worth point out that attention alone doesn’t contain a non-linearity (though the \(\text{softmax}\) used to compute attention weights is non-linear, the weighted sum itself is a linear combination).

An interesting way to think about attention is as an \(n \times n\) matrix. This could be interpreted as describing weighted edges in a graph, where each node is a token. Each attention head learns how to construct this graph for different inputs.