202409212318
Status: #idea
Tags: #ai #transformers #nlp
# Decoder-Only Transformers as Differentiable n-gram Models
## $n$-gram Model Basics
[$n$-gram models](https://en.wikipedia.org/wiki/Word_n-gram_language_model) use the previous $n$ tokens to make a prediction for token $n+1$. Typically, a conditional probability distribution is created by looking at all occurrences of a particular $n$-gram and counting the tokens that follow it, then dividing by the total number of occurrences of that $n$-gram. In this way, we have a conditional probability distribution for every $n$-gram that appears in our dataset. Below is an example of calculating the probability of the sentence "I saw the read house" using a bigram model.
![[Pasted image 20240227215748.png]]
Formally, this looks like:
![[Pasted image 20240227223331.png]]
where the conditional probability can be calculated as:
![[Pasted image 20240227223411.png]]
## Transformer Model Basics
The base of the transformer model is the attention mechanism, pictured below.
![[Pasted image 20240227214848.png]]
Intuitively, attention can be understood as a continuous key-value store. The attention operation is performed by taking an $n \times d$ query matrix $Q$ and multiplying it by the transpose of an $n \times d$ key matrix $K$. In this case, $n$ is the max number of tokens in a sequence and $d$ is the dimension of the embedding vector for each token.
This gives a dot product of every query with every key. Query vectors that point in the same (or similar) direction as key vectors result in larger dot product magnitudes, which gives us a measure of *similarity* between each query and every possible key. We finally apply $\texttt{softmax}$ so that a single key is selected most prominently from the set of keys. (Since we use $\texttt{softmax}$ instead of $\texttt{argmax}$, the transformer's key selection step remains differentiable.)
![[Pasted image 20240227214635.png]]
The output of the above step is the **attention matrix**. The attention matrix is essentially a probability distribution over the keys for each query in our input sequence. By multiplying the attention matrix with the value matrix, we obtain a matrix of **weighted sums of all the values**, where row $i$ in the resulting matrix is the weighted sum vector for index $i$ in the input sequence, weighted by index $i