Understanding the core innovation behind modern AI language models
The attention mechanism is a fundamental component of modern transformer-based language models like GPT, BERT, and others. It allows models to dynamically focus on different parts of the input sequence when producing each part of the output sequence.
Unlike traditional sequence models that process inputs in a fixed order, attention mechanisms enable models to learn which parts of the input are most relevant at each step of processing.
Attention mechanisms compute a weighted sum of values (V), where the weights are determined by the compatibility between queries (Q) and keys (K). This allows the model to focus on different parts of the input sequence dynamically.
Represent what the model is "looking for" at the current position. They are learned representations that help determine which parts of the input to focus on.
Represent what each input element "contains" or "offers". They are compared against queries to determine attention weights.
Contain the actual information that will be aggregated based on the attention weights. They represent what gets passed forward.
The separation of Q, K, and V provides flexibility and expressive power to the attention mechanism:
In transformer models, Q, K, and V are all derived from the same input sequence through learned linear transformations:
# Python example of creating Q, K, V import torch import torch.nn as nn # Suppose we have input embeddings of shape (batch_size, seq_len, d_model) batch_size = 32 seq_len = 10 d_model = 512 input_embeddings = torch.randn(batch_size, seq_len, d_model) # Create linear projection layers q_proj = nn.Linear(d_model, d_model) # Query projection k_proj = nn.Linear(d_model, d_model) # Key projection v_proj = nn.Linear(d_model, d_model) # Value projection # Project inputs to get Q, K, V Q = q_proj(input_embeddings) # Shape: (batch_size, seq_len, d_model) K = k_proj(input_embeddings) # Shape: (batch_size, seq_len, d_model) V = v_proj(input_embeddings) # Shape: (batch_size, seq_len, d_model)
In practice, the dimensions are often split into multiple "heads" (multi-head attention), where each head learns different attention patterns. This allows the model to attend to different aspects of the input simultaneously.
The core computation in attention mechanisms is the scaled dot-product attention, which can be implemented as follows:
def scaled_dot_product_attention(Q, K, V, mask=None): """ Q: Query tensor (batch_size, ..., seq_len_q, d_k) K: Key tensor (batch_size, ..., seq_len_k, d_k) V: Value tensor (batch_size, ..., seq_len_k, d_v) mask: Optional mask tensor for masking out certain positions """ # Compute dot products between Q and K matmul_qk = torch.matmul(Q, K.transpose(-2, -1)) # (..., seq_len_q, seq_len_k) # Scale by square root of dimension d_k = Q.size(-1) scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) # Apply mask if provided (for decoder self-attention) if mask is not None: scaled_attention_logits += (mask * -1e9) # Softmax to get attention weights attention_weights = torch.softmax(scaled_attention_logits, dim=-1) # Multiply weights by values output = torch.matmul(attention_weights, V) # (..., seq_len_q, d_v) return output, attention_weights
The scaling factor (√dₖ) is crucial because dot products grow large in magnitude as the dimension increases. This can push the softmax function into regions where it has extremely small gradients, making learning difficult. Scaling by √dₖ counteracts this effect.
class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.depth = d_model // num_heads # Linear layers for Q, K, V projections self.wq = nn.Linear(d_model, d_model) self.wk = nn.Linear(d_model, d_model) self.wv = nn.Linear(d_model, d_model) self.dense = nn.Linear(d_model, d_model) def split_heads(self, x, batch_size): """Split the last dimension into (num_heads, depth)""" x = x.view(batch_size, -1, self.num_heads, self.depth) return x.transpose(1, 2) # (batch_size, num_heads, seq_len, depth) def forward(self, q, k, v, mask=None): batch_size = q.size(0) # Linear projections q = self.wq(q) # (batch_size, seq_len, d_model) k = self.wk(k) v = self.wv(v) # Split into multiple heads q = self.split_heads(q, batch_size) k = self.split_heads(k, batch_size) v = self.split_heads(v, batch_size) # Scaled dot-product attention scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask) # Concatenate heads scaled_attention = scaled_attention.transpose(1, 2) # (batch_size, seq_len, num_heads, depth) concat_attention = scaled_attention.contiguous().view(batch_size, -1, self.d_model) # Final linear layer output = self.dense(concat_attention) return output, attention_weights
Q, K, and V all come from the same sequence. Allows each position to attend to all positions in the same sequence.
Used in decoder to prevent positions from attending to subsequent positions (autoregressive property).
Q comes from one sequence, while K and V come from another sequence (e.g., encoder-decoder attention).
Only attends to a subset of positions to reduce computational complexity (e.g., local, strided, or global attention).
"Attention Is All You Need" - Introduced the transformer architecture with scaled dot-product attention.
arXiv:1706.03762"Speech and Language Processing" - Comprehensive chapter on attention and transformer models.
Stanford NLP TextbookJay Alammar's visual explanation of transformer attention mechanisms.
jalammar.github.ioLine-by-line implementation guide with PyTorch.
Harvard NLP TutorialTay et al. review various attention variants for efficiency.
arXiv:2009.06732