← Notes

Four explanations for Grokking

The network has generalized but long after it has already fit the data. The paper is Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets by Power et al. [1]. I came across it maybe a week after it was posted and didn't really know what to do with it for a while.

Why this is a puzzle

Standard stories about generalization do not predict a long delay. The VC and Rademacher story says generalization either happens or doesn't as a function of how well the hypothesis class matches the data distribution. The implicit-bias stories say SGD hits a minimum that generalizes but that minimum should be found at roughly the same time as convergence of training loss, not thousands of steps later.

Grokking says the loss landscape has a second stage not driven by training loss. Some other variable is moving the weights ($\theta$) during the stretch where the training loss is already near zero.

Four explanations

The first explanation - and the one that the original paper kind of points toward - is that weight decay ($\lambda \|\theta\|_2^2$) is a slow regularizer. The paper notes that grokking only happens with weight decay turned on. Once training loss is zero, the weight-decay term keeps pulling the norm of the weights down even if that loss gradient is tiny. This is slow drift toward a lower-norm solution which may be the one that generalizes. The second explanation comes via mechanistic interpretability.

Figure 1 of arxiv 2201.02177. Training and validation accuracy on modular arithmetic as a function of optimization step on a log scale. The validation curve stays at chance while the training curve saturates, then jumps to one hundred percent much later.
Figure 1 of Power et al. [1]. Training accuracy saturates early; validation accuracy stays flat for orders of magnitude of additional steps and then jumps.

Neel Nanda and collaborators identified specific circuits inside the small grokking networks that implement modular arithmetic via a Fourier ($\hat{f}(\omega)$) decomposition. The grokking transition is when those circuits finish being assembled. Before the transition the network has to memorize by brute force. After the transition it can actually compute. The circuits-level framing this work builds on is laid out in Elhage et al., which is where I'd send anyone to understand what it could mean to talk about a 'circuit' inside a transformer.

The third explanation, due to Liu, Michaud, and Tegmark in Omnigrok, is geometric: the generalizing solution lies in a narrow "Goldilocks zone" of weight norms, and grokking is what you see when the optimizer has been started outside that zone and is slowly being walked into it. Weight decay is what does the walking, which is consistent with explanation one. The fourth angle is to step back and read all of this as a single process viewed at different mesh scales: grokking = double descent but resolved over training time rather than over model or dataset size.

Figure 2 of Nanda et al. 2023. Left: histogram of fraction-of-variance-explained by degree-2 polynomials over neurons. Right: heatmap of components of $W_L$ corresponding to frequency-14 neurons, showing weight concentrated at the sin/cos basis pair for that frequency.
Figure 2 of Nanda et al. [2]. The grokked network's neurons are well-explained by degree-2 polynomials (left), and individual neurons read off specific Fourier-basis pairs from the embedding (right) - the Fourier circuit is mechanistically visible.

In this reading, grokking is double descent unfolding in time. The most lucid public articulations of this "double descent over time" reading come from Preetum Nakkiran's writing, and OpenAI's Deep Double Descent writeup paints the picture visually. On the mechanistic side, Neel Nanda wrote an intuitive walkthrough of the Fourier circuit story and maintains a corresponding paper page. Google PAIR's Do Machine Learning Models Memorize or Generalize? poses the same question in nearby visual vocabulary.

Figure 5 of arxiv 2301.05217. The Discrete Fourier Transform of the grokked network's input embeddings, showing concentration on a small set of frequencies.
Figure 5 of Nanda et al. [2]. The DFT of the network's learned embeddings concentrates in a small number of frequencies after the transition which is the Fourier-based modular arithmetic circuit made visible.

These four views are looking at the same puzzle from different angles, and together they read as one story at different levels of abstraction. Weight decay is the optimization pressure: it selects a minimum-norm interpolant, and in modular arithmetic that interpolant admits the Fourier circuit because Fourier captures the low-rank ($\text{rank}(W) \ll d$) structure of the task. The broader pattern of fast memorization followed by slow compression is the shape double descent takes when you resolve it over time instead of over model size.

Figure 3 of Nanda et al. 2023. Average train accuracy (saturates near 1.0 within ~1k epochs), average test accuracy (stays at chance for ~5k epochs then jumps), and corresponding average train/test log-loss curves over epochs. Faded background lines show individual seeds.
Figure 3 of Nanda et al. [2]. The grokking pattern made averaged: training accuracy saturates fast, test accuracy lags by orders of magnitude before its sudden rise.

The one thing none of these explanations cleanly accounts for is the abruptness of the transition. Smoothly shrinking the norm and smoothly assembling circuits should give smoothly rising validation accuracy, not a near-vertical jump. My read is that the sharpness is largely a measurement artifact: softmax classifiers, $\sigma(z)_j = e^{z_j}/\sum_k e^{z_k}$, route through a top-1 argmax, so logits that are evolving continuously map onto a piecewise-constant accuracy curve that flips once the right logit crosses its competitor.

Neel Nanda summarizing the three-phase mechanistic account of grokking from [2]. The transition to generalization happens during cleanup, not during circuit formation which is why it looks sudden.

Further reading

References