Writing an LLM from scratch, part 13 – attention heads are dumb

Writing an LLM from scratch, part 13 -- the 'why' of attention, or: attention heads are dumb
Now that I've finished chapter 3 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)" -- having worked my way through multi-head attention in the last post -- I thought it would be worth pausing to take stock before moving on to Chapter 4.
There are two things I want to cover, the "why" of self-attention, and some thoughts on context lengths. This post is on the "why" -- that is, why do the particular set of matrix multiplications described in the book do what we want them to do?
As always, this is something I'm doing primarily to get things clear in my own head -- with the possible extra benefit of it being of use to other people out there. I will, of course, run it past multiple LLMs to make sure I'm not posting total nonsense, but caveat lector!
Let's get into it. As I wrote in part 8 of this series:
I think it's also worth noting that [what's in the book is] very much a "mechanistic" explanation -- it says how we do these calculations without saying why. I think that the "why" is actually out of scope for this book, but it's something that fascinates me, and I'll blog about it soon.
That "soon" is now :-)
Attention heads are dumb
I think that my core problem with getting my head around why these equations work was that I was overestimating what a single attention head could do. In part 6, I wrote, of the phrase "the fat cat sat on the mat":
So while the input embedding for "cat" just means "cat in position 3", the context vector for "cat" in this sentence also has some kind of overtones about it being a cat that is sitting, perhaps less strongly that it's a specific cat ("the" rather than "a"), and hints of it being sitting on a mat.
The thing that I hadn't understood was that this is true in as far as it goes, but only for the output of the attention mechanism as a whole -- not for a single attention head.
Each individual attention head is really dumb, and what it's doing is much simpler than that!
The two things that combine to make the mechanism as a whole smart are multi-head attention and layering. The book has gone over multi-head attention in detail, so let's drill down on that second part.
Layers
Right at the start, in part 1, I wrote:
One other thing that Raschka mentions that confuses me a little is that apparently the original transformer architecture had six encoder and six decoder blocks, and GPT-3 has 96 transformer layers. That doesn't fit very comfortably with my model of how this all works. Both encoders and decoders seem like stand-alone things that accept inputs (tokens/embeddings) and produce outputs (embeddings/tokens). What would you do with multiple layers of them?
Now that we've covered how attention works, that's become a bit clearer. A multi-head attention block gets a set of input embeddings, one per token in the input sequence, and produces a set of the same number of context vectors. There's nothing stopping us from treating those context vectors as the input embeddings for another attention block and doing the same thing again.
(That also explains why Raschka mentions that the number of dimensions in the context vectors often matches the number in the input embeddings; it makes it easier to use the same "shape" of multi-head attention calculations for each layer.)
In my mind, this is similar to the way an image-processing network -- say, a CNN -- works. In those, the first layer might detect edges, the second might detect lines at certain orientations, the next particular shapes, and then somewhere later on, the nth might recognise dogs' faces.
So the representation of the token "cat" that I described above would not be part of the output of one attention head, and perhaps even the first layer of the attention mechanism might not have anything that rich. But it might be in the output of the third layer of multi-head attention, or the fourth, or something like that.
By the 96th of those layers in GPT-3, what's represented in the context vectors is going to be super-enriched, and have lots of information spread across the different tokens. And realising this was a bit of an epiphany for me as well.
No more fixed-length bottleneck
If you cast your mind back to part 5, a big problem with encoder/decoder RNNs that did not have attention mechanisms was the fixed-length bottleneck. You would run your input sequence into an encoder RNN, which would try to represent its meaning in its hidden state -- a vector of a particular fixed length -- ready to pass it on to the decoder. Easy with a short input, but increasingly hard and eventually impossible as it gets longer, because you'd be trying to pack more and more information into the same "space".
But with attention, this super-enriched and combined representation of the input sequence that comes out of the last attention layer is proportional in length to the number of tokens in the input! You're still limited by available memory, of course (and other things -- see the next post), but the more tokens you have, the larger this "hidden state" of the context vectors.
That's pretty cool.
So, using multi-head attention plus layers allows us to build up complex representations even when each individual attention head is dumb. But, going back to the core of this post, why do these dumb attention heads use the specific calculations that they do?
Why dumb attention heads work
Let's use an example.
A heads-up/warning first: the attention heads are learning their own representations and patterns to match on as part of a deep learning gradient descent -- so whatever they learn will probably be weird and alien and not relate in any way to grammar as we understand it. But for this example, let's pretend that isn't the case, and that we've got an attention head that has learned how to match articles (like "a", "an", and "the") up with their associated nouns.
How would that work? Let's take "the fat cat sat on the mat", and ignore everything apart from the two "the"s, and the nouns "cat" and "mat". We'll say that our attention head wants to produce a context vector for "cat" that combines it with the first "the" (meaning that it will contain the concept that we're talking about a specific cat rather than just "a" cat), and similarly it wants to blend the second "the" into "mat".
Now, remember that our input sequence is a series of input embeddings, which are combinations of the token embeddings (which are vectors in a space that point to some abstract "meaning" of the tokens) and position embeddings (which represent their position in the sequence).
Taking "mat" as our example, we project its input embedding, which means "the token 'mat' at position 7"1 into query space. The breakthrough to me was that query space is another embedding space, just like the original one for the input embeddings, but with different representations for the values.
Let's say that in this new embedding space, representations are much simpler -- they don't have as much detail as the original one. It just represents "this is an article" or "this is not an article", and some information about positioning -- that is, the embedding for an article at position 1 is close to the embedding at position 2, but is not very close to the one for an article at position 69,536. And other things that are not articles would be somewhere even further away.
In this example, perhaps the projection that our attention head has learned will map "'mat' at position 7" to an embedding pointing in the direction of "some article -- the or a -- at position 6 or lower, probably quite close". In other words, the projection into query space turns an input embedding for a token into the kind of thing this attention head is looking for when it's handling that token. Likewise "'cat' at position 2" would be projected into an embedding vector meaning "some article at position 1 or lower, probably quite close".
Now, as well as projecting the input embeddings into the query space, we're also projecting them into the key space. In that case, our imaginary article-matching head would create a projection that would turn the first "the" into something meaning "an article at position 1", and the second into one meaning "an article at position 6".
So, the query weights have projected our input embeddings into this "low-resolution" embedding space to point in a direction meaning "this is what I'm interested in", and the key weights have projected the input embeddings into the same embedding space in a direction meaning "this is what I am".
That means that when we do our dot product, the query vector for the "mat" will point in a very similar direction to the key vector for the second "the", and so the dot product will be high -- remember, so long as vectors are roughly the same length, the dot product is an indication of how similar they are.
What's important about this is that the shared embedding space that the query and key vectors use can actually be pretty impoverished compared to the rich space that the input embeddings used. In our case, all the head cares about are whether tokens are nouns or articles or something else, and their position.
Let's take an example. Here's the imaginary set of attention scores that I imagined that the attention mechanism might come up with back in part 6 (modified to be causal so that tokens don't pay any attention to tokens in their "future"):
Token | ω("The") | ω("fat") | ω("cat") | ω("sat") | ω("on") | ω("the") | ω("mat") |
---|---|---|---|---|---|---|---|
The | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
fat | 0.2 | 1 | 0 | 0 | 0 | 0 | 0 |
cat | 0.6 | 0.8 | 1 | 0 | 0 | 0 | 0 |
sat | 0.1 | 0 | 0.85 | 1 | 0 | 0 | 0 |
on | 0 | 0.1 | 0.4 | 0.6 | 1 | 0 | 0 |
the | 0 | 0 | 0 | 0 | 0.1 | 1 | 0 |
mat | 0 | 0 | 0.2 | 0.8 | 0.7 | 0.6 | 1 |
Each row is, for the token in the first column, the attention scores for all of the
other words. It's based on my own intuition about the importance of words, and it's
the kind of thing you might imagine a clever attention head might come up with.
(Remember that
But our closer-to-real-world example of an article-noun matching head is really dumb, so it might come up with something more like this:
Token | ω("The") | ω("fat") | ω("cat") | ω("sat") | ω("on") | ω("the") | ω("mat") |
---|---|---|---|---|---|---|---|
The | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
fat | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
cat | 0.8 | 0 | 1 | 0 | 0 | 0 | 0 |
sat | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
on | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
the | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
mat | 0.1 | 0 | 0 | 0 | 0 | 0.8 | 1 |
All it has done is decided to pay attention to the "the"s when considering the nouns -- and it's even paying a bit of attention to the first "the" when considering "mat", because it doesn't know that it has to be the closest "the" that it matches with. 2
Now, as I said earlier, the real attention heads, having been trained by gradient descent over billions of tokens, will probably have learned something weird and abstract and not related to the way we think of language, grammar and the parts of speech.
But taken individually, they will be really dumb, because the equation is doing something really simple: when considering a particular kind of thing, look for this other kind of thing. Each token is projected into a shared embedding space by the query weights ("what I'm looking for" ) and into the same space by the key weights ("what I am"), and the dot product does the comparison to find matches.
Of course, that doesn't mean we lose any information. This impoverished embedding space is only used to do the matching to work out our attention scores. When we work out the context vectors we use projection into value space, which can be as rich as we like.
It's worth noting that although the example Raschka is using in the book has the same dimensionality for the shared space for query and key vectors, and the space for value vectors, there's actually no need for that. I've seen specs for LLMs where the QK space has fewer dimensions -- which makes sense, at least for this trivial example.
It's also worth noting that this key/query space is impoverished in this example, but in a real "alien" learned example, it could actually be quite complex and rich -- but much harder to understand than this example. Ultimately, the nature of that embedding space will be learned in the same way as everything else, and will match whatever thing the head in question has learned to do.
The elegance of dumb attention
So, that is (right now) my understanding of how scaled dot product attention works. We're just doing simple pattern matching, where each token's input embedding is projected by the query weights into a (learned) embedding space that is able to represent what it is "looking for" in some sense. It's also projected by the key weights into the same space, but this time in a way that makes it point to what it "is" in the same sense. Then the dot product matches those up so that we can associate input embeddings with each other to work out our attention scores.
That all makes sense in my head, and I hope it does in at least a few other people's :-)
I'll wrap this one up here; next time I'll be posting on what I understand right now about what we've been through in the book so far means for context lengths. We've seen the upside of that hidden state that grows as the input sequence does -- what are the downsides?
-
I'll one-index for this. ↩
-
I can't think of any way a single head could, TBH. It's considering all other tokens in parallel, so when it's looking at the first "the" it doesn't know that there's another closer one. ↩
What's Your Reaction?






