The Math Behind the Machine: A Deep Dive into the Transformer Architecture
The transformer architecture was introduced in the paper “Attention is All You Need” by Vaswani and a team of researchers at Google. In the paper, they demonstrated how they were able to achieve state-of-the-art performance for machine translation using this novel architecture without recurrence and convolution. Even though transformers were initially designed to address NLP problems, they currently represent the state-of-the-art (vision transformer) in the field of computer vision as well.
Before the deep dive let’s have a high-level overview of the architecture. Transformers consist of two primary components: an encoder and a decoder. The encoder includes multi-head self-attention, layer normalization, feed-forward network, and residual connections. The decoder is almost similar to the encoder except it uses two attention blocks: Masked Multi-Head Self-Attention and Multi-Head Cross-Attention.
This blog demonstrates the transformer in the context of machine translation task, providing a demonstration of each key component of it.
Data preprocessing
1. Adding special tokens:
Machine translation is a sequence-to-sequence task, this involves translating a sequence of text from one language (source) to another language (target). One characteristic of this task is that the length of the input sequence (source) and the length of the output sequence (target) can vary.
So as a preprocessing step, two special tokens, <start> and <end>, are added to the source and target sequences. These tokens help the model understand the structure of the input and output sequences. The <end> special token signals the encoder when it has received the complete source sequence, on the other hand the <end> token holds a greater significance for the decoder as it facilitates the generation of sequences with variable lengths. The decoder will emit the <end> token that will inform us when it is done emitting tokens, without an <end> token we would have no clue when the decoder is done emitting.
Note that the decoder is autoregressive, the <start> token signals the decoder to start generation of the next tokens. Adding these tokens shift the target sequence by one unit to the right. The example below demonstrates how special tokens are appended in sequences.
Encoder input: Hello, how are you? <EOS>
Decoder input: <SOS> Namaskar, kemon achen? <EOS>
Target prediction: Namaskar, kemon achen? <EOS>
2. Tokenization and Token embeddings
Transformers cannot directly process texts. Therefore, for this purpose, the source and target sentences are broken down into tokens.
For example, the whitespace tokenizer breaks the stream of input text by looking at the white spaces, and the subword tokenizer interpolates between word-level and character-based tokenization. Below is an example of a Byte-Pair Encoding (BPE) tokenizer that was used in GPT-2.
In the given example, the source sentence “Hello, how are you?” is broken down into 6 tokens. Note that the prefix “Ġ” signifies the existence of a space before the token. Afterwards, each token is assigned a unique numerical token ID. Each token ID is mapped to a vector space (512-dimensional) such that similar tokens are close to each other, and each dimension of the vector represents some semantic meaning about the token. These vectors are called embeddings.
Note: These embeddings are not word2vec or GLOVE embeddings, if it was then the embedding layer would have been a look-up table! The embedding of each token is generated randomly and then refined during backpropagation during the training.
Feature engineering
Transformer takes the source and target sequences all at once but language functions sequentially! So, we must find a way to incorporate sequential details into the token embeddings.
Okay, this seems easy, right? How about we add a 512-dimensional vector filled with the value ‘i + 1’ to the token embedding of the token at index “i”? So, we add a 512-dimensional vector filled with the value 1 to the 1st token embedding. This is a naive way to add sequential information to the token embeddings but do you see the problem here?
For a token at index 50, we will completely lose the token embedding representation as the positional values will dominate. So, this is not quite helpful! This problem is solved using positional encodings.
Positional encodings:
The authors employed a clever trick to incorporate sequential information, known as positional encoding. But before delving into that, let’s first understand some nomenclature. Here, ‘pos’ represents the position of the token; in our example, ‘Hello’ is positioned at 0, ‘,’ is at 1, and so on. ‘i’ denotes the index along the token-embedding vector, ‘d_model’ set to 512.
Imagine that d_model is 128 and the maximum position is 60. Below is a visualization of the positional encodings.
Each row in the above plot represents the corresponding positional encoding at that index. But to really appreciate PE we need to have a look at the cosine similarity matrix of the PE matrix. Below it is!
Look closely, the cosine similarity of the positional encoding for a specific token at position “p” is highest with respect to the positional encodings of neighbouring tokens, indicating a higher similarity with its surrounding tokens. In the original paper author states,
We chose this function because we hypothesized it would allow the model to easily learn to attend by relative positions.
Encoder block
Step-1: Multi-Head Attention (MHA):
Attention is at the core of the Transformer architecture and the MHA mechanism proves to be highly efficient on our hardware. The encoder performs self-attention (scaled dot-product attention) on the input tokens to understand the relationship between tokens. MHA extends this idea by using multiple attention heads all in parallel to compute attention to better understand the context. This allows the model to focus on different aspects of the input sequence simultaneously, capturing complex relationships and dependencies in the data.
Scaled Dot-Product Attention
For the input tensor, three different fully connected neural networks (let’s call them Wq, Wk and Wv) are trained to learn 3 matrices. These matrices are called the Query (Q), Key (k) and Value (V) matrices.
Before going any further let’s understand the shapes.
Dot_product(Input, Wq )= Query(Q)
Dot_product(Input, Wk )= Key(K)
Dot_product(Input, Wv )= Value(V)
Now the attention score is computed as follows,
Here square_root(d_k) is nothing but the dimension of the key(k). Also, note that the softmax is applied along the horizontal axis (Row wise) of the matrix.
But what are these Q, K and V matrices? What is the point of learning these three matrices?
The inspiration for computing these three matrices comes from retrieval systems. For example, when you search for a video on YouTube, the search engine will map your query(q) which is a text string against a set of keys(k) that are video titles, descriptions comment boxes etc. Then the search engine will present you with the best match values(v). Think of Q, K and V as follows.
Q: How closely some other token is related to me?
K: How closely I am related to some other token?
V: Contains the information of all the tokens.
Note that the attention scores are weighting the value(V) matrix.
Let’s talk about Multi-Head attention now. Multi-head attention is a way of grouping a bunch of Scaled Dot-Product attention mechanisms to better understand the context. First, these Wq, Wk and Wv are used to compute Q, K and V matrices and then these 3 matrices are split across the last dimension into H heads. Note that for this process, the embedding dimension has to be divisible by H.
For each head, the scaled dot-product attention is computed and concatenated. This whole process is parallelizable and also helps the model understand better context, as each head would allow the model to learn some certain context.
The author states, Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this. Mathematically MHA can be represented as follows.
Step-2: Residual connection and Layer Normalization:
The input flows through the MHA block, the output tensor is of the same size as the input. The output is added with the input tensor of the MHA block by a residual connection. This helps the model to mitigate the vanishing gradient problem but not only that! The residual connection reminds the transformer about the original input representation, this means it indicates the model that the contextual embedding of the token representation best represents the input tokens.
Remember that we added positional encoding with the token embeddings, we do not want this information to get lost during the attention computation. After the residual connection, we normalize the representation using Layer normalization.
Step-3: Feedforward network:
Multi-Head attention weights the value(V) matrix so that important tokens possess high values and the rest are suppressed. However, we want the encoder to output a global representation of the context. To achieve the output tensor of the previous Layer normalization block is again passed through a feedforward network which learns to capture the global relationship of the input tokens, helping the model to learn the best representation for the given input.
Step-4: stack of encoders [6x]
Instead of using one encoder the transformer uses a stack of encoder. The input tensor and output logits of each encoder is of same size, so it makes the computation easy even when we stack multiple encoders. In the original implementation they stacked 6 such encoders.
Stacking layers is what makes deep learning models powerful, using just one encoder the transformer will not be able to capture the complexity needed to model an entire language. Stacking encoders make the model understand complex patterns in language, that increases accuracy.
Decoder block
The decoder is autoregressive. The target tokens are modified by appending a <SOS> token at first and <EOS> at the end, this shifts the target labels by one to the right. The reason for shifting is to signal the decoder to start generating predictions when given the <SOS> token during inference.
The decoder has a lot of overlapping components with the encoder, except for a few differences. It uses Masked Multi-Head Self-Attention and Multi-Head Cross-Attention, unlike the encoder. Let’s step by step understand how the dataflow happens inside the decoder.
Step-1: Masked Multi-Head Attention
The Multi-Head Attention block used in the encoder is identical to the one used in the decoder. The only distinction lies in the addition of a mask before computing the Softmax in the Multi-Head Attention mechanism.
But why do need Masking?
The transformer decoder operates in an autoregressive fashion during inference, predicting tokens from left to right, with each token’s prediction influenced by the preceding ones. However, the training occurs in parallel, where both the input and target sequences are simultaneously provided to the model. To maintain the autoregressive phenomena during training, the trick is to prevent the decoder from attending to future tokens.
Below is how the data flow happens in the MHA block. The mask is broadcasted to match the dimension.
BS: Batch size,
E = embedding_dimension
num_heads = 8
head_dimension = E // num_heads
SL = Sequence length
The decoder decodes token one at a time. During training it does not have access to the tokens that “do not exist yet!”. The mask used here is called causal masking.
In a nutshell, the masked-MHA is nothing but self-attention on the target sequence along with causal masking so the decoder does not access to future tokens.
Step-2: Residual connection with Layer normalization
[Similar to the step used in the encoder]
Step-3: Multi-head attention [Cross attention, Decoder]
Machine translation is a complex task, the model has to learn the relationship between source and target tokens to perform translation. To achieve this, decoder has a cross attention block where the where the Q matrix learned from the target tokens, rest two (K, V) comes from the last encoder block.
Rest everything remains the same except the attention matrix is again masked so that the model does not attend to future tokens. This mask is a combination of causal masking that is used in masked MHA and encoder padding mask. We do not want the model to attend to future tokens and padding. Cross attention also makes the model interpretable, the cross attention weights can be visualized to understand the relationship of source and target tokens during inference.
Step-4: Linear layer and Softmax
The output logits of decoder is of size [Batch_size, sequence_length, d_model]. The linear layer transforms the output logits to shape [Batch_size, sequence_length, vocab_size]. Understand that the vocab_size is the size of all unique tokens in the target sequence.
The softmax function is performs along the vocab_size dimension to get probability scores, the class label is computed using the torch.argmax() function.
Step-5: stack of decoders [6x]
Similar to the encoder the decoder is also stacked to learn complex patterns between source and target sequences. Stacking increases the number of weights but reduces bias of the model.
Final step: Computing Categorical Cross Entropy Loss:
The model minimizes the categorical cross entropy loss where each unique token is treated as a separate class that the model aims to predict. So, the more diverse and varied your target sequence is, the greater the number of classes. It’s like the Transformer is playing a game of guessing, where each possible outcome (token) is a class. The Categorical Cross Entropy loss then tells the model how off its guesses are from the actual target sequence, nudging it to get better over time.
Why are transformers considered superior to RNNs?
Unlike RNN, there is no recurrence in the Transformer. Instead, the entire set of source and target tokens is fed to the model simultaneously. Cross-entropy loss is computed, and the gradients are then back-propagated to adjust the weights. The training process is highly efficient, and this efficiency stands out as a significant advantage of the Transformer as the model can be scaled pretty easily. Apart from that transformers can handle long range dependencies more effectively than RNNs.
Have a look at the video below, where Andrej Karpathy explains the Transformer in an intuitive way.
References:
- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … Polosukhin, I. (2017). “Attention is All You Need.” Advances in Neural Information Processing Systems, 30. Curran Associates, Inc. Retrieved from https://arxiv.org/abs/1706.03762v1
- “Transformer (machine learning model).” Wikipedia. Retrieved from https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)
- Various discussions on the Transformer model. (n.d.). AI Stack Exchange. Retrieved from https://ai.stackexchange.com/questions/tagged/transformer