Sequence Modeling & Language Models - Predicting One Token at a Time
Helpful context:
- Recurrent Neural Networks - Memory Hidden in the Hidden State
- Probability as a Language - The Grammar of Uncertainty
- Entropy & Information Theory - The Mathematics of Surprise
“The cat sat on the ___.”
Your brain fills in “mat” (or “floor” or “sofa”) without thinking. You don’t consciously enumerate possibilities and weigh them - the prediction just happens. What you’re doing, without realizing it, is running a language model. You are estimating the probability of each possible next word given every word that came before it.
This is the central task of language modeling: predict the next token given all previous tokens. It sounds almost too simple. It will turn out to be almost universally useful - not just for generating text, but as a pre-training objective that forces a model to develop representations of syntax, semantics, world knowledge, and reasoning.
A Probability Distribution Over Sequences
A language model is, precisely, a probability distribution over sequences of words (or tokens).
For a sequence $w_1, w_2, \ldots, w_n$, the model assigns a probability:
$$P(w_1, w_2, \ldots, w_n).$$
This joint distribution over all possible sequences captures everything about the language: which sentences are grammatical, which are plausible, which are bizarre.
But estimating this joint distribution directly is hopeless. For a vocabulary of $V = 50{,}000$ words and sequences of length $n = 20$, the number of possible sequences is $50000^{20}$ - a number with 91 digits. You cannot have seen most of them.
The key is the chain rule of probability. Any joint distribution factors as:
$$P(w_1, \ldots, w_n) = P(w_1) \cdot P(w_2 \mid w_1) \cdot P(w_3 \mid w_1, w_2) \cdots P(w_n \mid w_1, \ldots, w_{n-1}).$$
Every language model uses this factorization. The job reduces to estimating the conditional distribution $P(w_t \mid w_1, \ldots, w_{t-1})$ - the probability of the next word given all previous context. Estimate this well and you have a language model.
N-gram Models
The oldest approach is to approximate $P(w_t \mid w_1, \ldots, w_{t-1})$ using only the last $n-1$ words. This is the Markov assumption: the future depends on the past only through a fixed window.
$$P(w_t \mid w_1, \ldots, w_{t-1}) \approx P(w_t \mid w_{t-n+1}, \ldots, w_{t-1}).$$
Unigram ($n=1$): ignore context entirely.
$$P(w_t) \approx \frac{C(w_t)}{N}$$
where $C(w_t)$ is the count of $w_t$ in the training corpus and $N$ is the total number of tokens. A unigram model says “dog” is more likely than “xylophone” everywhere, regardless of what came before.
Bigram ($n=2$): condition only on the immediately preceding word.
$$P(w_t \mid w_{t-1}) \approx \frac{C(w_{t-1}, w_t)}{C(w_{t-1})}.$$
The fraction of times $w_t$ follows $w_{t-1}$ in training text. This already produces something recognizable as language. Given “the”, the model knows “cat” is more probable than “quickly”.
Trigram ($n=3$): condition on the two preceding words.
$$P(w_t \mid w_{t-2}, w_{t-1}) \approx \frac{C(w_{t-2}, w_{t-1}, w_t)}{C(w_{t-2}, w_{t-1})}.$$
Longer n-grams capture more context and produce better estimates - but they face a hard wall.
The Problems With N-grams
Data sparsity. In a vocabulary of $V = 50{,}000$ words, there are $V^3 = 1.25 \times 10^{14}$ possible trigrams. A training corpus of a billion words has seen at most a billion of them - fewer than $0.000001%$. Most trigrams you will encounter at test time have count zero in training. The model assigns them probability zero, which is clearly wrong.
Smoothing techniques (Laplace smoothing, Kneser-Ney) redistribute probability mass to unseen n-grams, but they are workarounds, not solutions.
Fixed and short context window. A trigram model uses the last two words - and that is all it will ever use. “The bank by the river refused to grant a loan because it was ___.” To fill the blank correctly (“broke” referring to a person, “dry” referring to the river, or “too small” referring to the bank), you need context from far earlier in the sentence. A trigram model sees only “because it was” and produces gibberish.
N-gram models hit a ceiling. The ceiling is low.
Perplexity: The Standard Evaluation Metric
How do you measure whether one language model is better than another? You need a test corpus the model hasn’t seen, and you ask: how surprised was the model by this text?
Perplexity quantifies this surprise:
$$\text{PP}(W) = \exp\left(-\frac{1}{N} \sum_{t=1}^{N} \log P(w_t \mid w_1, \ldots, w_{t-1})\right).$$
The inner quantity $-\frac{1}{N} \sum \log P$ is the average negative log-likelihood per token - a measure of how many bits (or nats, depending on the log base) of information the model needed to represent each token. Perplexity exponentiates this to get back to the original probability scale.
The intuition: a model with perplexity 100 is “as confused as if it were choosing uniformly from 100 options at each step.” A random model over a 50,000-word vocabulary has perplexity 50,000. A trigram model might achieve perplexity around 100 - 300 on standard corpora. Good modern language models achieve perplexity in the single digits on the same corpora.
Lower is always better. Perplexity of 1 would mean the model predicted every word with certainty - it knows the text perfectly.
Discomfort check. Perplexity is measured on held-out test text, so a model cannot cheat by memorizing the training corpus. But the test corpus still has to be drawn from the same distribution as training. Perplexity on a different domain (say, training on Wikipedia, testing on medical notes) will look worse - not because the model is bad, but because it’s out of distribution. Always match the evaluation corpus to the deployment domain.
RNN Language Models
Recurrent Neural Networks offer a different approach to the problem of variable-length context. Instead of truncating history to the last $n-1$ words, an RNN maintains a hidden state $h_t$ that is supposed to summarize all of the context seen so far.
At each step $t$:
- Read the current word embedding $x_t$.
- Update the hidden state: $h_t = \tanh(W_h h_{t-1} + W_x x_t + b)$.
- Predict the next word: $P(w_{t+1} \mid w_1, \ldots, w_t) = \text{softmax}(W_o h_t)$.
The hidden state $h_t$ is a fixed-size vector (say, 512 or 1024 dimensions) that must summarize all of the previous text. Because the same parameters are reused at every step, the RNN can in principle remember information from arbitrarily far back.
In practice, RNNs struggle to maintain information over long distances. The vanishing gradient problem means that gradients for events 100 steps in the past are near zero - the network effectively forgets. LSTMs and GRUs mitigate this with gating mechanisms, but the fundamental limitation remains.
Training With Teacher Forcing
How do you train an RNN language model? The objective is next-token prediction: at every position, the model should assign high probability to the correct next word.
During training, you use teacher forcing: at each step, you feed the model the true previous word, not whatever the model would have predicted.
Concretely, for the sequence “the cat sat on the mat”:
- Input sequence:
[BOS, the, cat, sat, on, the] - Target sequence:
[the, cat, sat, on, the, mat]
At each step, the input is the true word, not the model’s previous output. The loss is the cross-entropy summed over all positions.
Teacher forcing makes training stable and efficient. But it creates a problem called exposure bias: at inference time (when generating text), the model feeds its own predictions as inputs, not the true previous words. If the model makes an error at step $t$, that wrong word becomes the input at step $t+1$, and errors can cascade. The model was never trained to handle its own mistakes.
Seq2Seq: Encoder-Decoder Models
Language modeling is about predicting the next word in a single sequence. But many tasks require mapping one sequence to another - translation, summarization, question answering. The sequence-to-sequence (seq2seq) framework handles these.
A seq2seq model has two components:
Encoder: reads the entire input sequence and produces a context vector. The encoder is an RNN (or later, a transformer) that processes the source tokens one by one. At the end, the final hidden state $c$ is the context vector - a compressed representation of the entire input.
Decoder: generates the output sequence, one token at a time, conditioned on $c$. The decoder starts with a special [BOS] token and the context vector $c$ as its initial hidden state, then predicts tokens autoregressively: each predicted token becomes the input for the next step.
For machine translation: the encoder reads a French sentence, produces $c$, and the decoder generates the English translation.
The Information Bottleneck Problem
There is a fundamental flaw in this architecture, and it becomes obvious once you think about it carefully.
The entire source sequence is compressed into a single fixed-size vector $c$. A sentence with 100 words, each with complex semantic and syntactic properties, must be squeezed into a vector of dimension 512 (or 1024, or whatever you choose). Information is inevitably lost.
For short sentences (fewer than 20 words), this works reasonably well. For longer sentences, translation quality degrades sharply. The encoder’s final hidden state simply cannot hold everything relevant. The model forgets the beginning of the sentence by the time it reaches the end.
This is not an accident of implementation. It is a structural limitation: you cannot losslessly compress arbitrarily long sequences into a fixed-size vector. Something must be discarded.
Discomfort check. “Next-token prediction sounds too simple to produce capable models.” Here is the counterargument. Consider the sentence: “The bank closed at ___.” To predict “5pm” correctly, you need to know: (1) that “bank” here means a financial institution, not a riverbank - disambiguating word sense from context; (2) that financial institutions are businesses; (3) that businesses have closing times; (4) what typical business hours are. This is world knowledge, semantic disambiguation, and pragmatic reasoning, all required just to predict two words correctly. Now scale this to trillions of tokens. Predicting the next token well, at that scale, requires deep understanding of language and the world. The simplicity of the objective belies the richness of what the model must learn to satisfy it.
What’s Still Broken
RNN language models and seq2seq models were state of the art from roughly 2015 to 2017. They were impressive, but they had three structural problems that no amount of tuning could fix.
Sequential computation. An RNN must process tokens one at a time: $h_1$ must be computed before $h_2$, which must be computed before $h_3$. You cannot parallelize this. Training on a sequence of length 1000 requires 1000 sequential steps. On modern hardware with thousands of processing cores, this is catastrophically inefficient - the parallelism just sits idle.
Long-range dependencies. Even LSTMs struggle to maintain information across more than a few hundred tokens. The hidden state is a bottleneck: everything from the past must be routed through a vector of fixed dimension, and gradient signals for distant dependencies decay to near zero during backpropagation.
Fixed context window. The hidden state has a fixed size regardless of how long the input is. There is no mechanism for the model to “look back” at specific earlier parts of the input when generating an output. The context is smeared into $h_t$ and the smearing is lossy.
These three problems all have the same solution: attention mechanisms. Instead of compressing the past into a single hidden state, let the model look directly at all past positions and decide what to attend to. Instead of sequential computation, compute attention over all positions in parallel.
Summary
| Concept | Definition |
|---|---|
| Language model | $P(w_1, \ldots, w_n) = \prod_t P(w_t \mid w_1, \ldots, w_{t-1})$ |
| N-gram model | Approximates context with last $n-1$ words |
| Bigram estimate | $P(w_t \mid w_{t-1}) = C(w_{t-1}, w_t) / C(w_{t-1})$ |
| Perplexity | $\exp(-\frac{1}{N} \sum \log P(w_t \mid \text{context}))$ |
| RNN language model | Hidden state $h_t$ summarizes all past context |
| Teacher forcing | Train with true previous token; creates exposure bias |
| Seq2seq | Encoder compresses input to $c$; decoder generates from $c$ |
| Information bottleneck | Fixed-size $c$ loses information for long sequences |
Language modeling starts as a probability problem - factor the joint, estimate the conditional. N-grams make it tractable but shallow. RNNs make it deep but slow and forgetful. The bottleneck problem in seq2seq is not a bug; it is the precise diagnosis that leads directly to attention.
Read next: