Project 1: Build a Toy Transformer from Scratch in PyTorch
5. Text Generation
Now that we've set up our model architecture and training loop, it's time to implement text generation. This function allows our model to produce new text by predicting tokens one by one.
Text Generation Function Breakdown:
The generate() function takes a trained model and produces new text from a given prompt. It's decorated with @torch.no_grad() to disable gradient calculations since we're only running inference, not training.
- Parameters: The function accepts a model, starting prompt, desired length of generation, temperature for controlling randomness, and top_k for filtering to only the most likely tokens.
- Process: The function encodes the initial prompt, then repeatedly:
- Takes the most recent context (up to block_size tokens)
- Runs it through the model to get next-token probabilities
- Applies temperature scaling and top-k filtering
- Samples the next token based on the resulting distribution
- Adds the new token to the sequence
- Temperature control: The temperature parameter (set to 0.9 by default) controls randomness in generation. Lower values (e.g., 0.5) make the model more conservative and deterministic, while higher values (e.g., 1.5) increase diversity but might reduce coherence.
- Top-k sampling: This technique restricts token selection to only the k most probable next tokens (50 by default). This prevents the model from selecting extremely unlikely tokens while maintaining some creative variation.
After generating the requested number of tokens, the function decodes the entire sequence back to text and returns it. The code example then demonstrates generating 160 new tokens starting from the prompt "In the".
@torch.no_grad()
def generate(model, prompt="In the", max_new_tokens=120, temperature=0.9, top_k=50):
model.eval()
idx = encode(prompt).unsqueeze(0).to(device) # [1,T]
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:] # crop context
logits = model(idx_cond)[:, -1, :] / max(1e-6, temperature)
# top-k
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # [1,1]
idx = torch.cat([idx, next_id], dim=1)
return decode(idx[0])
print(generate(model, "In the", 160))
Let's break down this text generation function line by line:
The @torch.no_grad() decorator tells PyTorch not to track gradients during execution of this function. This is important for inference since we don't need to calculate gradients, which saves memory and speeds up computation.
The function generate() takes several parameters:
model: The trained TinyGPT modelprompt: Initial text to start generation (defaults to "In the")max_new_tokens: Number of new tokens to generate (defaults to 120)temperature: Controls randomness in generation (defaults to 0.9)top_k: Limits selection to only the k most probable tokens (defaults to 50)
model.eval() puts the model in evaluation mode, which disables dropout and other training-specific behaviors.
idx = encode(prompt).unsqueeze(0).to(device) converts the text prompt into token IDs, adds a batch dimension with unsqueeze(0), and moves the tensor to the appropriate device (CPU/GPU).
The main generation loop runs max_new_tokens times, producing one new token in each iteration:
idx_cond = idx[:, -block_size:] crops the context to the most recent block_size tokens, which is necessary because the model has a limited context window.
logits = model(idx_cond)[:, -1, :] / max(1e-6, temperature) gets predictions for the next token by:
- Running the cropped context through the model
- Selecting only the last position's logits with
[:, -1, :] - Applying temperature scaling by dividing by
temperature(with a minimum value of 1e-6 to prevent division by zero)
The top-k filtering section restricts token selection to only the most likely candidates:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))finds the values of the top k logitslogits[logits < v[:, [-1]]] = -float("inf")sets all logits smaller than the kth largest to negative infinity, effectively removing them from consideration
probs = F.softmax(logits, dim=-1) converts the modified logits into a probability distribution using softmax.
next_id = torch.multinomial(probs, num_samples=1) randomly samples the next token ID based on the calculated probabilities. Using multinomial sampling rather than just picking the highest probability token adds creativity to the generation.
idx = torch.cat([idx, next_id], dim=1) appends the newly generated token to our sequence of tokens.
After generating all tokens, return decode(idx[0]) converts the complete sequence of token IDs back to text and returns it.
Finally, print(generate(model, "In the", 160)) demonstrates the function by generating 160 new tokens starting with the prompt "In the" and printing the result.
This implementation showcases several important techniques in text generation:
- Autoregressive generation (one token at a time)
- Context management for transformer models
- Temperature scaling to control randomness
- Top-k sampling to filter unlikely tokens
- Probability-based sampling for diverse outputs
5. Text Generation
Now that we've set up our model architecture and training loop, it's time to implement text generation. This function allows our model to produce new text by predicting tokens one by one.
Text Generation Function Breakdown:
The generate() function takes a trained model and produces new text from a given prompt. It's decorated with @torch.no_grad() to disable gradient calculations since we're only running inference, not training.
- Parameters: The function accepts a model, starting prompt, desired length of generation, temperature for controlling randomness, and top_k for filtering to only the most likely tokens.
- Process: The function encodes the initial prompt, then repeatedly:
- Takes the most recent context (up to block_size tokens)
- Runs it through the model to get next-token probabilities
- Applies temperature scaling and top-k filtering
- Samples the next token based on the resulting distribution
- Adds the new token to the sequence
- Temperature control: The temperature parameter (set to 0.9 by default) controls randomness in generation. Lower values (e.g., 0.5) make the model more conservative and deterministic, while higher values (e.g., 1.5) increase diversity but might reduce coherence.
- Top-k sampling: This technique restricts token selection to only the k most probable next tokens (50 by default). This prevents the model from selecting extremely unlikely tokens while maintaining some creative variation.
After generating the requested number of tokens, the function decodes the entire sequence back to text and returns it. The code example then demonstrates generating 160 new tokens starting from the prompt "In the".
@torch.no_grad()
def generate(model, prompt="In the", max_new_tokens=120, temperature=0.9, top_k=50):
model.eval()
idx = encode(prompt).unsqueeze(0).to(device) # [1,T]
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:] # crop context
logits = model(idx_cond)[:, -1, :] / max(1e-6, temperature)
# top-k
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # [1,1]
idx = torch.cat([idx, next_id], dim=1)
return decode(idx[0])
print(generate(model, "In the", 160))
Let's break down this text generation function line by line:
The @torch.no_grad() decorator tells PyTorch not to track gradients during execution of this function. This is important for inference since we don't need to calculate gradients, which saves memory and speeds up computation.
The function generate() takes several parameters:
model: The trained TinyGPT modelprompt: Initial text to start generation (defaults to "In the")max_new_tokens: Number of new tokens to generate (defaults to 120)temperature: Controls randomness in generation (defaults to 0.9)top_k: Limits selection to only the k most probable tokens (defaults to 50)
model.eval() puts the model in evaluation mode, which disables dropout and other training-specific behaviors.
idx = encode(prompt).unsqueeze(0).to(device) converts the text prompt into token IDs, adds a batch dimension with unsqueeze(0), and moves the tensor to the appropriate device (CPU/GPU).
The main generation loop runs max_new_tokens times, producing one new token in each iteration:
idx_cond = idx[:, -block_size:] crops the context to the most recent block_size tokens, which is necessary because the model has a limited context window.
logits = model(idx_cond)[:, -1, :] / max(1e-6, temperature) gets predictions for the next token by:
- Running the cropped context through the model
- Selecting only the last position's logits with
[:, -1, :] - Applying temperature scaling by dividing by
temperature(with a minimum value of 1e-6 to prevent division by zero)
The top-k filtering section restricts token selection to only the most likely candidates:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))finds the values of the top k logitslogits[logits < v[:, [-1]]] = -float("inf")sets all logits smaller than the kth largest to negative infinity, effectively removing them from consideration
probs = F.softmax(logits, dim=-1) converts the modified logits into a probability distribution using softmax.
next_id = torch.multinomial(probs, num_samples=1) randomly samples the next token ID based on the calculated probabilities. Using multinomial sampling rather than just picking the highest probability token adds creativity to the generation.
idx = torch.cat([idx, next_id], dim=1) appends the newly generated token to our sequence of tokens.
After generating all tokens, return decode(idx[0]) converts the complete sequence of token IDs back to text and returns it.
Finally, print(generate(model, "In the", 160)) demonstrates the function by generating 160 new tokens starting with the prompt "In the" and printing the result.
This implementation showcases several important techniques in text generation:
- Autoregressive generation (one token at a time)
- Context management for transformer models
- Temperature scaling to control randomness
- Top-k sampling to filter unlikely tokens
- Probability-based sampling for diverse outputs
5. Text Generation
Now that we've set up our model architecture and training loop, it's time to implement text generation. This function allows our model to produce new text by predicting tokens one by one.
Text Generation Function Breakdown:
The generate() function takes a trained model and produces new text from a given prompt. It's decorated with @torch.no_grad() to disable gradient calculations since we're only running inference, not training.
- Parameters: The function accepts a model, starting prompt, desired length of generation, temperature for controlling randomness, and top_k for filtering to only the most likely tokens.
- Process: The function encodes the initial prompt, then repeatedly:
- Takes the most recent context (up to block_size tokens)
- Runs it through the model to get next-token probabilities
- Applies temperature scaling and top-k filtering
- Samples the next token based on the resulting distribution
- Adds the new token to the sequence
- Temperature control: The temperature parameter (set to 0.9 by default) controls randomness in generation. Lower values (e.g., 0.5) make the model more conservative and deterministic, while higher values (e.g., 1.5) increase diversity but might reduce coherence.
- Top-k sampling: This technique restricts token selection to only the k most probable next tokens (50 by default). This prevents the model from selecting extremely unlikely tokens while maintaining some creative variation.
After generating the requested number of tokens, the function decodes the entire sequence back to text and returns it. The code example then demonstrates generating 160 new tokens starting from the prompt "In the".
@torch.no_grad()
def generate(model, prompt="In the", max_new_tokens=120, temperature=0.9, top_k=50):
model.eval()
idx = encode(prompt).unsqueeze(0).to(device) # [1,T]
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:] # crop context
logits = model(idx_cond)[:, -1, :] / max(1e-6, temperature)
# top-k
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # [1,1]
idx = torch.cat([idx, next_id], dim=1)
return decode(idx[0])
print(generate(model, "In the", 160))
Let's break down this text generation function line by line:
The @torch.no_grad() decorator tells PyTorch not to track gradients during execution of this function. This is important for inference since we don't need to calculate gradients, which saves memory and speeds up computation.
The function generate() takes several parameters:
model: The trained TinyGPT modelprompt: Initial text to start generation (defaults to "In the")max_new_tokens: Number of new tokens to generate (defaults to 120)temperature: Controls randomness in generation (defaults to 0.9)top_k: Limits selection to only the k most probable tokens (defaults to 50)
model.eval() puts the model in evaluation mode, which disables dropout and other training-specific behaviors.
idx = encode(prompt).unsqueeze(0).to(device) converts the text prompt into token IDs, adds a batch dimension with unsqueeze(0), and moves the tensor to the appropriate device (CPU/GPU).
The main generation loop runs max_new_tokens times, producing one new token in each iteration:
idx_cond = idx[:, -block_size:] crops the context to the most recent block_size tokens, which is necessary because the model has a limited context window.
logits = model(idx_cond)[:, -1, :] / max(1e-6, temperature) gets predictions for the next token by:
- Running the cropped context through the model
- Selecting only the last position's logits with
[:, -1, :] - Applying temperature scaling by dividing by
temperature(with a minimum value of 1e-6 to prevent division by zero)
The top-k filtering section restricts token selection to only the most likely candidates:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))finds the values of the top k logitslogits[logits < v[:, [-1]]] = -float("inf")sets all logits smaller than the kth largest to negative infinity, effectively removing them from consideration
probs = F.softmax(logits, dim=-1) converts the modified logits into a probability distribution using softmax.
next_id = torch.multinomial(probs, num_samples=1) randomly samples the next token ID based on the calculated probabilities. Using multinomial sampling rather than just picking the highest probability token adds creativity to the generation.
idx = torch.cat([idx, next_id], dim=1) appends the newly generated token to our sequence of tokens.
After generating all tokens, return decode(idx[0]) converts the complete sequence of token IDs back to text and returns it.
Finally, print(generate(model, "In the", 160)) demonstrates the function by generating 160 new tokens starting with the prompt "In the" and printing the result.
This implementation showcases several important techniques in text generation:
- Autoregressive generation (one token at a time)
- Context management for transformer models
- Temperature scaling to control randomness
- Top-k sampling to filter unlikely tokens
- Probability-based sampling for diverse outputs
5. Text Generation
Now that we've set up our model architecture and training loop, it's time to implement text generation. This function allows our model to produce new text by predicting tokens one by one.
Text Generation Function Breakdown:
The generate() function takes a trained model and produces new text from a given prompt. It's decorated with @torch.no_grad() to disable gradient calculations since we're only running inference, not training.
- Parameters: The function accepts a model, starting prompt, desired length of generation, temperature for controlling randomness, and top_k for filtering to only the most likely tokens.
- Process: The function encodes the initial prompt, then repeatedly:
- Takes the most recent context (up to block_size tokens)
- Runs it through the model to get next-token probabilities
- Applies temperature scaling and top-k filtering
- Samples the next token based on the resulting distribution
- Adds the new token to the sequence
- Temperature control: The temperature parameter (set to 0.9 by default) controls randomness in generation. Lower values (e.g., 0.5) make the model more conservative and deterministic, while higher values (e.g., 1.5) increase diversity but might reduce coherence.
- Top-k sampling: This technique restricts token selection to only the k most probable next tokens (50 by default). This prevents the model from selecting extremely unlikely tokens while maintaining some creative variation.
After generating the requested number of tokens, the function decodes the entire sequence back to text and returns it. The code example then demonstrates generating 160 new tokens starting from the prompt "In the".
@torch.no_grad()
def generate(model, prompt="In the", max_new_tokens=120, temperature=0.9, top_k=50):
model.eval()
idx = encode(prompt).unsqueeze(0).to(device) # [1,T]
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:] # crop context
logits = model(idx_cond)[:, -1, :] / max(1e-6, temperature)
# top-k
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # [1,1]
idx = torch.cat([idx, next_id], dim=1)
return decode(idx[0])
print(generate(model, "In the", 160))
Let's break down this text generation function line by line:
The @torch.no_grad() decorator tells PyTorch not to track gradients during execution of this function. This is important for inference since we don't need to calculate gradients, which saves memory and speeds up computation.
The function generate() takes several parameters:
model: The trained TinyGPT modelprompt: Initial text to start generation (defaults to "In the")max_new_tokens: Number of new tokens to generate (defaults to 120)temperature: Controls randomness in generation (defaults to 0.9)top_k: Limits selection to only the k most probable tokens (defaults to 50)
model.eval() puts the model in evaluation mode, which disables dropout and other training-specific behaviors.
idx = encode(prompt).unsqueeze(0).to(device) converts the text prompt into token IDs, adds a batch dimension with unsqueeze(0), and moves the tensor to the appropriate device (CPU/GPU).
The main generation loop runs max_new_tokens times, producing one new token in each iteration:
idx_cond = idx[:, -block_size:] crops the context to the most recent block_size tokens, which is necessary because the model has a limited context window.
logits = model(idx_cond)[:, -1, :] / max(1e-6, temperature) gets predictions for the next token by:
- Running the cropped context through the model
- Selecting only the last position's logits with
[:, -1, :] - Applying temperature scaling by dividing by
temperature(with a minimum value of 1e-6 to prevent division by zero)
The top-k filtering section restricts token selection to only the most likely candidates:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))finds the values of the top k logitslogits[logits < v[:, [-1]]] = -float("inf")sets all logits smaller than the kth largest to negative infinity, effectively removing them from consideration
probs = F.softmax(logits, dim=-1) converts the modified logits into a probability distribution using softmax.
next_id = torch.multinomial(probs, num_samples=1) randomly samples the next token ID based on the calculated probabilities. Using multinomial sampling rather than just picking the highest probability token adds creativity to the generation.
idx = torch.cat([idx, next_id], dim=1) appends the newly generated token to our sequence of tokens.
After generating all tokens, return decode(idx[0]) converts the complete sequence of token IDs back to text and returns it.
Finally, print(generate(model, "In the", 160)) demonstrates the function by generating 160 new tokens starting with the prompt "In the" and printing the result.
This implementation showcases several important techniques in text generation:
- Autoregressive generation (one token at a time)
- Context management for transformer models
- Temperature scaling to control randomness
- Top-k sampling to filter unlikely tokens
- Probability-based sampling for diverse outputs
