Project 1: Build a Toy Transformer from Scratch in PyTorch
2. Model Components
2.1 Positional Encoding (sinusoidal; easy and effective)
Sinusoidal positional encoding is a fundamental technique in transformers that injects information about token positions into the model. Without positional encoding, transformers would be position-invariant and unable to distinguish sequence order.
The sinusoidal approach uses sine and cosine functions of different frequencies to create unique position vectors. Each position is encoded as a distinct pattern across the embedding dimensions, allowing the model to learn both absolute and relative positions.
Key advantages of sinusoidal encoding:
- No additional parameters to learn
- Theoretically allows extrapolation to sequences longer than those seen during training
- Creates smooth transitions between positions
You can later swap this for RoPE (Rotary Position Embedding), which directly encodes relative position information into the attention calculation through rotation matrices, often showing better performance on longer sequences and more efficient extrapolation beyond training lengths.
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=4096):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe) # [max_len, d_model]
def forward(self, x): # x: [B,T,C]
T = x.size(1)
return x + self.pe[:T]
This part of the code implements sinusoidal positional encoding, a technique essential in transformer models to provide information about token positions. Let's break it down:
The class inherits from PyTorch's nn.Module and provides two primary methods:
Initialization Method: The constructor takes two parameters - d_model (embedding dimension) and max_len (maximum sequence length, defaulting to 4096). It creates position encodings as follows:
- Creates an empty tensor
peof shape [max_len, d_model] filled with zeros - Generates position indices from 0 to max_len-1 as a column vector (
pos) - Calculates frequency divisors (
div) using the formula10000^(-2i/d_model)for dimension indices - Applies sine function to even indices (0, 2, 4...) of the embedding dimension
- Applies cosine function to odd indices (1, 3, 5...) of the embedding dimension
- Registers the encodings as a buffer named "pe" (meaning it's part of the model but not a trainable parameter)
Forward Method: This method adds positional information to input embeddings:
- Takes input tensor
xwith shape [Batch, Time, Channels] - Extracts the sequence length
Tfrom the input's second dimension - Adds the pre-computed positional encodings to the input embeddings
- Returns the embedding + position encoding combination
The mathematical intuition behind this approach is that each position is encoded as a unique pattern of sine and cosine waves at different frequencies, allowing the model to learn relative positions.
2.2 Multi-Head Self-Attention (causal)
This component is the heart of the transformer architecture. Multi-Head Self-Attention allows the model to focus on different parts of the input sequence simultaneously, creating rich representations that capture complex relationships between tokens. The "causal" aspect ensures that predictions for each position can only depend on known tokens (those that come before it).
Key characteristics of this implementation:
- Splits the embedding dimension across multiple attention heads, allowing each head to focus on different aspects of the sequence
- Implements the scaled dot-product attention mechanism (dividing by √d_k to stabilize gradients)
- Uses a causal mask to enforce autoregressive behavior, preventing information leakage from future tokens
- Includes dropout for regularization to improve generalization
The implementation below transforms the input through query (Q), key (K), and value (V) projections before computing attention scores and applying the mask to ensure causality.
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.h = n_heads
self.dh = d_model // n_heads
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
q = self.q(x).view(B, T, self.h, self.dh).transpose(1,2) # [B,h,T,dh]
k = self.k(x).view(B, T, self.h, self.dh).transpose(1,2)
v = self.v(x).view(B, T, self.h, self.dh).transpose(1,2)
att = (q @ k.transpose(-2,-1)) / math.sqrt(self.dh) # [B,h,T,T]
# Causal mask
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
att = att.masked_fill(mask, float("-inf"))
w = F.softmax(att, dim=-1)
w = self.drop(w)
y = w @ v # [B,h,T,dh]
y = y.transpose(1,2).contiguous().view(B, T, C) # [B,T,C]
return self.out(y)
Here's a breakdown of the MultiHeadSelfAttention class:
Class Definition and Initialization:
The class inherits from PyTorch's nn.Module and implements multi-head self-attention, which is a key component of transformer architectures. The initialization method takes three parameters:
- d_model: The dimensionality of the input embeddings
- n_heads: The number of attention heads
- dropout: Dropout probability for regularization (defaults to 0.0)
The assertion assert d_model % n_heads == 0 ensures that the embedding dimension is divisible by the number of heads, which is necessary for splitting the embedding into equal parts across heads.
The class initializes several components:
- self.h: Stores the number of attention heads
- self.dh: Calculates the dimension per head (d_model divided by n_heads)
- Linear projections for query (q), key (k), and value (v) transformations
- Output projection layer (out) to combine the multi-head outputs
- Dropout layer for regularization
Forward Method:
The forward method implements the actual attention mechanism:
- Input Unpacking and Projections:
- Extracts the batch size (B), sequence length (T), and channel/embedding dimension (C) from the input shape
- Projects the input through the query, key, and value linear layers
- Reshapes and transposes the projections to separate the heads dimension, resulting in tensors of shape [B, h, T, dh]
- Attention Score Calculation:
- Computes the scaled dot-product attention:
(q @ k.transpose(-2,-1)) / math.sqrt(self.dh) - The scaling factor (1/√dh) stabilizes gradients during training
- The result is an attention matrix of shape [B, h, T, T] where each element represents the attention score between two positions
- Causal Masking:
- Creates an upper triangular mask using
torch.triuwith diagonal=1 - This mask ensures causality - each position can only attend to itself and previous positions
- Sets the masked positions to negative infinity (
float("-inf")) which will become zero after softmax
- Attention Weights and Application:
- Applies softmax to convert attention scores to probabilities (along the last dimension)
- Applies dropout to the attention weights for regularization
- Computes the weighted sum of values:
w @ vresulting in [B, h, T, dh]
- Output Transformation:
- Transposes the result back to [B, T, h, dh]
- Reshapes to [B, T, C] by concatenating the heads
- Applies the output projection to get the final output
This implementation enforces the autoregressive property crucial for generative language models by ensuring each token can only attend to previous tokens in the sequence, which is achieved through the causal mask.
2.3 Feedforward (GELU now; try SwiGLU later)
The Feedforward Network (FFN) is a critical component in transformer architectures that processes the output from the self-attention mechanism. It adds non-linearity and increases the model's representational capacity by projecting inputs to a higher dimension before projecting back.
In this implementation, the FeedForward class:
- Takes input with dimension d_model, projects it to a higher dimension d_ff
- Applies GELU (Gaussian Error Linear Unit) activation function, a smoother alternative to ReLU
- Processes through dropout for regularization
- Projects back to the original dimension d_model
- Applies a final dropout layer
The structure follows a common pattern in transformer designs where the hidden dimension (d_ff) is typically 4x the model dimension, allowing the network to learn more complex patterns.
The suggested upgrade to SwiGLU represents a more advanced activation function that combines elements of SwiSH and GLU (Gated Linear Unit). The mathematical form silu(W1x) * W2x → proj allows for more effective information flow by using a gating mechanism, which has been shown to improve performance in larger language models.
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x): return self.net(x)
Here's a breakdown of the FeedForward class in the code:
Class Definition and Initialization:
The constructor takes three parameters:
- d_model: The input/output dimension (same as the model's embedding dimension)
- d_ff: The inner dimension of the feed-forward network, typically 4x larger than d_model
- dropout: A regularization parameter controlling the dropout probability (defaults to 0.0)
Network Architecture:
The feed-forward network is implemented as a sequential container with the following layers:
- A linear projection from d_model to the larger dimension d_ff
- A GELU (Gaussian Error Linear Unit) activation function, which is a smooth alternative to ReLU
- A dropout layer for regularization
- A linear projection back from d_ff to d_model
- A final dropout layer
Forward Method:
The forward method is concisely implemented as a single line that passes the input x through the sequential network and returns the result:
def forward(self, x): return self.net(x)
Purpose in the Transformer:
This feed-forward network serves several important functions:
- It adds non-linearity to the model, allowing it to learn complex patterns
- It increases the model's representational capacity by projecting to a higher dimension before returning to the model dimension
- It processes each position independently, complementing the self-attention mechanism that models relationships between positions
The code notes suggest upgrading this implementation to SwiGLU (a combination of SwiSH and Gated Linear Unit) for better scaling behavior. SwiGLU would replace the current GELU activation with a gating mechanism of the form silu(W1x) * W2x → proj, which has been shown to improve performance in larger language models
Upgrade idea: Replace with SwiGLU:
silu(W1x) * W2x → projfor better scaling behavior.
2.4 Transformer Block (Pre-Norm)
The Transformer Block is a fundamental building block in modern transformer architectures, implementing the core processing unit that combines self-attention and feed-forward operations. This implementation uses a "Pre-Norm" approach, which applies layer normalization before each sub-layer rather than after (as in the original "Post-Norm" transformer design).
Class Definition and Initialization:
The TransformerBlock class takes several parameters:
- d_model: The embedding dimension throughout the model
- n_heads: Number of attention heads for multi-head attention
- d_ff: Dimension of the feed-forward network's hidden layer
- dropout: Dropout probability for regularization
Architecture Components:
- Two Layer Normalization layers (ln1, ln2) - normalize inputs to attention and feed-forward
- Multi-Head Self-Attention layer (attn) - processes token relationships
- Feed-Forward Network (ff) - adds non-linearity and transforms representations
Forward Pass Flow:
The forward method implements two sequential sub-layers, each with a residual connection:
- Self-Attention Sub-layer:
x = x + self.attn(self.ln1(x))
- First normalizes the input using layer norm
- Passes normalized input through multi-head attention
- Adds the attention output to the original input (residual connection)
- Feed-Forward Sub-layer:
x = x + self.ff(self.ln2(x))
- Normalizes the output from the attention block
- Passes it through the feed-forward network
- Adds the result to the input of this sub-layer (residual connection)
Pre-Norm vs. Post-Norm:
This implementation uses the Pre-Norm variant, which applies normalization before each sub-layer rather than after. Research has shown Pre-Norm leads to more stable training, especially in deeper networks, by ensuring the residual path remains unobstructed for gradient flow.
The residual connections (adding the input to the sub-layer output) are crucial for deep transformer networks as they help mitigate the vanishing gradient problem and allow for effective training of deeper architectures.
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.0):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadSelfAttention(d_model, n_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ff = FeedForward(d_model, d_ff, dropout)
def forward(self, x):
x = x + self.attn(self.ln1(x)) # residual
x = x + self.ff(self.ln2(x)) # residual
return x
Here's a comprehensive breakdown of the TransformerBlock class:
The TransformerBlock class implements a core building block of modern transformer architectures using a "Pre-Norm" approach. This class combines self-attention and feed-forward operations with residual connections and layer normalization.
Class Structure:
The class inherits from PyTorch's nn.Module and takes four parameters:
- d_model: The embedding dimension used throughout the model
- n_heads: Number of attention heads for the multi-head self-attention
- d_ff: Dimension of the feed-forward network's hidden layer
- dropout: Dropout probability for regularization (defaults to 0.0)
Components:
The class initializes four main components:
- self.ln1: First LayerNorm that normalizes inputs to the attention layer
- self.attn: MultiHeadSelfAttention layer that processes relationships between tokens
- self.ln2: Second LayerNorm that normalizes inputs to the feed-forward layer
- self.ff: FeedForward network that adds non-linearity and transforms representations
Forward Method:
The forward method implements two sequential sub-layers, each with a residual connection:
- Self-Attention Sub-layer:
x = x + self.attn(self.ln1(x)) # residual
- First normalizes the input using layer normalization
- Passes the normalized input through multi-head attention
- Adds the attention output to the original input (residual connection)
- Feed-Forward Sub-layer:
x = x + self.ff(self.ln2(x)) # residual
- Normalizes the output from the attention block
- Passes it through the feed-forward network
- Adds the result to the input of this sub-layer (residual connection)
Key Design Choices:
This implementation uses Pre-Norm architecture where normalization is applied before each sub-layer rather than after. Research shows Pre-Norm leads to more stable training, especially in deeper networks, by ensuring the residual path remains unobstructed for gradient flow.
The residual connections (adding the input to the sub-layer output) are crucial as they help mitigate the vanishing gradient problem and allow for effective training of deeper networks.
2. Model Components
2.1 Positional Encoding (sinusoidal; easy and effective)
Sinusoidal positional encoding is a fundamental technique in transformers that injects information about token positions into the model. Without positional encoding, transformers would be position-invariant and unable to distinguish sequence order.
The sinusoidal approach uses sine and cosine functions of different frequencies to create unique position vectors. Each position is encoded as a distinct pattern across the embedding dimensions, allowing the model to learn both absolute and relative positions.
Key advantages of sinusoidal encoding:
- No additional parameters to learn
- Theoretically allows extrapolation to sequences longer than those seen during training
- Creates smooth transitions between positions
You can later swap this for RoPE (Rotary Position Embedding), which directly encodes relative position information into the attention calculation through rotation matrices, often showing better performance on longer sequences and more efficient extrapolation beyond training lengths.
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=4096):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe) # [max_len, d_model]
def forward(self, x): # x: [B,T,C]
T = x.size(1)
return x + self.pe[:T]
This part of the code implements sinusoidal positional encoding, a technique essential in transformer models to provide information about token positions. Let's break it down:
The class inherits from PyTorch's nn.Module and provides two primary methods:
Initialization Method: The constructor takes two parameters - d_model (embedding dimension) and max_len (maximum sequence length, defaulting to 4096). It creates position encodings as follows:
- Creates an empty tensor
peof shape [max_len, d_model] filled with zeros - Generates position indices from 0 to max_len-1 as a column vector (
pos) - Calculates frequency divisors (
div) using the formula10000^(-2i/d_model)for dimension indices - Applies sine function to even indices (0, 2, 4...) of the embedding dimension
- Applies cosine function to odd indices (1, 3, 5...) of the embedding dimension
- Registers the encodings as a buffer named "pe" (meaning it's part of the model but not a trainable parameter)
Forward Method: This method adds positional information to input embeddings:
- Takes input tensor
xwith shape [Batch, Time, Channels] - Extracts the sequence length
Tfrom the input's second dimension - Adds the pre-computed positional encodings to the input embeddings
- Returns the embedding + position encoding combination
The mathematical intuition behind this approach is that each position is encoded as a unique pattern of sine and cosine waves at different frequencies, allowing the model to learn relative positions.
2.2 Multi-Head Self-Attention (causal)
This component is the heart of the transformer architecture. Multi-Head Self-Attention allows the model to focus on different parts of the input sequence simultaneously, creating rich representations that capture complex relationships between tokens. The "causal" aspect ensures that predictions for each position can only depend on known tokens (those that come before it).
Key characteristics of this implementation:
- Splits the embedding dimension across multiple attention heads, allowing each head to focus on different aspects of the sequence
- Implements the scaled dot-product attention mechanism (dividing by √d_k to stabilize gradients)
- Uses a causal mask to enforce autoregressive behavior, preventing information leakage from future tokens
- Includes dropout for regularization to improve generalization
The implementation below transforms the input through query (Q), key (K), and value (V) projections before computing attention scores and applying the mask to ensure causality.
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.h = n_heads
self.dh = d_model // n_heads
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
q = self.q(x).view(B, T, self.h, self.dh).transpose(1,2) # [B,h,T,dh]
k = self.k(x).view(B, T, self.h, self.dh).transpose(1,2)
v = self.v(x).view(B, T, self.h, self.dh).transpose(1,2)
att = (q @ k.transpose(-2,-1)) / math.sqrt(self.dh) # [B,h,T,T]
# Causal mask
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
att = att.masked_fill(mask, float("-inf"))
w = F.softmax(att, dim=-1)
w = self.drop(w)
y = w @ v # [B,h,T,dh]
y = y.transpose(1,2).contiguous().view(B, T, C) # [B,T,C]
return self.out(y)
Here's a breakdown of the MultiHeadSelfAttention class:
Class Definition and Initialization:
The class inherits from PyTorch's nn.Module and implements multi-head self-attention, which is a key component of transformer architectures. The initialization method takes three parameters:
- d_model: The dimensionality of the input embeddings
- n_heads: The number of attention heads
- dropout: Dropout probability for regularization (defaults to 0.0)
The assertion assert d_model % n_heads == 0 ensures that the embedding dimension is divisible by the number of heads, which is necessary for splitting the embedding into equal parts across heads.
The class initializes several components:
- self.h: Stores the number of attention heads
- self.dh: Calculates the dimension per head (d_model divided by n_heads)
- Linear projections for query (q), key (k), and value (v) transformations
- Output projection layer (out) to combine the multi-head outputs
- Dropout layer for regularization
Forward Method:
The forward method implements the actual attention mechanism:
- Input Unpacking and Projections:
- Extracts the batch size (B), sequence length (T), and channel/embedding dimension (C) from the input shape
- Projects the input through the query, key, and value linear layers
- Reshapes and transposes the projections to separate the heads dimension, resulting in tensors of shape [B, h, T, dh]
- Attention Score Calculation:
- Computes the scaled dot-product attention:
(q @ k.transpose(-2,-1)) / math.sqrt(self.dh) - The scaling factor (1/√dh) stabilizes gradients during training
- The result is an attention matrix of shape [B, h, T, T] where each element represents the attention score between two positions
- Causal Masking:
- Creates an upper triangular mask using
torch.triuwith diagonal=1 - This mask ensures causality - each position can only attend to itself and previous positions
- Sets the masked positions to negative infinity (
float("-inf")) which will become zero after softmax
- Attention Weights and Application:
- Applies softmax to convert attention scores to probabilities (along the last dimension)
- Applies dropout to the attention weights for regularization
- Computes the weighted sum of values:
w @ vresulting in [B, h, T, dh]
- Output Transformation:
- Transposes the result back to [B, T, h, dh]
- Reshapes to [B, T, C] by concatenating the heads
- Applies the output projection to get the final output
This implementation enforces the autoregressive property crucial for generative language models by ensuring each token can only attend to previous tokens in the sequence, which is achieved through the causal mask.
2.3 Feedforward (GELU now; try SwiGLU later)
The Feedforward Network (FFN) is a critical component in transformer architectures that processes the output from the self-attention mechanism. It adds non-linearity and increases the model's representational capacity by projecting inputs to a higher dimension before projecting back.
In this implementation, the FeedForward class:
- Takes input with dimension d_model, projects it to a higher dimension d_ff
- Applies GELU (Gaussian Error Linear Unit) activation function, a smoother alternative to ReLU
- Processes through dropout for regularization
- Projects back to the original dimension d_model
- Applies a final dropout layer
The structure follows a common pattern in transformer designs where the hidden dimension (d_ff) is typically 4x the model dimension, allowing the network to learn more complex patterns.
The suggested upgrade to SwiGLU represents a more advanced activation function that combines elements of SwiSH and GLU (Gated Linear Unit). The mathematical form silu(W1x) * W2x → proj allows for more effective information flow by using a gating mechanism, which has been shown to improve performance in larger language models.
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x): return self.net(x)
Here's a breakdown of the FeedForward class in the code:
Class Definition and Initialization:
The constructor takes three parameters:
- d_model: The input/output dimension (same as the model's embedding dimension)
- d_ff: The inner dimension of the feed-forward network, typically 4x larger than d_model
- dropout: A regularization parameter controlling the dropout probability (defaults to 0.0)
Network Architecture:
The feed-forward network is implemented as a sequential container with the following layers:
- A linear projection from d_model to the larger dimension d_ff
- A GELU (Gaussian Error Linear Unit) activation function, which is a smooth alternative to ReLU
- A dropout layer for regularization
- A linear projection back from d_ff to d_model
- A final dropout layer
Forward Method:
The forward method is concisely implemented as a single line that passes the input x through the sequential network and returns the result:
def forward(self, x): return self.net(x)
Purpose in the Transformer:
This feed-forward network serves several important functions:
- It adds non-linearity to the model, allowing it to learn complex patterns
- It increases the model's representational capacity by projecting to a higher dimension before returning to the model dimension
- It processes each position independently, complementing the self-attention mechanism that models relationships between positions
The code notes suggest upgrading this implementation to SwiGLU (a combination of SwiSH and Gated Linear Unit) for better scaling behavior. SwiGLU would replace the current GELU activation with a gating mechanism of the form silu(W1x) * W2x → proj, which has been shown to improve performance in larger language models
Upgrade idea: Replace with SwiGLU:
silu(W1x) * W2x → projfor better scaling behavior.
2.4 Transformer Block (Pre-Norm)
The Transformer Block is a fundamental building block in modern transformer architectures, implementing the core processing unit that combines self-attention and feed-forward operations. This implementation uses a "Pre-Norm" approach, which applies layer normalization before each sub-layer rather than after (as in the original "Post-Norm" transformer design).
Class Definition and Initialization:
The TransformerBlock class takes several parameters:
- d_model: The embedding dimension throughout the model
- n_heads: Number of attention heads for multi-head attention
- d_ff: Dimension of the feed-forward network's hidden layer
- dropout: Dropout probability for regularization
Architecture Components:
- Two Layer Normalization layers (ln1, ln2) - normalize inputs to attention and feed-forward
- Multi-Head Self-Attention layer (attn) - processes token relationships
- Feed-Forward Network (ff) - adds non-linearity and transforms representations
Forward Pass Flow:
The forward method implements two sequential sub-layers, each with a residual connection:
- Self-Attention Sub-layer:
x = x + self.attn(self.ln1(x))
- First normalizes the input using layer norm
- Passes normalized input through multi-head attention
- Adds the attention output to the original input (residual connection)
- Feed-Forward Sub-layer:
x = x + self.ff(self.ln2(x))
- Normalizes the output from the attention block
- Passes it through the feed-forward network
- Adds the result to the input of this sub-layer (residual connection)
Pre-Norm vs. Post-Norm:
This implementation uses the Pre-Norm variant, which applies normalization before each sub-layer rather than after. Research has shown Pre-Norm leads to more stable training, especially in deeper networks, by ensuring the residual path remains unobstructed for gradient flow.
The residual connections (adding the input to the sub-layer output) are crucial for deep transformer networks as they help mitigate the vanishing gradient problem and allow for effective training of deeper architectures.
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.0):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadSelfAttention(d_model, n_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ff = FeedForward(d_model, d_ff, dropout)
def forward(self, x):
x = x + self.attn(self.ln1(x)) # residual
x = x + self.ff(self.ln2(x)) # residual
return x
Here's a comprehensive breakdown of the TransformerBlock class:
The TransformerBlock class implements a core building block of modern transformer architectures using a "Pre-Norm" approach. This class combines self-attention and feed-forward operations with residual connections and layer normalization.
Class Structure:
The class inherits from PyTorch's nn.Module and takes four parameters:
- d_model: The embedding dimension used throughout the model
- n_heads: Number of attention heads for the multi-head self-attention
- d_ff: Dimension of the feed-forward network's hidden layer
- dropout: Dropout probability for regularization (defaults to 0.0)
Components:
The class initializes four main components:
- self.ln1: First LayerNorm that normalizes inputs to the attention layer
- self.attn: MultiHeadSelfAttention layer that processes relationships between tokens
- self.ln2: Second LayerNorm that normalizes inputs to the feed-forward layer
- self.ff: FeedForward network that adds non-linearity and transforms representations
Forward Method:
The forward method implements two sequential sub-layers, each with a residual connection:
- Self-Attention Sub-layer:
x = x + self.attn(self.ln1(x)) # residual
- First normalizes the input using layer normalization
- Passes the normalized input through multi-head attention
- Adds the attention output to the original input (residual connection)
- Feed-Forward Sub-layer:
x = x + self.ff(self.ln2(x)) # residual
- Normalizes the output from the attention block
- Passes it through the feed-forward network
- Adds the result to the input of this sub-layer (residual connection)
Key Design Choices:
This implementation uses Pre-Norm architecture where normalization is applied before each sub-layer rather than after. Research shows Pre-Norm leads to more stable training, especially in deeper networks, by ensuring the residual path remains unobstructed for gradient flow.
The residual connections (adding the input to the sub-layer output) are crucial as they help mitigate the vanishing gradient problem and allow for effective training of deeper networks.
2. Model Components
2.1 Positional Encoding (sinusoidal; easy and effective)
Sinusoidal positional encoding is a fundamental technique in transformers that injects information about token positions into the model. Without positional encoding, transformers would be position-invariant and unable to distinguish sequence order.
The sinusoidal approach uses sine and cosine functions of different frequencies to create unique position vectors. Each position is encoded as a distinct pattern across the embedding dimensions, allowing the model to learn both absolute and relative positions.
Key advantages of sinusoidal encoding:
- No additional parameters to learn
- Theoretically allows extrapolation to sequences longer than those seen during training
- Creates smooth transitions between positions
You can later swap this for RoPE (Rotary Position Embedding), which directly encodes relative position information into the attention calculation through rotation matrices, often showing better performance on longer sequences and more efficient extrapolation beyond training lengths.
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=4096):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe) # [max_len, d_model]
def forward(self, x): # x: [B,T,C]
T = x.size(1)
return x + self.pe[:T]
This part of the code implements sinusoidal positional encoding, a technique essential in transformer models to provide information about token positions. Let's break it down:
The class inherits from PyTorch's nn.Module and provides two primary methods:
Initialization Method: The constructor takes two parameters - d_model (embedding dimension) and max_len (maximum sequence length, defaulting to 4096). It creates position encodings as follows:
- Creates an empty tensor
peof shape [max_len, d_model] filled with zeros - Generates position indices from 0 to max_len-1 as a column vector (
pos) - Calculates frequency divisors (
div) using the formula10000^(-2i/d_model)for dimension indices - Applies sine function to even indices (0, 2, 4...) of the embedding dimension
- Applies cosine function to odd indices (1, 3, 5...) of the embedding dimension
- Registers the encodings as a buffer named "pe" (meaning it's part of the model but not a trainable parameter)
Forward Method: This method adds positional information to input embeddings:
- Takes input tensor
xwith shape [Batch, Time, Channels] - Extracts the sequence length
Tfrom the input's second dimension - Adds the pre-computed positional encodings to the input embeddings
- Returns the embedding + position encoding combination
The mathematical intuition behind this approach is that each position is encoded as a unique pattern of sine and cosine waves at different frequencies, allowing the model to learn relative positions.
2.2 Multi-Head Self-Attention (causal)
This component is the heart of the transformer architecture. Multi-Head Self-Attention allows the model to focus on different parts of the input sequence simultaneously, creating rich representations that capture complex relationships between tokens. The "causal" aspect ensures that predictions for each position can only depend on known tokens (those that come before it).
Key characteristics of this implementation:
- Splits the embedding dimension across multiple attention heads, allowing each head to focus on different aspects of the sequence
- Implements the scaled dot-product attention mechanism (dividing by √d_k to stabilize gradients)
- Uses a causal mask to enforce autoregressive behavior, preventing information leakage from future tokens
- Includes dropout for regularization to improve generalization
The implementation below transforms the input through query (Q), key (K), and value (V) projections before computing attention scores and applying the mask to ensure causality.
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.h = n_heads
self.dh = d_model // n_heads
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
q = self.q(x).view(B, T, self.h, self.dh).transpose(1,2) # [B,h,T,dh]
k = self.k(x).view(B, T, self.h, self.dh).transpose(1,2)
v = self.v(x).view(B, T, self.h, self.dh).transpose(1,2)
att = (q @ k.transpose(-2,-1)) / math.sqrt(self.dh) # [B,h,T,T]
# Causal mask
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
att = att.masked_fill(mask, float("-inf"))
w = F.softmax(att, dim=-1)
w = self.drop(w)
y = w @ v # [B,h,T,dh]
y = y.transpose(1,2).contiguous().view(B, T, C) # [B,T,C]
return self.out(y)
Here's a breakdown of the MultiHeadSelfAttention class:
Class Definition and Initialization:
The class inherits from PyTorch's nn.Module and implements multi-head self-attention, which is a key component of transformer architectures. The initialization method takes three parameters:
- d_model: The dimensionality of the input embeddings
- n_heads: The number of attention heads
- dropout: Dropout probability for regularization (defaults to 0.0)
The assertion assert d_model % n_heads == 0 ensures that the embedding dimension is divisible by the number of heads, which is necessary for splitting the embedding into equal parts across heads.
The class initializes several components:
- self.h: Stores the number of attention heads
- self.dh: Calculates the dimension per head (d_model divided by n_heads)
- Linear projections for query (q), key (k), and value (v) transformations
- Output projection layer (out) to combine the multi-head outputs
- Dropout layer for regularization
Forward Method:
The forward method implements the actual attention mechanism:
- Input Unpacking and Projections:
- Extracts the batch size (B), sequence length (T), and channel/embedding dimension (C) from the input shape
- Projects the input through the query, key, and value linear layers
- Reshapes and transposes the projections to separate the heads dimension, resulting in tensors of shape [B, h, T, dh]
- Attention Score Calculation:
- Computes the scaled dot-product attention:
(q @ k.transpose(-2,-1)) / math.sqrt(self.dh) - The scaling factor (1/√dh) stabilizes gradients during training
- The result is an attention matrix of shape [B, h, T, T] where each element represents the attention score between two positions
- Causal Masking:
- Creates an upper triangular mask using
torch.triuwith diagonal=1 - This mask ensures causality - each position can only attend to itself and previous positions
- Sets the masked positions to negative infinity (
float("-inf")) which will become zero after softmax
- Attention Weights and Application:
- Applies softmax to convert attention scores to probabilities (along the last dimension)
- Applies dropout to the attention weights for regularization
- Computes the weighted sum of values:
w @ vresulting in [B, h, T, dh]
- Output Transformation:
- Transposes the result back to [B, T, h, dh]
- Reshapes to [B, T, C] by concatenating the heads
- Applies the output projection to get the final output
This implementation enforces the autoregressive property crucial for generative language models by ensuring each token can only attend to previous tokens in the sequence, which is achieved through the causal mask.
2.3 Feedforward (GELU now; try SwiGLU later)
The Feedforward Network (FFN) is a critical component in transformer architectures that processes the output from the self-attention mechanism. It adds non-linearity and increases the model's representational capacity by projecting inputs to a higher dimension before projecting back.
In this implementation, the FeedForward class:
- Takes input with dimension d_model, projects it to a higher dimension d_ff
- Applies GELU (Gaussian Error Linear Unit) activation function, a smoother alternative to ReLU
- Processes through dropout for regularization
- Projects back to the original dimension d_model
- Applies a final dropout layer
The structure follows a common pattern in transformer designs where the hidden dimension (d_ff) is typically 4x the model dimension, allowing the network to learn more complex patterns.
The suggested upgrade to SwiGLU represents a more advanced activation function that combines elements of SwiSH and GLU (Gated Linear Unit). The mathematical form silu(W1x) * W2x → proj allows for more effective information flow by using a gating mechanism, which has been shown to improve performance in larger language models.
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x): return self.net(x)
Here's a breakdown of the FeedForward class in the code:
Class Definition and Initialization:
The constructor takes three parameters:
- d_model: The input/output dimension (same as the model's embedding dimension)
- d_ff: The inner dimension of the feed-forward network, typically 4x larger than d_model
- dropout: A regularization parameter controlling the dropout probability (defaults to 0.0)
Network Architecture:
The feed-forward network is implemented as a sequential container with the following layers:
- A linear projection from d_model to the larger dimension d_ff
- A GELU (Gaussian Error Linear Unit) activation function, which is a smooth alternative to ReLU
- A dropout layer for regularization
- A linear projection back from d_ff to d_model
- A final dropout layer
Forward Method:
The forward method is concisely implemented as a single line that passes the input x through the sequential network and returns the result:
def forward(self, x): return self.net(x)
Purpose in the Transformer:
This feed-forward network serves several important functions:
- It adds non-linearity to the model, allowing it to learn complex patterns
- It increases the model's representational capacity by projecting to a higher dimension before returning to the model dimension
- It processes each position independently, complementing the self-attention mechanism that models relationships between positions
The code notes suggest upgrading this implementation to SwiGLU (a combination of SwiSH and Gated Linear Unit) for better scaling behavior. SwiGLU would replace the current GELU activation with a gating mechanism of the form silu(W1x) * W2x → proj, which has been shown to improve performance in larger language models
Upgrade idea: Replace with SwiGLU:
silu(W1x) * W2x → projfor better scaling behavior.
2.4 Transformer Block (Pre-Norm)
The Transformer Block is a fundamental building block in modern transformer architectures, implementing the core processing unit that combines self-attention and feed-forward operations. This implementation uses a "Pre-Norm" approach, which applies layer normalization before each sub-layer rather than after (as in the original "Post-Norm" transformer design).
Class Definition and Initialization:
The TransformerBlock class takes several parameters:
- d_model: The embedding dimension throughout the model
- n_heads: Number of attention heads for multi-head attention
- d_ff: Dimension of the feed-forward network's hidden layer
- dropout: Dropout probability for regularization
Architecture Components:
- Two Layer Normalization layers (ln1, ln2) - normalize inputs to attention and feed-forward
- Multi-Head Self-Attention layer (attn) - processes token relationships
- Feed-Forward Network (ff) - adds non-linearity and transforms representations
Forward Pass Flow:
The forward method implements two sequential sub-layers, each with a residual connection:
- Self-Attention Sub-layer:
x = x + self.attn(self.ln1(x))
- First normalizes the input using layer norm
- Passes normalized input through multi-head attention
- Adds the attention output to the original input (residual connection)
- Feed-Forward Sub-layer:
x = x + self.ff(self.ln2(x))
- Normalizes the output from the attention block
- Passes it through the feed-forward network
- Adds the result to the input of this sub-layer (residual connection)
Pre-Norm vs. Post-Norm:
This implementation uses the Pre-Norm variant, which applies normalization before each sub-layer rather than after. Research has shown Pre-Norm leads to more stable training, especially in deeper networks, by ensuring the residual path remains unobstructed for gradient flow.
The residual connections (adding the input to the sub-layer output) are crucial for deep transformer networks as they help mitigate the vanishing gradient problem and allow for effective training of deeper architectures.
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.0):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadSelfAttention(d_model, n_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ff = FeedForward(d_model, d_ff, dropout)
def forward(self, x):
x = x + self.attn(self.ln1(x)) # residual
x = x + self.ff(self.ln2(x)) # residual
return x
Here's a comprehensive breakdown of the TransformerBlock class:
The TransformerBlock class implements a core building block of modern transformer architectures using a "Pre-Norm" approach. This class combines self-attention and feed-forward operations with residual connections and layer normalization.
Class Structure:
The class inherits from PyTorch's nn.Module and takes four parameters:
- d_model: The embedding dimension used throughout the model
- n_heads: Number of attention heads for the multi-head self-attention
- d_ff: Dimension of the feed-forward network's hidden layer
- dropout: Dropout probability for regularization (defaults to 0.0)
Components:
The class initializes four main components:
- self.ln1: First LayerNorm that normalizes inputs to the attention layer
- self.attn: MultiHeadSelfAttention layer that processes relationships between tokens
- self.ln2: Second LayerNorm that normalizes inputs to the feed-forward layer
- self.ff: FeedForward network that adds non-linearity and transforms representations
Forward Method:
The forward method implements two sequential sub-layers, each with a residual connection:
- Self-Attention Sub-layer:
x = x + self.attn(self.ln1(x)) # residual
- First normalizes the input using layer normalization
- Passes the normalized input through multi-head attention
- Adds the attention output to the original input (residual connection)
- Feed-Forward Sub-layer:
x = x + self.ff(self.ln2(x)) # residual
- Normalizes the output from the attention block
- Passes it through the feed-forward network
- Adds the result to the input of this sub-layer (residual connection)
Key Design Choices:
This implementation uses Pre-Norm architecture where normalization is applied before each sub-layer rather than after. Research shows Pre-Norm leads to more stable training, especially in deeper networks, by ensuring the residual path remains unobstructed for gradient flow.
The residual connections (adding the input to the sub-layer output) are crucial as they help mitigate the vanishing gradient problem and allow for effective training of deeper networks.
2. Model Components
2.1 Positional Encoding (sinusoidal; easy and effective)
Sinusoidal positional encoding is a fundamental technique in transformers that injects information about token positions into the model. Without positional encoding, transformers would be position-invariant and unable to distinguish sequence order.
The sinusoidal approach uses sine and cosine functions of different frequencies to create unique position vectors. Each position is encoded as a distinct pattern across the embedding dimensions, allowing the model to learn both absolute and relative positions.
Key advantages of sinusoidal encoding:
- No additional parameters to learn
- Theoretically allows extrapolation to sequences longer than those seen during training
- Creates smooth transitions between positions
You can later swap this for RoPE (Rotary Position Embedding), which directly encodes relative position information into the attention calculation through rotation matrices, often showing better performance on longer sequences and more efficient extrapolation beyond training lengths.
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=4096):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe) # [max_len, d_model]
def forward(self, x): # x: [B,T,C]
T = x.size(1)
return x + self.pe[:T]
This part of the code implements sinusoidal positional encoding, a technique essential in transformer models to provide information about token positions. Let's break it down:
The class inherits from PyTorch's nn.Module and provides two primary methods:
Initialization Method: The constructor takes two parameters - d_model (embedding dimension) and max_len (maximum sequence length, defaulting to 4096). It creates position encodings as follows:
- Creates an empty tensor
peof shape [max_len, d_model] filled with zeros - Generates position indices from 0 to max_len-1 as a column vector (
pos) - Calculates frequency divisors (
div) using the formula10000^(-2i/d_model)for dimension indices - Applies sine function to even indices (0, 2, 4...) of the embedding dimension
- Applies cosine function to odd indices (1, 3, 5...) of the embedding dimension
- Registers the encodings as a buffer named "pe" (meaning it's part of the model but not a trainable parameter)
Forward Method: This method adds positional information to input embeddings:
- Takes input tensor
xwith shape [Batch, Time, Channels] - Extracts the sequence length
Tfrom the input's second dimension - Adds the pre-computed positional encodings to the input embeddings
- Returns the embedding + position encoding combination
The mathematical intuition behind this approach is that each position is encoded as a unique pattern of sine and cosine waves at different frequencies, allowing the model to learn relative positions.
2.2 Multi-Head Self-Attention (causal)
This component is the heart of the transformer architecture. Multi-Head Self-Attention allows the model to focus on different parts of the input sequence simultaneously, creating rich representations that capture complex relationships between tokens. The "causal" aspect ensures that predictions for each position can only depend on known tokens (those that come before it).
Key characteristics of this implementation:
- Splits the embedding dimension across multiple attention heads, allowing each head to focus on different aspects of the sequence
- Implements the scaled dot-product attention mechanism (dividing by √d_k to stabilize gradients)
- Uses a causal mask to enforce autoregressive behavior, preventing information leakage from future tokens
- Includes dropout for regularization to improve generalization
The implementation below transforms the input through query (Q), key (K), and value (V) projections before computing attention scores and applying the mask to ensure causality.
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.h = n_heads
self.dh = d_model // n_heads
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
q = self.q(x).view(B, T, self.h, self.dh).transpose(1,2) # [B,h,T,dh]
k = self.k(x).view(B, T, self.h, self.dh).transpose(1,2)
v = self.v(x).view(B, T, self.h, self.dh).transpose(1,2)
att = (q @ k.transpose(-2,-1)) / math.sqrt(self.dh) # [B,h,T,T]
# Causal mask
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
att = att.masked_fill(mask, float("-inf"))
w = F.softmax(att, dim=-1)
w = self.drop(w)
y = w @ v # [B,h,T,dh]
y = y.transpose(1,2).contiguous().view(B, T, C) # [B,T,C]
return self.out(y)
Here's a breakdown of the MultiHeadSelfAttention class:
Class Definition and Initialization:
The class inherits from PyTorch's nn.Module and implements multi-head self-attention, which is a key component of transformer architectures. The initialization method takes three parameters:
- d_model: The dimensionality of the input embeddings
- n_heads: The number of attention heads
- dropout: Dropout probability for regularization (defaults to 0.0)
The assertion assert d_model % n_heads == 0 ensures that the embedding dimension is divisible by the number of heads, which is necessary for splitting the embedding into equal parts across heads.
The class initializes several components:
- self.h: Stores the number of attention heads
- self.dh: Calculates the dimension per head (d_model divided by n_heads)
- Linear projections for query (q), key (k), and value (v) transformations
- Output projection layer (out) to combine the multi-head outputs
- Dropout layer for regularization
Forward Method:
The forward method implements the actual attention mechanism:
- Input Unpacking and Projections:
- Extracts the batch size (B), sequence length (T), and channel/embedding dimension (C) from the input shape
- Projects the input through the query, key, and value linear layers
- Reshapes and transposes the projections to separate the heads dimension, resulting in tensors of shape [B, h, T, dh]
- Attention Score Calculation:
- Computes the scaled dot-product attention:
(q @ k.transpose(-2,-1)) / math.sqrt(self.dh) - The scaling factor (1/√dh) stabilizes gradients during training
- The result is an attention matrix of shape [B, h, T, T] where each element represents the attention score between two positions
- Causal Masking:
- Creates an upper triangular mask using
torch.triuwith diagonal=1 - This mask ensures causality - each position can only attend to itself and previous positions
- Sets the masked positions to negative infinity (
float("-inf")) which will become zero after softmax
- Attention Weights and Application:
- Applies softmax to convert attention scores to probabilities (along the last dimension)
- Applies dropout to the attention weights for regularization
- Computes the weighted sum of values:
w @ vresulting in [B, h, T, dh]
- Output Transformation:
- Transposes the result back to [B, T, h, dh]
- Reshapes to [B, T, C] by concatenating the heads
- Applies the output projection to get the final output
This implementation enforces the autoregressive property crucial for generative language models by ensuring each token can only attend to previous tokens in the sequence, which is achieved through the causal mask.
2.3 Feedforward (GELU now; try SwiGLU later)
The Feedforward Network (FFN) is a critical component in transformer architectures that processes the output from the self-attention mechanism. It adds non-linearity and increases the model's representational capacity by projecting inputs to a higher dimension before projecting back.
In this implementation, the FeedForward class:
- Takes input with dimension d_model, projects it to a higher dimension d_ff
- Applies GELU (Gaussian Error Linear Unit) activation function, a smoother alternative to ReLU
- Processes through dropout for regularization
- Projects back to the original dimension d_model
- Applies a final dropout layer
The structure follows a common pattern in transformer designs where the hidden dimension (d_ff) is typically 4x the model dimension, allowing the network to learn more complex patterns.
The suggested upgrade to SwiGLU represents a more advanced activation function that combines elements of SwiSH and GLU (Gated Linear Unit). The mathematical form silu(W1x) * W2x → proj allows for more effective information flow by using a gating mechanism, which has been shown to improve performance in larger language models.
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x): return self.net(x)
Here's a breakdown of the FeedForward class in the code:
Class Definition and Initialization:
The constructor takes three parameters:
- d_model: The input/output dimension (same as the model's embedding dimension)
- d_ff: The inner dimension of the feed-forward network, typically 4x larger than d_model
- dropout: A regularization parameter controlling the dropout probability (defaults to 0.0)
Network Architecture:
The feed-forward network is implemented as a sequential container with the following layers:
- A linear projection from d_model to the larger dimension d_ff
- A GELU (Gaussian Error Linear Unit) activation function, which is a smooth alternative to ReLU
- A dropout layer for regularization
- A linear projection back from d_ff to d_model
- A final dropout layer
Forward Method:
The forward method is concisely implemented as a single line that passes the input x through the sequential network and returns the result:
def forward(self, x): return self.net(x)
Purpose in the Transformer:
This feed-forward network serves several important functions:
- It adds non-linearity to the model, allowing it to learn complex patterns
- It increases the model's representational capacity by projecting to a higher dimension before returning to the model dimension
- It processes each position independently, complementing the self-attention mechanism that models relationships between positions
The code notes suggest upgrading this implementation to SwiGLU (a combination of SwiSH and Gated Linear Unit) for better scaling behavior. SwiGLU would replace the current GELU activation with a gating mechanism of the form silu(W1x) * W2x → proj, which has been shown to improve performance in larger language models
Upgrade idea: Replace with SwiGLU:
silu(W1x) * W2x → projfor better scaling behavior.
2.4 Transformer Block (Pre-Norm)
The Transformer Block is a fundamental building block in modern transformer architectures, implementing the core processing unit that combines self-attention and feed-forward operations. This implementation uses a "Pre-Norm" approach, which applies layer normalization before each sub-layer rather than after (as in the original "Post-Norm" transformer design).
Class Definition and Initialization:
The TransformerBlock class takes several parameters:
- d_model: The embedding dimension throughout the model
- n_heads: Number of attention heads for multi-head attention
- d_ff: Dimension of the feed-forward network's hidden layer
- dropout: Dropout probability for regularization
Architecture Components:
- Two Layer Normalization layers (ln1, ln2) - normalize inputs to attention and feed-forward
- Multi-Head Self-Attention layer (attn) - processes token relationships
- Feed-Forward Network (ff) - adds non-linearity and transforms representations
Forward Pass Flow:
The forward method implements two sequential sub-layers, each with a residual connection:
- Self-Attention Sub-layer:
x = x + self.attn(self.ln1(x))
- First normalizes the input using layer norm
- Passes normalized input through multi-head attention
- Adds the attention output to the original input (residual connection)
- Feed-Forward Sub-layer:
x = x + self.ff(self.ln2(x))
- Normalizes the output from the attention block
- Passes it through the feed-forward network
- Adds the result to the input of this sub-layer (residual connection)
Pre-Norm vs. Post-Norm:
This implementation uses the Pre-Norm variant, which applies normalization before each sub-layer rather than after. Research has shown Pre-Norm leads to more stable training, especially in deeper networks, by ensuring the residual path remains unobstructed for gradient flow.
The residual connections (adding the input to the sub-layer output) are crucial for deep transformer networks as they help mitigate the vanishing gradient problem and allow for effective training of deeper architectures.
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.0):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadSelfAttention(d_model, n_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ff = FeedForward(d_model, d_ff, dropout)
def forward(self, x):
x = x + self.attn(self.ln1(x)) # residual
x = x + self.ff(self.ln2(x)) # residual
return x
Here's a comprehensive breakdown of the TransformerBlock class:
The TransformerBlock class implements a core building block of modern transformer architectures using a "Pre-Norm" approach. This class combines self-attention and feed-forward operations with residual connections and layer normalization.
Class Structure:
The class inherits from PyTorch's nn.Module and takes four parameters:
- d_model: The embedding dimension used throughout the model
- n_heads: Number of attention heads for the multi-head self-attention
- d_ff: Dimension of the feed-forward network's hidden layer
- dropout: Dropout probability for regularization (defaults to 0.0)
Components:
The class initializes four main components:
- self.ln1: First LayerNorm that normalizes inputs to the attention layer
- self.attn: MultiHeadSelfAttention layer that processes relationships between tokens
- self.ln2: Second LayerNorm that normalizes inputs to the feed-forward layer
- self.ff: FeedForward network that adds non-linearity and transforms representations
Forward Method:
The forward method implements two sequential sub-layers, each with a residual connection:
- Self-Attention Sub-layer:
x = x + self.attn(self.ln1(x)) # residual
- First normalizes the input using layer normalization
- Passes the normalized input through multi-head attention
- Adds the attention output to the original input (residual connection)
- Feed-Forward Sub-layer:
x = x + self.ff(self.ln2(x)) # residual
- Normalizes the output from the attention block
- Passes it through the feed-forward network
- Adds the result to the input of this sub-layer (residual connection)
Key Design Choices:
This implementation uses Pre-Norm architecture where normalization is applied before each sub-layer rather than after. Research shows Pre-Norm leads to more stable training, especially in deeper networks, by ensuring the residual path remains unobstructed for gradient flow.
The residual connections (adding the input to the sub-layer output) are crucial as they help mitigate the vanishing gradient problem and allow for effective training of deeper networks.
