# 1 Scaling by Thinking in Continuous Space
spacing=nonfrench
Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach
Jonas Geiping 1 Sean McLeish 2 Neel Jain 2 John Kirchenbauer 2 Siddharth Singh 2 Brian R. Bartoldson 3 Bhavya Kailkhura 3 Abhinav Bhatele 2 Tom Goldstein 2 footnotetext: 1 ELLIS Institute Tübingen, Max-Planck Institute for Intelligent Systems, Tübingen AI Center 2 University of Maryland, College Park 3 Lawrence Livermore National Laboratory. Correspondence to: Jonas Geiping, Tom Goldstein < jonas@tue.ellis.eu, tomg@umd.edu >.
Abstract
We study a novel language model architecture that is capable of scaling test-time computation by implicitly reasoning in latent space. Our model works by iterating a recurrent block, thereby unrolling to arbitrary depth at test-time. This stands in contrast to mainstream reasoning models that scale up compute by producing more tokens. Unlike approaches based on chain-of-thought, our approach does not require any specialized training data, can work with small context windows, and can capture types of reasoning that are not easily represented in words. We scale a proof-of-concept model to 3.5 billion parameters and 800 billion tokens. We show that the resulting model can improve its performance on reasoning benchmarks, sometimes dramatically, up to a computation load equivalent to 50 billion parameters.
Model: huggingface.co/tomg-group-umd/huginn-0125 Code and Data: github.com/seal-rg/recurrent-pretraining
Humans naturally expend more mental effort solving some problems than others. While humans are capable of thinking over long time spans by verbalizing intermediate results and writing them down, a substantial amount of thought happens through complex, recurrent firing patterns in the brain, before the first word of an answer is uttered.
Early attempts at increasing the power of language models focused on scaling model size, a practice that requires extreme amounts of data and computation. More recently, researchers have explored ways to enhance the reasoning capability of models by scaling test time computation. The mainstream approach involves post-training on long chain-of-thought examples to develop the model’s ability to verbalize intermediate calculations in its context window and thereby externalize thoughts.
<details>
<summary>x1.png Details</summary>

### Visual Description
## Chart: Scaling up Test-Time Compute with Recurrent Depth
### Overview
The image is a line chart comparing the accuracy of three different models (ARC challenge, GSM8K CoT, and OpenBookQA) as a function of Test-Time Compute Recurrence. The x-axis represents the Test-Time Compute Recurrence, and the y-axis represents the Accuracy (%). The chart also displays the Materialized Parameters along the top x-axis.
### Components/Axes
* **Title:** Scaling up Test-Time Compute with Recurrent Depth
* **X-axis Title:** Test-Time Compute Recurrence
* X-axis values: 1, 4, 6, 8, 12, 20, 32, 48, 64
* **Y-axis Title:** Accuracy (%)
* Y-axis values: 0, 10, 20, 30, 40, 50
* **Secondary X-axis Title:** Materialized Parameters
* Secondary X-axis values: 3.6B, 8.3B, 11.5B, 14.6B, 21.0B, 33.6B, 52.6B, 77.9B, 103B
* **Legend:** Located in the bottom-right corner.
* Blue: ARC challenge
* Orange: GSM8K CoT
* Green: OpenBookQA
### Detailed Analysis
**ARC challenge (Blue):**
The blue line represents the "ARC challenge" model. The line generally slopes upward, indicating that accuracy increases with Test-Time Compute Recurrence.
* At x=1, Accuracy is approximately 22%.
* At x=4, Accuracy is approximately 33%.
* At x=8, Accuracy is approximately 38%.
* From x=20 onwards, the accuracy plateaus at approximately 47%.
**GSM8K CoT (Orange):**
The orange line represents the "GSM8K CoT" model. The line shows a steep upward trend, indicating a rapid increase in accuracy with Test-Time Compute Recurrence.
* At x=1, Accuracy is approximately 0%.
* At x=4, Accuracy is approximately 1%.
* At x=8, Accuracy is approximately 25%.
* From x=20 onwards, the accuracy plateaus at approximately 47%.
**OpenBookQA (Green):**
The green line represents the "OpenBookQA" model. The line shows a gradual upward trend, indicating a moderate increase in accuracy with Test-Time Compute Recurrence.
* At x=1, Accuracy is approximately 25%.
* At x=4, Accuracy is approximately 26%.
* At x=8, Accuracy is approximately 38%.
* From x=20 onwards, the accuracy plateaus at approximately 42%.
### Key Observations
* The GSM8K CoT model (orange line) shows the most significant improvement in accuracy as Test-Time Compute Recurrence increases, starting from a very low initial accuracy.
* The ARC challenge (blue line) and OpenBookQA (green line) models have higher initial accuracies but show less dramatic improvements with increased Test-Time Compute Recurrence.
* All three models plateau in accuracy after a certain level of Test-Time Compute Recurrence (around x=20).
* The Materialized Parameters increase along with the Test-Time Compute Recurrence.
### Interpretation
The chart suggests that increasing Test-Time Compute Recurrence can improve the accuracy of certain models, particularly those that start with lower initial accuracies. The GSM8K CoT model benefits the most from increased compute, while the ARC challenge and OpenBookQA models show more modest gains. The plateauing of accuracy for all models indicates that there are diminishing returns to increasing Test-Time Compute Recurrence beyond a certain point. The Materialized Parameters are correlated with the Test-Time Compute Recurrence, suggesting that increased compute requires more parameters.
</details>
Figure 1: We train a 3.5B parameter language model with depth recurrence. At test time, the model can iterate longer to use more compute and improve its performance. Instead of scaling test-time reasoning by “verbalizing” in long Chains-of-Thought, the model improves entirely by reasoning in latent space. Tasks that require less reasoning like OpenBookQA converge quicker than tasks like GSM8k, which effectively make use of more compute.
However, the constraint that expensive internal reasoning must always be projected down to a single verbalized next token appears wasteful; it is plausible that models could be more competent if they were able to natively “think” in their continuous latent space. One way to unlock this untapped dimension of additional compute involves adding a recurrent unit to a model. This unit runs in a loop, iteratively processing and updating its hidden state and enabling computations to be carried on indefinitely. While this is not currently the dominant paradigm, this idea is foundational to machine learning and has been (re-)discovered in every decade, for example as recurrent neural networks, diffusion models, and as universal or looped transformers.
In this work, we show that depth-recurrent language models can learn effectively, be trained in an efficient manner, and demonstrate significant performance improvements under the scaling of test-time compute. Our proposed transformer architecture is built upon a latent depth-recurrent block that is run for a randomly sampled number of iterations during training. We show that this paradigm can scale to several billion parameters and over half a trillion tokens of pretraining data. At test-time, the model can improve its performance through recurrent reasoning in latent space, enabling it to compete with other open-source models that benefit from more parameters and training data. Additionally, we show that recurrent depth models naturally support a number of features at inference time that require substantial tuning and research effort in non-recurrent models, such as per-token adaptive compute, (self)-speculative decoding, and KV-cache sharing. We finish out our study by tracking token trajectories in latent space, showing that a number of interesting computation behaviors simply emerge with scale, such as the model rotating shapes in latent space for numerical computations.
2 Why Train Models with Recurrent Depth?
Recurrent layers enable a transformer model to perform arbitrarily many computations before emitting a token. In principle, recurrent mechanisms provide a simple solution for test-time compute scaling. Compared to a more standard approach of long context reasoning OpenAI (2024); DeepSeek-AI et al. (2025), latent recurrent thinking has several advantages.
- Latent reasoning does not require construction of bespoke training data. Chain-of-thought reasoning requires the model to be trained on long demonstrations that are constructed in the domain of interest. In contrast, our proposed latent reasoning models can train with a variable compute budget, using standard training data with no specialized demonstrations, and enhance their abilities at test-time if given additional compute.
- Latent reasoning models require less memory for training and inference than chain-of-thought reasoning models. Because the latter require extremely long context windows, specialized training methods such as token-parallelization Liu et al. (2023a) may be needed.
- Recurrent-depth networks perform more FLOPs per parameter than standard transformers, significantly reducing communication costs between accelerators at scale. This especially enables higher device utilization when training with slower interconnects.
- By constructing an architecture that is compute-heavy and small in parameter count, we hope to set a strong prior towards models that solve problems by “thinking”, i.e. by learning meta-strategies, logic and abstraction, instead of memorizing. The strength of recurrent priors for learning complex algorithms has already been demonstrated in the “deep thinking” literature Schwarzschild et al. (2021b); Bansal et al. (2022); Schwarzschild et al. (2023).
On a more philosophical note, we hope that latent reasoning captures facets of human reasoning that defy verbalization, such as spatial thinking, physical intuition or (motor) planning. Over many iterations of the recurrent process, reasoning in a high-dimensional vector space would enable the deep exploration of multiple directions simultaneously, instead of linear thinking, leading to a system capable of exhibiting novel and complex reasoning behavior.
Scaling compute in this manner is not at odds with scaling through extended (verbalized) inference scaling (Shao et al., 2024), or scaling parameter counts in pretraining (Kaplan et al., 2020), we argue it may build a third axis on which to scale model performance.
———————— Table of Contents ————————
- Section 3 introduces our latent recurrent-depth model architecture and training objective.
- Section 4 describes the data selection and engineering of our large-scale training run on Frontier, an AMD cluster.
- Section 5 reports benchmark results, showing how the model improves when scaling inference compute.
- Section 6 includes several application examples showing how recurrent models naturally simplify LLM usecases.
- Section 7 visualizes what computation patterns emerge at scale with this architecture and training objective, showing that context-dependent behaviors emerge in latent space, such as “orbiting” when responding to prompts requiring numerical reasoning.
<details>
<summary>x2.png Details</summary>

### Visual Description
## Diagram: Recurrent Block Diagram
### Overview
The image is a diagram illustrating a process flow involving a "Prelude" block, a series of "Recurrent Blocks," and a "Coda" block. The diagram shows how input is injected and how residual streams are passed between these blocks.
### Components/Axes
* **Blocks:**
* **Prelude (P):** A light blue rounded rectangle labeled "P".
* **Recurrent Block (R):** A series of green rounded rectangles labeled "R".
* **Coda (C):** A pink rounded rectangle labeled "C".
* **Inputs/Outputs:**
* "Hello" (input to Prelude)
* "World" (output from Coda)
* **Arrows:**
* Solid black arrows indicate the "Residual Stream".
* Dashed gray arrows indicate "Input Injection".
* **Labels:**
* *e*: Label for the input injection arrows.
* *s0, s1, sR*: Labels for the residual stream arrows.
* N(0, σ^2 I_{n-h}): Label for the initial residual stream from the Prelude block.
* *p*: Label for the output arrow from the Coda block.
* **Legend:** Located at the bottom of the diagram.
* Light blue square: Prelude
* Green square: Recurrent Block
* Pink square: Coda
* Dashed gray arrow: Input Injection
* Solid black arrow: Residual Stream
### Detailed Analysis
1. **Prelude (P):**
* Input: "Hello" flows into the Prelude block.
* Output: The Prelude block outputs to the first Recurrent Block via a residual stream labeled N(0, σ^2 I_{n-h}) and *s0*.
* Input Injection: The Prelude block also injects input into each Recurrent Block via dashed gray arrows labeled *e*.
2. **Recurrent Blocks (R):**
* A series of green "R" blocks are connected by solid black arrows representing the residual stream.
* The residual streams between the blocks are labeled *s0*, *s1*, and *sR*.
* Each Recurrent Block receives input injection from the Prelude block, labeled *e*.
3. **Coda (C):**
* Input: The final Recurrent Block outputs to the Coda block via a residual stream labeled *sR*.
* Output: The Coda block outputs "World" via an arrow labeled *p*.
### Key Observations
* The diagram illustrates a sequential process with feedback from the initial Prelude block to each Recurrent Block.
* The residual stream is the primary flow of information between the Recurrent Blocks.
* The input injection provides additional information to each Recurrent Block.
### Interpretation
The diagram represents a recurrent neural network architecture or a similar sequential processing model. The Prelude block likely prepares the initial input, the Recurrent Blocks perform iterative processing, and the Coda block generates the final output. The input injection mechanism allows the initial input to influence each step of the recurrent process. The residual stream ensures that information is carried forward through the sequence. The N(0, σ^2 I_{n-h}) likely represents a Gaussian noise distribution, which is used to initialize the hidden state of the recurrent network.
</details>
Figure 2: A visualization of the Architecture, as described in Section 3. Each block consists of a number of sub-layers. The blue prelude block embeds the inputs into latent space, where the green shared recurrent block is a block of layers that is repeated to compute the final latent state, which is decoded by the layers of the red coda block.
3 A scalable recurrent architecture
In this section we will describe our proposed architecture for a transformer with latent recurrent depth, discussing design choices and small-scale ablations. A diagram of the architecture can be found in Figure 2. We always refer to the sequence dimension as $n$ , the hidden dimension of the model as $h$ , and its vocabulary as the set $V$ .
3.1 Macroscopic Design
The model is primarily structured around decoder-only transformer blocks (Vaswani et al., 2017; Radford et al., 2019). However these blocks are structured into three functional groups, the prelude $P$ , which embeds the input data into a latent space using multiple transformer layers, then the core recurrent block $R$ , which is the central unit of recurrent computation modifying states $\mathbf{s}∈\mathbb{R}^{n× h}$ , and finally the coda $C$ , which un-embeds from latent space using several layers and also contains the prediction head of the model. The core block is set between the prelude and coda blocks, and by looping the core we can put an indefinite amount of verses in our song.
Given a number of recurrent iterations $r$ , and a sequence of input tokens $\mathbf{x}∈ V^{n}$ these groups are used in the following way to produce output probabilities $\mathbf{p}∈\mathbb{R}^{n×|V|}$
| | $\displaystyle\mathbf{e}$ | $\displaystyle=P(\mathbf{x})$ | |
| --- | --- | --- | --- |
where $\sigma$ is some standard deviation for initializing the random state. This process is shown in Figure 2. Given an init random state $\mathbf{s}_{0}$ , the model repeatedly applies the core block $R$ , which accepts the latent state $\mathbf{s}_{i-1}$ and the embedded input $\mathbf{e}$ and outputs a new latent state $\mathbf{s}_{i}$ . After finishing all iterations, the coda block processes the last state and produces the probabilities of the next token.
This architecture is based on deep thinking literature, where it is shown that injecting the latent inputs $\mathbf{e}$ in every step (Bansal et al., 2022) and initializing the latent vector with a random state stabilizes the recurrence and promotes convergence to a steady state independent of initialization, i.e. path independence (Anil et al., 2022).
Motivation for this Design.
This recurrent design is the minimal setup required to learn stable iterative operators. A good example is gradient descent of a function $E(\mathbf{x},\mathbf{y})$ , where $\mathbf{x}$ may be the variable of interest and $\mathbf{y}$ the data. Gradient descent on this function starts from an initial random state, here $\mathbf{x}_{0}$ , and repeatedly applies a simple operation (the gradient of the function it optimizes), that depends on the previous state $\mathbf{x}_{k}$ and data $\mathbf{y}$ . Note that we need to use $\mathbf{y}$ in every step to actually optimize our function. Similarly we repeatedly inject the data $\mathbf{e}$ in our set-up in every step of the recurrence. If $\mathbf{e}$ was provided only at the start, e.g. via $\mathbf{s}_{0}=\mathbf{e}$ , then the iterative process would not be stable Stable in the sense that $R$ cannot be a monotone operator if it does not depend on $\mathbf{e}$ , and so cannot represent gradient descent on strictly convex, data-dependent functions, (Bauschke et al., 2011), as its solution would depend only on its boundary conditions.
The structure of using several layers to embed input tokens into a hidden latent space is based on empirical results analyzing standard fixed-depth transformers (Skean et al., 2024; Sun et al., 2024; Kaplan et al., 2024). This body of research shows that the initial and the end layers of LLMs are noticeably different, whereas middle layers are interchangeable and permutable. For example, Kaplan et al. (2024) show that within a few layers standard models already embed sub-word tokens into single concepts in latent space, on which the model then operates.
**Remark 3.1 (Is this a Diffusion Model?)**
*This iterative architecture will look familiar to the other modern iterative modeling paradigm, diffusion models (Song and Ermon, 2019), especially latent diffusion models (Rombach et al., 2022). We ran several ablations with iterative schemes even more similar to diffusion models, such as $\mathbf{s}_{i}=R(\mathbf{e},\mathbf{s}_{i-1})+\mathbf{n}$ where $\mathbf{n}\sim\mathcal{N}(\mathbf{0},\sigma_{i}I_{n· h})$ , but find the injection of noise not to help in our preliminary experiments, which is possibly connected to our training objective. We also evaluated and $\mathbf{s}_{i}=R_{i}(\mathbf{e},\mathbf{s}_{i-1})$ , i.e. a core block that takes the current step as input (Peebles and Xie, 2023), but find that this interacts badly with path independence, leading to models that cannot extrapolate.*
3.2 Microscopic Design
Within each group, we broadly follow standard transformer layer design. Each block contains multiple layers, and each layer contains a standard, causal self-attention block using RoPE (Su et al., 2021) with a base of $50000$ , and a gated SiLU MLP (Shazeer, 2020). We use RMSNorm (Zhang and Sennrich, 2019) as our normalization function. The model has learnable biases on queries and keys, and nowhere else. To stabilize the recurrence, we order all layers in the following “sandwich” format, using norm layers $n_{i}$ , which is related, but not identical to similar strategies in (Ding et al., 2021; Team Gemma et al., 2024):
| | $\displaystyle\hat{\mathbf{x}_{l}}=$ | $\displaystyle n_{2}\left(\mathbf{x}_{l-1}+\textnormal{Attn}(n_{1}(\mathbf{x}_{%
l-1}))\right)$ | |
| --- | --- | --- | --- |
While at small scales, most normalization strategies, e.g. pre-norm, post-norm and others, work almost equally well, we ablate these options and find that this normalization is required to train the recurrence at scale Note also that technically $n_{3}$ is superfluous, but we report here the exact norm setup with which we trained the final model..
Given an embedding matrix $E$ and embedding scale $\gamma$ , the prelude block first embeds input tokens $\mathbf{x}$ as $\gamma E(\mathbf{x})$ , and then to applies $l_{P}$ many prelude layers with the layout described above.
Our core recurrent block $R$ starts with an adapter matrix $A:\mathbb{R}^{2h}→\mathbb{R}^{h}$ mapping the concatenation of $\mathbf{s}_{i}$ and $\mathbf{e}$ into the hidden dimension $h$ (Bansal et al., 2022). While re-incorporation of initial embedding features via addition rather than concatenation works equally well for smaller models, we find that concatenation works best at scale. This is then fed into $l_{R}$ transformer layers. At the end of the core block the output is again rescaled with an RMSNorm $n_{c}$ .
The coda contains $l_{C}$ layers, normalization by $n_{c}$ , and projection into the vocabulary using tied embeddings $E^{T}$ .
In summary, we can summarize the architecture by the triplet $(l_{P},l_{R},l_{C})$ , describing the number of layers in each stage, and by the number of recurrences $r$ , which may vary in each forward pass. We train a number of small-scale models with shape $(1,4,1)$ and hidden size $h=1024$ , in addition to a large model with shape $(2,4,2)$ and $h=5280$ . This model has only $8$ “real” layers, but when the recurrent block is iterated, e.g. 32 times, it unfolds to an effective depth of $2+4r+2=132$ layers, constructing computation chains that can be deeper than even the largest fixed-depth transformers (Levine et al., 2021; Merrill et al., 2022).
3.3 Training Objective
Training Recurrent Models through Unrolling.
To ensure that the model can function when we scale up recurrent iterations at test-time, we randomly sample iteration counts during training, assigning a random number of iterations $r$ to every input sequence (Schwarzschild et al., 2021b). We optimize the expectation of the loss function $L$ over random samples $x$ from distribution $X$ and random iteration counts $r$ from distribution $\Lambda$ .
$$
\mathcal{L}(\theta)=\mathbb{E}_{\mathbf{x}\in X}\mathbb{E}_{r\sim\Lambda}L%
\left(m_{\theta}(\mathbf{x},r),\mathbf{x}^{\prime}\right).
$$
Here, $m$ represents the model output, and $\mathbf{x}^{\prime}$ is the sequence $\mathbf{x}$ shifted left, i.e., the next tokens in the sequence $\mathbf{x}$ . We choose $\Lambda$ to be a log-normal Poisson distribution. Given a targeted mean recurrence $\bar{r}+1$ and a variance that we set to $\sigma=\frac{1}{2}$ , we can sample from this distribution via
$$
\displaystyle\tau \displaystyle\sim\mathcal{N}(\log(\bar{r})-\frac{1}{2}\sigma^{2},\sigma) \displaystyle r \displaystyle\sim\mathcal{P}(e^{\tau})+1, \tag{1}
$$
given the normal distribution $\mathcal{N}$ and Poisson distribution $\mathcal{P}$ , see Figure 3. The distribution most often samples values less than $\bar{r}$ , but it contains a heavy tail of occasional events in which significantly more iterations are taken.
<details>
<summary>x3.png Details</summary>

### Visual Description
## Density Plot: Sampled r Distribution
### Overview
The image is a density plot showing the distribution of a variable "Sampled r". The plot displays the density of the variable on the y-axis and the values of "Sampled r" on the x-axis. Vertical lines indicate the mean, median, and mode of the distribution.
### Components/Axes
* **X-axis:** "Sampled r", with a scale from 0 to 150, marked at intervals of 25 (0, 25, 50, 75, 100, 125, 150).
* **Y-axis:** "Density", with a scale from 0.00 to 0.03, marked at intervals of 0.01 (0.00, 0.01, 0.02, 0.03).
* **Density Curve:** A black line representing the probability density function of "Sampled r".
* **Mean:** A blue dashed vertical line indicating the mean value of "Sampled r". Mean = 33.0
* **Median:** A green dashed vertical line indicating the median value of "Sampled r". Median = 29.0
* **Mode:** A red dashed vertical line indicating the mode value of "Sampled r". Mode = 24.0
* **Legend:** Located at the bottom of the chart, explaining the meaning of the lines:
* Black line: "Density"
* Blue dashed line: "Mean = 33.0"
* Green dashed line: "Median = 29.0"
* Red dashed line: "Mode = 24.0"
### Detailed Analysis
* **Density Curve:** The black density curve starts near 0 on the y-axis at x=0, rises to a peak around x=24 (the mode), and then gradually decreases towards 0 as x increases to 150. The curve is right-skewed.
* **Mean:** The blue dashed line is positioned at x=33.0.
* **Median:** The green dashed line is positioned at x=29.0.
* **Mode:** The red dashed line is positioned at x=24.0.
### Key Observations
* The distribution is unimodal and right-skewed.
* The mode (24.0) is less than the median (29.0), which is less than the mean (33.0), confirming the right skewness.
* The density is highest around the mode (24.0) and decreases as "Sampled r" moves away from this value.
### Interpretation
The density plot illustrates the distribution of the "Sampled r" variable. The right skewness indicates that there are more values clustered towards the lower end of the range, with a tail extending towards higher values. The mean being greater than the median further supports this skewness. The mode represents the most frequently occurring value of "Sampled r" in the sample. The plot provides a visual representation of the probability of observing different values of "Sampled r".
</details>
Figure 3: We use a log-normal Poisson Distribution to sample the number of recurrent iterations for each training step.
Truncated Backpropagation.
To keep computation and memory low at train time, we backpropagate through only the last $k$ iterations of the recurrent unit. This enables us to train with the heavy-tailed Poisson distribution $\Lambda$ , as maximum activation memory and backward compute is now independent of $r$ . We fix $k=8$ in our main experiments. At small scale, this works as well as sampling $k$ uniformly, but with set fixed, the overall memory usage in each step of training is equal. Note that the prelude block still receives gradient updates in every step, as its output $\mathbf{e}$ is injected in every step. This setup resembles truncated backpropagation through time, as commonly done with RNNs, although our setup is recurrent in depth rather than time (Williams and Peng, 1990; Mikolov et al., 2011).
4 Training a large-scale recurrent-depth Language Model
After verifying that we can reliably train small test models up to 10B tokens, we move on to larger-scale runs. Given our limited compute budget, we could either train multiple tiny models too small to show emergent effects or scaling, or train a single medium-scale model. Based on this, we prepared for a single run, which we detail below.
4.1 Training Setup
We describe the training setup, separated into architecture, optimization setup and pretraining data. We publicly release all training data, pretraining code, and a selection of intermediate model checkpoints.
Pretraining Data.
Given access to only enough compute for a single large scale model run, we opted for a dataset mixture that maximized the potential for emergent reasoning behaviors, not necessarily for optimal benchmark performance. Our final mixture is heavily skewed towards code and mathematical reasoning data with (hopefully) just enough general webtext to allow the model to acquire standard language modeling abilities. All sources are publicly available. We provide an overview in Figure 4. Following Allen-Zhu and Li (2024), we directly mix relevant instruction data into the pretraining data. However, due to compute and time constraints, we were not able to ablate this mixture. We expect that a more careful data preparation could further improve the model’s performance. We list all data sources in Appendix C.
Tokenization and Packing Details.
We construct a vocabulary of $65536$ tokens via BPE (Sennrich et al., 2016), using the implementation of Dagan (2024). In comparison to conventional tokenizer training, we construct our tokenizer directly on the instruction data split of our pretraining corpus, to maximize tokenization efficiency on the target domain. We also substantially modify the pre-tokenization regex (e.g. of Dagan et al. (2024)) to better support code, contractions and LaTeX. We include a <|begin_text|> token at the start of every document. After tokenizing our pretraining corpus, we pack our tokenized documents into sequences of length 4096. When packing, we discard document ends that would otherwise lack previous context, to fix an issue described as the “grounding problem” in Ding et al. (2024), aside from several long-document sources of mathematical content, which we preserve in their entirety.
<details>
<summary>x4.png Details</summary>

### Visual Description
## Chart Type: Pie Chart of Text Categories
### Overview
The image is a pie chart illustrating the distribution of different categories of text. The chart shows the percentage breakdown of each category, with "generic-text" and "code" representing the largest portions. The legend on the right side of the chart maps each text category to a specific color.
### Components/Axes
* **Chart Type:** Pie Chart
* **Categories:**
* generic-text
* code
* scientific-text
* synthetic-text
* longform-text
* math
* generic-instruct
* Q&A-text
* math-instruct
* writing-instruct
* misc-reasoning
* **Legend:** Located on the right side of the pie chart. Each category is associated with a specific color.
* Blue: generic-text: 28.71%
* Orange: code: 25.36%
* Green: scientific-text: 18.73%
* Red: synthetic-text: 8.14%
* Purple: longform-text: 7.50%
* Brown: math: 6.14%
* Pink: generic-instruct: 2.09%
* Gray: Q&A-text: 1.58%
* Yellow: math-instruct: 1.51%
* Teal: writing-instruct: 0.12%
* Dark Blue: misc-reasoning: 0.11%
### Detailed Analysis
The pie chart is divided into slices, each representing a different category of text. The size of each slice corresponds to the percentage of that category.
* **generic-text:** (Blue) 28.71% - Largest slice
* **code:** (Orange) 25.36% - Second largest slice
* **scientific-text:** (Green) 18.73%
* **synthetic-text:** (Red) 8.14%
* **longform-text:** (Purple) 7.50%
* **math:** (Brown) 6.14%
* **generic-instruct:** (Pink) 2.09%
* **Q&A-text:** (Gray) 1.58%
* **math-instruct:** (Yellow) 1.51%
* **writing-instruct:** (Teal) 0.12%
* **misc-reasoning:** (Dark Blue) 0.11% - Smallest slice
### Key Observations
* "generic-text" and "code" constitute the majority of the text categories, accounting for over half of the total distribution.
* "scientific-text" is the third largest category, representing a significant portion of the distribution.
* The remaining categories each represent a relatively small percentage of the total.
* "writing-instruct" and "misc-reasoning" are the smallest categories, with percentages close to zero.
### Interpretation
The pie chart provides a clear visualization of the distribution of different text categories. The dominance of "generic-text" and "code" suggests that these types of text are the most prevalent in the dataset being analyzed. The relatively small percentages of "writing-instruct" and "misc-reasoning" indicate that these categories are less common. The data suggests a diverse range of text types, with a concentration in generic and code-related content.
</details>
Figure 4: Distribution of data sources that are included during training. The majority of our data is comprised of generic web-text, scientific writing and code.
Architecture and Initialization.
We scale the architecture described in Section 3, setting the layers to $(2,4,2)$ , and train with a mean recurrence value of $\bar{r}=32$ . We mainly scale by increasing the hidden size to $h=5280$ , which yields $55$ heads of size of $96$ . The MLP inner dimension is $17920$ and the RMSNorm $\varepsilon$ is $10^{-6}$ . Overall this model shape has about $1.5$ B parameters in non-recurrent prelude and head, $1.5$ B parameters in the core recurrent block, and $0.5$ B in the tied input embedding.
At small scales, most sensible initialization schemes work. However, at larger scales, we use the initialization of Takase et al. (2024) which prescribes a variance of $\sigma_{h}^{2}=\frac{2}{5h}$ . We initialize all parameters from a truncated normal distribution (truncated at $3\sigma$ ) with this variance, except all out-projection layers, where the variance is set to $\sigma_{\textnormal{out}}^{2}=\frac{1}{5hl}$ , for $l=l_{P}+\bar{r}l_{R}+l_{C}$ the number of effective layers, which is 132 for this model. As a result, the out-projection layers are initialized with fairly small values (Goyal et al., 2018). The output of the embedding layer is scaled by $\sqrt{h}$ . To match this initialization, the state $s_{0}$ is also sampled from a truncated normal distribution, here with variance $\sigma_{s}^{2}=\frac{2}{5}$ .
Locked-Step Sampling.
To enable synchronization between parallel workers, we sample a single depth $r$ for each micro-batch of training, which we synchronize across workers (otherwise workers would idle while waiting for the model with the largest $r$ to complete its backward pass). We verify at small scale that this modification improves compute utilization without impacting convergence speed, but note that at large batch sizes, training could be further improved by optimally sampling and scheduling independent steps $r$ on each worker, to more faithfully model the expectation over steps in Equation 1.
Optimizer and Learning Rate Schedule.
We train using the Adam optimizer with decoupled weight regularization ( $\beta_{1}=0.9$ , $\beta_{2}=0.95$ , $\eta=$5\text{×}{10}^{-4}$$ ) (Kingma and Ba, 2015; Loshchilov and Hutter, 2017), modified to include update clipping (Wortsman et al., 2023b) and removal of the $\varepsilon$ constant as in Everett et al. (2024). We clip gradients above $1$ . We train with warm-up and a constant learning rate (Zhai et al., 2022; Geiping and Goldstein, 2023), warming up to our maximal learning rate within the first $4096$ steps.
4.2 Compute Setup and Hardware
We train this model using compute time allocated on the Oak Ridge National Laboratory’s Frontier supercomputer. This HPE Cray system contains 9408 compute nodes with AMD MI250X GPUs, connected via 4xHPE Slingshot-11 NICs. The scheduling system is orchestrated through SLURM. We train in bfloat16 mixed precision using a PyTorch-based implementation (Zamirai et al., 2021).
<details>
<summary>x5.png Details</summary>

### Visual Description
## Line Charts: Training Performance Comparison
### Overview
The image presents three line charts comparing the performance of a "Main" training run against two "Bad Run" scenarios, and the impact of recurrence on validation perplexity. The charts display Loss, Hidden State Correlation, and Validation Perplexity (Val PPL) as a function of Optimizer Step. All axes are on a logarithmic scale.
### Components/Axes
**Chart 1: Loss vs. Optimizer Step**
* **Y-axis:** Loss (log) - logarithmic scale, ranging from approximately 0.1 to 10.
* **X-axis:** Optimizer Step - logarithmic scale, ranging from 10^1 to 10^4.
* **Data Series:**
* Main (blue): Starts around 10, decreases to approximately 0.2 by step 10^4.
* Bad Run 1 (orange): Remains relatively constant at approximately 5.
* Bad Run 2 (green): Starts around 10, decreases to approximately 0.5 by step 10^4.
**Chart 2: Hidden State Corr. vs. Optimizer Step**
* **Y-axis:** Hidden State Corr. (log) - logarithmic scale, ranging from 10^-1 to 10^0 (0.1 to 1).
* **X-axis:** Optimizer Step - logarithmic scale, ranging from 10^1 to 10^4.
* **Data Series:**
* Main (blue): Starts around 0.3, fluctuates, and decreases to approximately 0.05 by step 10^4.
* Bad Run 1 (orange): Remains relatively constant at approximately 1.
* Bad Run 2 (green): Starts around 1, decreases to approximately 0.05 by step 10^4.
**Chart 3: Val PPL vs. Optimizer Step**
* **Y-axis:** Val PPL (log) - logarithmic scale, ranging from 10^1 to 10^3 (10 to 1000).
* **X-axis:** Optimizer Step - logarithmic scale, ranging from 10^2 to 10^4 (100 to 10000).
* **Data Series (Recurrence):**
* 1 (black, solid line): Starts around 2000, decreases to approximately 100 by step 10^4.
* 4 (black, dashed line): Starts around 500, decreases to approximately 10 by step 10^4.
* 8 (black, dotted line): Starts around 300, decreases to approximately 10 by step 10^4.
* 16 (black, dash-dot line): Starts around 200, decreases to approximately 10 by step 10^4.
* 32 (black, long dash line): Starts around 100, decreases to approximately 10 by step 10^4.
* 64 (black, dash-dot-dot line): Starts around 100, decreases to approximately 10 by step 10^4.
* Bad Run 1 (orange, solid line): Remains relatively constant at approximately 2000.
* Bad Run 2 (green, solid line): Starts around 2000, decreases to approximately 20 by step 10^4.
**Legend:**
* Located at the bottom of the first two charts, and on the right side of the third chart.
* Identifies the data series by color and line style.
* Main: Blue line
* Bad Run 1: Orange line
* Bad Run 2: Green line
* Recurrence: Black lines with varying styles (solid, dashed, dotted, etc.)
### Detailed Analysis
**Chart 1: Loss**
* The "Main" run (blue) shows a significant decrease in loss as the optimizer step increases, indicating successful training.
* "Bad Run 1" (orange) shows a consistently high loss, suggesting a failure to learn.
* "Bad Run 2" (green) shows a decrease in loss, but not as significant as the "Main" run.
**Chart 2: Hidden State Correlation**
* The "Main" run (blue) shows a fluctuating but generally decreasing hidden state correlation.
* "Bad Run 1" (orange) maintains a high correlation, indicating a lack of diversity in the hidden state.
* "Bad Run 2" (green) shows a decreasing correlation, similar to the "Main" run, but with more noise.
**Chart 3: Validation Perplexity**
* The "Main" run (represented by various black lines for different recurrence values) shows a decrease in validation perplexity as the optimizer step increases. Lower perplexity indicates better model performance.
* Higher recurrence values (e.g., 32, 64) generally start with lower perplexity.
* "Bad Run 1" (orange) shows a consistently high perplexity, indicating poor generalization.
* "Bad Run 2" (green) shows a decrease in perplexity, but not as significant as the "Main" run with higher recurrence values.
### Key Observations
* "Bad Run 1" consistently performs poorly across all metrics (high loss, high hidden state correlation, high perplexity).
* "Bad Run 2" shows some improvement compared to "Bad Run 1" but is still significantly worse than the "Main" run.
* Higher recurrence values lead to lower validation perplexity, suggesting improved model performance.
### Interpretation
The charts demonstrate the importance of proper training and hyperparameter tuning. "Bad Run 1" likely represents a scenario where the model failed to learn due to issues such as incorrect initialization, inappropriate learning rate, or vanishing gradients. "Bad Run 2" might represent a partially successful training run, but still not optimal. The validation perplexity chart highlights the impact of recurrence on model performance, with higher recurrence values generally leading to better generalization. The data suggests that recurrence is a crucial hyperparameter for achieving optimal performance.
</details>
Figure 5: Plots of the initial 10000 steps for the first two failed attempts and the final, successful run (“Main”). Note the hidden state collapse (middle) and collapse of the recurrence (right) in the first two failed runs, underlining the importance of our architecture and initialization in inducing a recurrent model and explain the underperformance of these runs in terms of pretraining loss (left).
Device Speed and Parallelization Strategy.
Nominally, each MI250X chip Technically, each node contains 4 dual-chip MI250X cards, but its main software stack (ROCm runtime) treats these chips as fully independent. achieves 192 TFLOP per GPU (AMD, 2021). For a single matrix multiplication, we measure a maximum achievable speed on these GPUs of 125 TFLOP/s on our software stack (ROCM 6.2.0, PyTorch 2.6 pre-release 11/02) (Bekman, 2023). Our implementation, using extensive PyTorch compilation and optimization of the hidden dimension to $h=5280$ achieves a single-node training speed of 108.75 TFLOP/s, i.e. 87% AFU (“Achievable Flop Utilization”). Due to the weight sharing inherent in our recurrent design, even our largest model is still small enough to be trained using only data (not tensor) parallelism, with only optimizer sharding (Rajbhandari et al., 2020) and gradient checkpointing on a per-iteration granularity. With a batch size of 1 per GPU we end up with a global batch size of 16M tokens per step, minimizing inter-GPU communication bandwidth.
When we run at scale on 4096 GPUs, we achieve 52-64 TFLOP/s per GPU, i.e. 41%-51% AFU, or 1-1.2M tokens per second. To achieve this, we wrote a hand-crafted distributed data parallel implementation to circumvent a critical AMD interconnect issue, which we describe in more detail in Section A.2. Overall, we believe this may be the largest language model training run to completion in terms of number of devices used in parallel on an AMD cluster, as of time of writing.
Training Timeline.
Training proceeded through 21 segments of up to 12 hours, which scheduled on Frontier mostly in early December 2024. We also ran a baseline comparison, where we train the same architecture but in a feedforward manner with only 1 pass through the core/recurrent block. This trained with the same setup for 180B tokens on 256 nodes with a batch size of 2 per GPU. Ultimately, we were able to schedule 795B tokens of pretraining of the main model. Due to our constant learning rate schedule, we were able to add additional segments “on-demand”, when an allocation happened to be available.
<details>
<summary>x6.png Details</summary>

### Visual Description
## Line Chart: Loss vs. Step and Tokens
### Overview
The image is a line chart showing the relationship between "Loss" on the y-axis and "Step (log)" on the x-axis. A secondary x-axis displays "Tokens (log)". The chart illustrates how the loss decreases as the step and number of tokens increase.
### Components/Axes
* **Y-axis:** "Loss" with a linear scale. The axis markers are 5 and 10.
* **X-axis (bottom):** "Step (log)" with a logarithmic scale. The axis markers are 1, 10, 10<sup>2</sup>, 10<sup>3</sup>, and 10<sup>4</sup>.
* **X-axis (top):** "Tokens (log)" with a logarithmic scale. The axis markers are 10<sup>8</sup>, 10<sup>9</sup>, 10<sup>10</sup>, 10<sup>11</sup>, and 10<sup>12</sup>.
* **Data Series:** A single blue line representing the loss.
### Detailed Analysis
The blue line shows the loss value as a function of the step and tokens.
* **Trend:** The line slopes downward, indicating a decreasing loss as the step and number of tokens increase. The decrease is steeper at the beginning and gradually flattens out.
* **Data Points:**
* At Step = 10<sup>0</sup> (1), Loss ≈ 10.5
* At Step = 10<sup>1</sup> (10), Loss ≈ 9
* At Step = 10<sup>2</sup> (100), Loss ≈ 6.5
* At Step = 10<sup>3</sup> (1000), Loss ≈ 3
* At Step = 10<sup>4</sup> (10000), Loss ≈ 2
### Key Observations
* The loss decreases rapidly in the initial steps.
* The rate of decrease slows down significantly after approximately 1000 steps.
* The loss appears to plateau around a value of 2 after 10000 steps.
### Interpretation
The chart demonstrates the learning process of a model, where the loss decreases as the model is trained over more steps and exposed to more tokens. The initial rapid decrease in loss indicates fast learning at the beginning, while the later plateau suggests that the model is converging and further training yields diminishing returns. The logarithmic scale on the x-axis indicates that the model benefits most from the initial exposure to data, with each subsequent order of magnitude of tokens having a smaller impact on reducing the loss.
</details>
<details>
<summary>x7.png Details</summary>

### Visual Description
## Log-Log Plot: Validation Perplexity vs. Step and Tokens for Different Recurrence Values
### Overview
The image is a log-log plot showing the relationship between validation perplexity and training step (and tokens) for different recurrence values. The plot displays how validation perplexity decreases with increasing training steps, with different lines representing different recurrence values. The x-axis represents the training step (log) and tokens (log), while the y-axis represents the validation perplexity (log). The plot includes a legend indicating the recurrence value associated with each line.
### Components/Axes
* **Title:** None explicitly provided in the image.
* **X-axis:**
* Label: "Step (log)"
* Scale: Logarithmic, with markers at 10<sup>2</sup>, 10<sup>3</sup>, 10<sup>4</sup>.
* Secondary Label: "Tokens (log)"
* Secondary Scale: Logarithmic, with markers at 10<sup>10</sup>, 10<sup>11</sup>, 10<sup>12</sup>.
* **Y-axis:**
* Label: "Validation Perplexity (log)"
* Scale: Logarithmic, with markers at 10<sup>0</sup>, 10<sup>1</sup>, 10<sup>2</sup>, 10<sup>3</sup>.
* **Legend:** Located in the top-right corner.
* Title: "Recurrence"
* Entries:
* Blue line: 1
* Orange line: 4
* Green line: 8
* Red line: 16
* Purple line: 32
* Brown line: 64
### Detailed Analysis
The plot shows six lines, each representing a different recurrence value. All lines generally show a decreasing trend, indicating that validation perplexity decreases as the training step increases.
* **Recurrence = 1 (Blue):** Starts at approximately 1500 perplexity at step 10<sup>2</sup>. The line decreases rapidly initially, then plateaus and fluctuates around 50-100 perplexity after step 10<sup>3</sup>.
* **Recurrence = 4 (Orange):** Starts at approximately 700 perplexity at step 10<sup>2</sup>. The line decreases rapidly and plateaus around 10 perplexity after step 10<sup>3</sup>.
* **Recurrence = 8 (Green):** Starts at approximately 600 perplexity at step 10<sup>2</sup>. The line decreases rapidly and plateaus around 5 perplexity after step 10<sup>3</sup>.
* **Recurrence = 16 (Red):** Starts at approximately 500 perplexity at step 10<sup>2</sup>. The line decreases rapidly and plateaus around 5 perplexity after step 10<sup>3</sup>.
* **Recurrence = 32 (Purple):** Starts at approximately 500 perplexity at step 10<sup>2</sup>. The line decreases rapidly and plateaus around 3 perplexity after step 10<sup>3</sup>.
* **Recurrence = 64 (Brown):** Starts at approximately 500 perplexity at step 10<sup>2</sup>. The line decreases rapidly and plateaus around 3 perplexity after step 10<sup>3</sup>.
### Key Observations
* The validation perplexity decreases as the training step increases for all recurrence values.
* Higher recurrence values (8, 16, 32, 64) result in lower validation perplexity compared to lower recurrence values (1, 4) after a certain number of steps.
* The line for recurrence = 1 shows more fluctuation and plateaus at a higher perplexity compared to other recurrence values.
* The lines for recurrence values 8, 16, 32, and 64 are very close to each other, suggesting that increasing recurrence beyond 8 has diminishing returns in terms of reducing validation perplexity.
### Interpretation
The plot suggests that increasing the recurrence value generally leads to lower validation perplexity, indicating better model performance. However, there appears to be a point of diminishing returns, as recurrence values above 8 do not significantly improve the validation perplexity. The fluctuations in the recurrence = 1 line suggest that a lower recurrence value may lead to less stable training. The relationship between the training step and tokens is linear on a log-log scale, implying a power-law relationship between them. The data demonstrates the impact of recurrence on model performance, highlighting the importance of choosing an appropriate recurrence value for optimal results.
</details>
Figure 6: Left: Plot of pretrain loss over the 800B tokens on the main run. Right: Plot of val ppl at recurrent depths 1, 4, 8, 16, 32, 64. During training, the model improves in perplexity on all levels of recurrence.
Table 1: Results on lm-eval-harness tasks zero-shot across various open-source models. We show ARC (Clark et al., 2018), HellaSwag (Zellers et al., 2019), MMLU (Hendrycks et al., 2021a), OpenBookQA (Mihaylov et al., 2018), PiQA (Bisk et al., 2020), SciQ (Johannes Welbl, 2017), and WinoGrande (Sakaguchi et al., 2021). We report normalized accuracy when provided.
| Model random Amber | Param 7B | Tokens 1.2T | ARC-E 25.0 65.70 | ARC-C 25.0 37.20 | HellaSwag 25.0 72.54 | MMLU 25.0 26.77 | OBQA 25.0 41.00 | PiQA 50.0 78.73 | SciQ 25.0 88.50 | WinoGrande 50.0 63.22 |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| Pythia-2.8b | 2.8B | 0.3T | 58.00 | 32.51 | 59.17 | 25.05 | 35.40 | 73.29 | 83.60 | 57.85 |
| Pythia-6.9b | 6.9B | 0.3T | 60.48 | 34.64 | 63.32 | 25.74 | 37.20 | 75.79 | 82.90 | 61.40 |
| Pythia-12b | 12B | 0.3T | 63.22 | 34.64 | 66.72 | 24.01 | 35.40 | 75.84 | 84.40 | 63.06 |
| OLMo-1B | 1B | 3T | 57.28 | 30.72 | 63.00 | 24.33 | 36.40 | 75.24 | 78.70 | 59.19 |
| OLMo-7B | 7B | 2.5T | 68.81 | 40.27 | 75.52 | 28.39 | 42.20 | 80.03 | 88.50 | 67.09 |
| OLMo-7B-0424 | 7B | 2.05T | 75.13 | 45.05 | 77.24 | 47.46 | 41.60 | 80.09 | 96.00 | 68.19 |
| OLMo-7B-0724 | 7B | 2.75T | 74.28 | 43.43 | 77.76 | 50.18 | 41.60 | 80.69 | 95.70 | 67.17 |
| OLMo-2-1124 | 7B | 4T | 82.79 | 57.42 | 80.50 | 60.56 | 46.20 | 81.18 | 96.40 | 74.74 |
| Ours, ( $r=4$ ) | 3.5B | 0.8T | 49.07 | 27.99 | 43.46 | 23.39 | 28.20 | 64.96 | 80.00 | 55.24 |
| Ours, ( $r=8$ ) | 3.5B | 0.8T | 65.11 | 35.15 | 58.54 | 25.29 | 35.40 | 73.45 | 92.10 | 55.64 |
| Ours, ( $r=16$ ) | 3.5B | 0.8T | 69.49 | 37.71 | 64.67 | 31.25 | 37.60 | 75.79 | 93.90 | 57.77 |
| Ours, ( $r=32$ ) | 3.5B | 0.8T | 69.91 | 38.23 | 65.21 | 31.38 | 38.80 | 76.22 | 93.50 | 59.43 |
4.3 Importance of Norms and Initializations at Scale
At small scales all normalization strategies worked, and we observed only tiny differences between initializations. The same was not true at scale. The first training run we started was set up with the same block sandwich structure as described above, but parameter-free RMSNorm layers, no embedding scale $\gamma$ , a parameter-free adapter $A(\mathbf{s},\mathbf{e})=\mathbf{s}+\mathbf{e}$ , and a peak learning rate of $4\text{×}{10}^{-4}$ . As shown in Figure 5, this run (“Bad Run 1”, orange), quickly stalled.
While the run obviously stopped improving in training loss (left plot), we find that this stall is due to the model’s representation collapsing (Noci et al., 2022). The correlation of hidden states in the token dimension quickly goes to 1.0 (middle plot), meaning the model predicts the same hidden state for every token in the sequence. We find that this is an initialization issue that arises due to the recurrence operation. Every iteration of the recurrence block increases token correlation, mixing the sequence until collapse.
We attempt to fix this by introducing the embedding scale factor, switching back to a conventional pre-normalization block, and switching to the learned adapter. Initially, these changes appear to remedy the issue. Even though token correlation shoots close to 1.0 at the start (“Bad Run 2”, green), the model recovers after the first 150 steps. However, we quickly find that this training run is not able to leverage test-time compute effectively (right plot), as validation perplexity is the same whether 1 or 32 recurrences are used. This initialization and norm setup has led to a local minimum as the model has learned early to ignore the incoming state $\mathbf{s}$ , preventing further improvements.
In a third, and final run (“Main”, blue), we fix this issue by reverting back to the sandwich block format, and further dropping the peak learning rate to $4\text{×}{10}^{-5}$ . This run starts smoothly, never reaches a token correlation close to 1.0, and quickly overtakes the previous run by utilizing the recurrence and improving with more iterations.
With our successful configuration, training continues smoothly for the next 750B tokens without notable interruptions or loss spikes. We plot training loss and perplexity at different recurrence steps in Figure 6. In our material, we refer to the final checkpoint of this run as our “main model”, which we denote as Huginn-0125 \textipa /hu: gIn/, transl. “thought”, is a raven depicted in Norse mythology. Corvids are surprisingly intelligent for their size, and and of course, as birds, able to unfold their wings at test-time..
5 Benchmark Results
We train our final model for 800B tokens, and a non-recurrent baseline for 180B tokens. We evaluate these checkpoints against other open-source models trained on fully public datasets (like ours) of a similar size. We compare against Amber (Liu et al., 2023c), Pythia (Biderman et al., 2023) and a number of OLMo 1&2 variants (Groeneveld et al., 2024; AI2, 2024; Team OLMo et al., 2025). We execute all standard benchmarks through the lm-eval harness (Biderman et al., 2024) and code benchmarks via bigcode-bench (Zhuo et al., 2024).
5.1 Standard Benchmarks
Overall, it is not straightforward to place our model in direct comparison to other large language models, all of which are small variations of the fixed-depth transformer architecture. While our model has only 3.5B parameters and hence requires only modest interconnect bandwidth during pretraining, it chews through raw FLOPs close to what a 32B parameter transformer would consume during pretraining, and can continuously improve in performance with test-time scaling up to FLOP budgets equivalent to a standard 50B parameter fixed-depth transformer. It is also important to note a few caveats of the main training run when interpreting the results. First, our main checkpoint is trained for only 47000 steps on a broadly untested mixture, and the learning rate is never cooled down from its peak. As an academic project, the model is trained only on publicly available data and the 800B token count, while large in comparison to older fully open-source models such as the Pythia series, is small in comparison to modern open-source efforts such as OLMo, and tiny in comparison to the datasets used to train industrial open-weight models.
Table 2: Benchmarks of mathematical reasoning and understanding. We report flexible and strict extract for GSM8K and GSM8K CoT, extract match for Minerva Math, and acc norm. for MathQA.
| Model Random Amber | GSM8K 0.00 3.94/4.32 | GSM8k CoT 0.00 3.34/5.16 | Minerva MATH 0.00 1.94 | MathQA 20.00 25.26 |
| --- | --- | --- | --- | --- |
| Pythia-2.8b | 1.59/2.12 | 1.90/2.81 | 1.96 | 24.52 |
| Pythia-6.9b | 2.05/2.43 | 2.81/2.88 | 1.38 | 25.96 |
| Pythia-12b | 3.49/4.62 | 3.34/4.62 | 2.56 | 25.80 |
| OLMo-1B | 1.82/2.27 | 1.59/2.58 | 1.60 | 23.38 |
| OLMo-7B | 4.02/4.09 | 6.07/7.28 | 2.12 | 25.26 |
| OLMo-7B-0424 | 27.07/27.29 | 26.23/26.23 | 5.56 | 28.48 |
| OLMo-7B-0724 | 28.66/28.73 | 28.89/28.89 | 5.62 | 27.84 |
| OLMo-2-1124-7B | 66.72/66.79 | 61.94/66.19 | 19.08 | 37.59 |
| Our w/o sys. prompt ( $r=32$ ) | 28.05/28.20 | 32.60/34.57 | 12.58 | 26.60 |
| Our w/ sys. prompt ( $r=32$ ) | 24.87/38.13 | 34.80/42.08 | 11.24 | 27.97 |
Table 3: Evaluation on code benchmarks, MBPP and HumanEval. We report pass@1 for both datasets.
| Model Random starcoder2-3b | Param 3B | Tokens 3.3T | MBPP 0.00 43.00 | HumanEval 0.00 31.09 |
| --- | --- | --- | --- | --- |
| starcoder2-7b | 7B | 3.7T | 43.80 | 31.70 |
| Amber | 7B | 1.2T | 19.60 | 13.41 |
| Pythia-2.8b | 2.8B | 0.3T | 6.70 | 7.92 |
| Pythia-6.9b | 6.9B | 0.3T | 7.92 | 5.60 |
| Pythia-12b | 12B | 0.3T | 5.60 | 9.14 |
| OLMo-1B | 1B | 3T | 0.00 | 4.87 |
| OLMo-7B | 7B | 2.5T | 15.6 | 12.80 |
| OLMo-7B-0424 | 7B | 2.05T | 21.20 | 16.46 |
| OLMo-7B-0724 | 7B | 2.75T | 25.60 | 20.12 |
| OLMo-2-1124-7B | 7B | 4T | 21.80 | 10.36 |
| Ours ( $r=32$ ) | 3.5B | 0.8T | 24.80 | 23.17 |
Disclaimers aside, we collect results for established benchmark tasks (Team OLMo et al., 2025) in Table 1 and show all models side-by-side. In direct comparison we see that our model outperforms the older Pythia series and is roughly comparable to the first OLMo generation, OLMo-7B in most metrics, but lags behind the later OLMo models trained larger, more carefully curated datasets. For the first recurrent-depth model for language to be trained at this scale, and considering the limitations of the training run, we find these results promising and certainly suggestive that further research into latent recurrence as an approach to test-time scaling is warranted.
Table 4: Baseline comparison, recurrent versus non-recurrent model trained in the same training setup and data. Comparing the recurrent model with its non-recurrent baseline, we see that even at 180B tokens, the recurrent substantially outperforms on harder tasks.
| Ours, early ckpt, ( $r=32$ ) Ours, early ckpt, ( $r=1$ ) Ours, ( $r=32$ ) | 0.18T 0.18T 0.8T | 53.62 34.01 69.91 | 29.18 23.72 38.23 | 48.80 29.19 65.21 | 25.59 23.47 31.38 | 31.40 25.60 38.80 | 68.88 53.26 76.22 | 80.60 54.10 93.50 | 52.88 53.75 59.43 | 9.02/10.24 0.00/0.15 34.80/42.08 |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| Ours, ( $r=1$ ) | 0.8T | 34.89 | 24.06 | 29.34 | 23.60 | 26.80 | 55.33 | 47.10 | 49.41 | 0.00/0.00 |
5.2 Math and Coding Benchmarks
We also evaluate the model on math and coding. For math, we evaluate GSM8k (Cobbe et al., 2021) (as zero-shot and in the 8-way CoT setup), MATH ((Hendrycks et al., 2021b) with the Minerva evaluation rules (Lewkowycz et al., 2022)) and MathQA (Amini et al., 2019). For coding, we check MBPP (Austin et al., 2021) and HumanEval (Chen et al., 2021). Here we find that our model significantly surpasses all models except the latest OLMo-2 model in mathematical reasoning, as measured on GSM8k and MATH. On coding benchmarks the model beats all other general-purpose open-source models, although it does not outperform dedicated code models, such as StarCoder2 (Lozhkov et al., 2024), trained for several trillion tokens. We also note that while further improvements in language modeling are slowing down, as expected at this training scale, both code and mathematical reasoning continue to improve steadily throughout training, see Figure 8.
<details>
<summary>x8.png Details</summary>

### Visual Description
## Line Chart: Performance vs. Recurrence at Test-Time
### Overview
The image is a line chart comparing the performance of four different models (HellaSwag, GSM8K CoT (Strict), GSM8K CoT (Flexible), and Humaneval) across varying levels of recurrence at test-time. The x-axis represents the recurrence at test-time, while the y-axis represents performance.
### Components/Axes
* **X-axis:** Recurrence at Test-Time, with values 1, 4, 8, 16, 32, and 64.
* **Y-axis:** Performance, with values ranging from 0 to 80.
* **Legend (top-left):**
* Blue squares: HellaSwag
* Orange circles: GSM8K CoT (Strict)
* Green circles: GSM8K CoT (Flexible)
* Red line: Humaneval
### Detailed Analysis
* **HellaSwag (Blue, dotted line):** The performance increases sharply from a recurrence of 1 to 8, then plateaus.
* Recurrence 1: Performance ~30
* Recurrence 4: Performance ~45
* Recurrence 8: Performance ~60
* Recurrence 16: Performance ~65
* Recurrence 32: Performance ~65
* Recurrence 64: Performance ~65
* **GSM8K CoT (Strict) (Orange, dashed line):** The performance increases gradually with recurrence.
* Recurrence 1: Performance ~0
* Recurrence 4: Performance ~2
* Recurrence 8: Performance ~10
* Recurrence 16: Performance ~30
* Recurrence 32: Performance ~35
* Recurrence 64: Performance ~35
* **GSM8K CoT (Flexible) (Green, dashed-dotted line):** The performance increases with recurrence, similar to the strict version, but with a steeper initial increase.
* Recurrence 1: Performance ~0
* Recurrence 4: Performance ~2
* Recurrence 8: Performance ~15
* Recurrence 16: Performance ~40
* Recurrence 32: Performance ~40
* Recurrence 64: Performance ~42
* **Humaneval (Red, solid line):** The performance increases gradually with recurrence, but remains lower than the other models.
* Recurrence 1: Performance ~0
* Recurrence 4: Performance ~2
* Recurrence 8: Performance ~10
* Recurrence 16: Performance ~20
* Recurrence 32: Performance ~23
* Recurrence 64: Performance ~23
### Key Observations
* HellaSwag significantly outperforms the other models, especially at lower recurrence values.
* GSM8K CoT (Flexible) generally performs better than GSM8K CoT (Strict).
* Humaneval has the lowest performance across all recurrence values.
* All models except HellaSwag show a noticeable increase in performance as recurrence increases from 1 to 16.
### Interpretation
The chart suggests that increasing recurrence at test-time can improve the performance of these models, particularly for GSM8K CoT (Strict), GSM8K CoT (Flexible), and Humaneval. HellaSwag, however, reaches a performance plateau relatively quickly. The substantial difference in performance between HellaSwag and the other models indicates that it may be better suited for tasks requiring fewer recurrent steps. The difference between the strict and flexible versions of GSM8K CoT suggests that allowing more flexibility in the chain-of-thought reasoning can lead to better performance. Humaneval's lower performance may indicate that it is a more challenging task or that the model is not as well-suited for recurrent processing.
</details>
Figure 7: Performance on GSM8K CoT (strict match and flexible match), HellaSwag (acc norm.), and HumanEval (pass@1). As we increase compute, the performance on these benchmarks increases. HellaSwag only needs $8$ recurrences to achieve near peak performance while other benchmarks make use of more compute.
<details>
<summary>x9.png Details</summary>

### Visual Description
## Line Chart: GSM8K CoT Performance vs. Tokens Trained
### Overview
The image is a line chart comparing the performance of different models on the GSM8K CoT (Chain-of-Thought) task, based on the number of tokens they were trained on. The x-axis represents the number of tokens trained (in billions), and the y-axis represents the GSM8K CoT score. Different lines represent different configurations, likely related to the number of "Rec" (likely referring to "retrieval" or "recurrent") units or layers.
### Components/Axes
* **Title:** GSM8K CoT
* **X-axis:** Tokens Trained (Billion). Scale ranges from 100 to 800 in increments of 100.
* **Y-axis:** GSM8K CoT. Scale ranges from 0 to 35 in increments of 5.
* **Legend:** Located at the top of the chart.
* Blue line: 1 Rec
* Orange dashed line: 4 Rec
* Green dash-dot line: 8 Rec
* Red dotted line: 16 Rec
* Purple solid line: 32 Rec
* Brown dashed line: 64 Rec
### Detailed Analysis
* **1 Rec (Blue Line):** Remains relatively flat at a low score (approximately 0-1) across all training token values.
* **4 Rec (Orange Dashed Line):** Also remains relatively flat at a low score (approximately 1-2) across all training token values.
* **8 Rec (Green Dash-Dot Line):** Starts at approximately 2 at 100 tokens, increases to approximately 5 at 200 tokens, decreases to approximately 4 at 300 tokens, decreases again to approximately 3 at 400 tokens, increases to approximately 8 at 500 tokens, decreases to approximately 7 at 600 tokens, peaks at approximately 22 at 700 tokens, and decreases to approximately 15 at 800 tokens.
* **16 Rec (Red Dotted Line):** Starts at approximately 2 at 100 tokens, increases to approximately 11 at 200 tokens, increases to approximately 14 at 300 tokens, increases to approximately 20 at 400 tokens, increases to approximately 26 at 500 tokens, increases to approximately 27 at 600 tokens, peaks at approximately 36 at 700 tokens, and decreases to approximately 32 at 800 tokens.
* **32 Rec (Purple Solid Line):** Starts at approximately 2 at 100 tokens, increases to approximately 12 at 200 tokens, increases to approximately 15 at 300 tokens, increases to approximately 21 at 400 tokens, increases to approximately 28 at 500 tokens, increases to approximately 28 at 600 tokens, peaks at approximately 36 at 700 tokens, and decreases to approximately 35 at 800 tokens.
* **64 Rec (Brown Dashed Line):** Starts at approximately 2 at 100 tokens, increases to approximately 11 at 200 tokens, increases to approximately 14 at 300 tokens, increases to approximately 20 at 400 tokens, increases to approximately 28 at 500 tokens, increases to approximately 27 at 600 tokens, peaks at approximately 35 at 700 tokens, and decreases to approximately 34 at 800 tokens.
### Key Observations
* The "1 Rec" and "4 Rec" configurations show minimal improvement in GSM8K CoT score as the number of training tokens increases.
* The "8 Rec" configuration shows some improvement, but is significantly lower than the other configurations.
* The "16 Rec", "32 Rec", and "64 Rec" configurations show a significant increase in GSM8K CoT score as the number of training tokens increases, peaking around 700 billion tokens, and then slightly decreasing.
* The "16 Rec", "32 Rec", and "64 Rec" configurations perform similarly, with "32 Rec" and "64 Rec" performing slightly better than "16 Rec".
### Interpretation
The data suggests that increasing the number of training tokens significantly improves the performance of models on the GSM8K CoT task, but only when the model has a sufficient number of "Rec" units or layers (16, 32, or 64 in this case). Models with fewer "Rec" units (1, 4, or 8) do not benefit as much from increased training data. The performance of models with 16, 32, and 64 "Rec" units is similar, suggesting that there may be a point of diminishing returns in increasing the number of "Rec" units. The slight decrease in performance after 700 billion tokens for the "16 Rec", "32 Rec", and "64 Rec" configurations could indicate overfitting or a need for further optimization.
</details>
<details>
<summary>x10.png Details</summary>

### Visual Description
## Line Chart: HellaSwag Performance vs. Tokens Trained
### Overview
The image is a line chart comparing the performance of different models, measured by HellaSwag score, against the number of tokens trained (in billions). The chart includes six data series, each representing a different model configuration (1 Rec, 4 Rec, 8 Rec, 16 Rec, 32 Rec, and 64 Rec). The chart shows how the HellaSwag score changes as the number of tokens trained increases.
### Components/Axes
* **X-axis:** Tokens Trained (Billion). The scale ranges from 100 to 800, with tick marks at intervals of 100.
* **Y-axis:** HellaSwag. The scale ranges from 25 to 65, with tick marks at intervals of 5.
* **Legend:** Located at the top of the chart, the legend identifies each line by its corresponding model configuration:
* Blue line: 1 Rec
* Orange dashed line: 4 Rec
* Green dash-dot line: 8 Rec
* Red dotted line: 16 Rec
* Purple line: 32 Rec
* Brown dashed line: 64 Rec
### Detailed Analysis
* **1 Rec (Blue Line):** This line remains relatively flat, indicating minimal improvement in HellaSwag score as the number of tokens trained increases. The score fluctuates between approximately 28 and 30.
* At 100 Tokens Trained: ~29
* At 800 Tokens Trained: ~29
* **4 Rec (Orange Dashed Line):** This line shows a moderate increase in HellaSwag score as the number of tokens trained increases. The score starts around 32 and plateaus around 45.
* At 100 Tokens Trained: ~32
* At 800 Tokens Trained: ~45
* **8 Rec (Green Dash-Dot Line):** This line shows a significant increase in HellaSwag score initially, then plateaus as the number of tokens trained increases.
* At 100 Tokens Trained: ~40
* At 800 Tokens Trained: ~58
* **16 Rec (Red Dotted Line):** This line shows a rapid increase in HellaSwag score initially, then plateaus as the number of tokens trained increases.
* At 100 Tokens Trained: ~42
* At 800 Tokens Trained: ~64
* **32 Rec (Purple Line):** This line shows a rapid increase in HellaSwag score initially, then plateaus as the number of tokens trained increases.
* At 100 Tokens Trained: ~42
* At 800 Tokens Trained: ~65
* **64 Rec (Brown Dashed Line):** This line shows a rapid increase in HellaSwag score initially, then plateaus as the number of tokens trained increases.
* At 100 Tokens Trained: ~42
* At 800 Tokens Trained: ~64
### Key Observations
* The 1 Rec model shows almost no improvement with increased training tokens.
* The 4 Rec model shows a moderate improvement, but plateaus at a lower HellaSwag score compared to the other models.
* The 8 Rec, 16 Rec, 32 Rec, and 64 Rec models show significant initial improvements, but their performance plateaus as the number of tokens trained increases.
* The 32 Rec model appears to achieve the highest HellaSwag score, closely followed by the 16 Rec and 64 Rec models.
### Interpretation
The data suggests that increasing the number of Rec (likely referring to recurrent layers or some similar architectural component) significantly improves the model's performance, as measured by the HellaSwag score, up to a certain point. The 1 Rec model's flat line indicates that it is likely under-parameterized and cannot effectively learn from the training data, regardless of the number of tokens trained. The other models (4 Rec, 8 Rec, 16 Rec, 32 Rec, and 64 Rec) benefit from increased training, but their performance plateaus, suggesting diminishing returns. The 32 Rec model appears to be the most effective configuration, achieving the highest HellaSwag score. The similarity in performance between 16 Rec, 32 Rec, and 64 Rec models at higher training token counts suggests that there may be a saturation point beyond which adding more Rec does not significantly improve performance.
</details>
<details>
<summary>x11.png Details</summary>

### Visual Description
## Line Chart: HumanEval vs. Tokens Trained
### Overview
The image is a line chart comparing "HumanEval" scores against "Tokens Trained (Billion)" for different configurations denoted as "Rec" (likely referring to some kind of "recurrent" or "recognition" parameter). There are six data series, each representing a different "Rec" value: 1, 4, 8, 16, 32, and 64. The chart shows how the HumanEval score changes as the number of tokens trained increases for each configuration.
### Components/Axes
* **X-axis:** "Tokens Trained (Billion)". The scale ranges from 100 to 800, with tick marks at intervals of 100.
* **Y-axis:** "HumanEval". The scale ranges from 0 to 25, with tick marks at intervals of 5.
* **Legend:** Located at the top of the chart, it identifies each line by its "Rec" value and corresponding color/linestyle:
* **Blue:** 1 Rec (solid line)
* **Orange:** 4 Rec (dashed line)
* **Green:** 8 Rec (dash-dotted line)
* **Red:** 16 Rec (dotted line)
* **Purple:** 32 Rec (solid line)
* **Brown:** 64 Rec (dashed line)
### Detailed Analysis
* **1 Rec (Blue, Solid):** The line remains almost flat at a value of approximately 0 for all token values.
* (100, ~0)
* (800, ~0)
* **4 Rec (Orange, Dashed):** The line starts near 0, increases slightly, and then fluctuates between 1 and 4.
* (100, ~0)
* (200, ~0.5)
* (300, ~1)
* (400, ~3)
* (500, ~2)
* (600, ~1)
* (700, ~4)
* (800, ~1)
* **8 Rec (Green, Dash-Dotted):** The line increases from approximately 2 to 15, then decreases to approximately 12.
* (100, ~2)
* (200, ~7)
* (300, ~8)
* (400, ~10)
* (500, ~11)
* (600, ~15)
* (700, ~15)
* (800, ~12)
* **16 Rec (Red, Dotted):** The line increases from approximately 6 to 19, then plateaus around 19.
* (100, ~6)
* (200, ~9)
* (300, ~12)
* (400, ~15)
* (500, ~18)
* (600, ~19)
* (700, ~19)
* (800, ~19)
* **32 Rec (Purple, Solid):** The line increases from approximately 6 to 23.
* (100, ~6)
* (200, ~9)
* (300, ~13)
* (400, ~15)
* (500, ~15)
* (600, ~19)
* (700, ~19)
* (800, ~23)
* **64 Rec (Brown, Dashed):** The line increases from approximately 6 to 23.
* (100, ~6)
* (200, ~8)
* (300, ~12)
* (400, ~15)
* (500, ~15)
* (600, ~19)
* (700, ~19)
* (800, ~23)
### Key Observations
* The "1 Rec" configuration performs significantly worse than all other configurations, with HumanEval scores consistently near 0.
* The "4 Rec" configuration also performs poorly, with HumanEval scores generally below 5.
* The "8 Rec" configuration shows an initial increase in HumanEval score, but then decreases after 600 Billion tokens trained.
* The "16 Rec", "32 Rec", and "64 Rec" configurations show similar performance, with HumanEval scores increasing significantly as the number of tokens trained increases. The "16 Rec" plateaus around 19, while "32 Rec" and "64 Rec" continue to increase to approximately 23.
### Interpretation
The chart suggests that the "Rec" parameter has a significant impact on the HumanEval score. Lower values of "Rec" (1 and 4) result in poor performance, while higher values (16, 32, and 64) lead to significantly better results. The "8 Rec" configuration shows a non-monotonic relationship, suggesting that there may be an optimal value for this parameter. The plateauing of the "16 Rec" configuration suggests that there may be diminishing returns to training beyond a certain point for this configuration. The continued increase of "32 Rec" and "64 Rec" suggests that these configurations may benefit from further training.
</details>
Figure 8: GSM8K CoT, HellaSwag, and HumanEval performance over the training tokens with different recurrences at test-time. We evaluate GSM8K CoT with chat template and 8-way few shot as multiturn. HellaSwag and HumanEval are zero-shot with no chat template. Model performance on harder tasks grows almost linearly with the training budget, if provided sufficient test-time compute.
5.3 Where does recurrence help most?
How much of this performance can we attribute to recurrence, and how much to other factors, such as dataset, tokenization and architectural choices? In Table 4, we compare our recurrent model against its non-recurrent twin, which we trained to 180B tokens in the exact same setting. In direct comparison of both models at 180B tokens, we see that the recurrent model outperforms its baseline with an especially pronounced advantage on harder tasks, such as the ARC challenge set. On other tasks, such as SciQ, which requires straightforward recall of scientific facts, performance of the models is more similar. We observe that gains through reasoning are especially prominent on GSM8k, where the 180B recurrent model is already 5 times better than the baseline at this early snapshot in the pretraining process. We also note that the recurrent model, when evaluated with only a single recurrence, effectively stops improving between the early 180B checkpoint and the 800B checkpoint, showing that further improvements are not built into the prelude or coda non-recurrent layers but encoded entirely into the iterations of the recurrent block.
Further, we chart the improvement as a function of test-time compute on several of these tasks for the main model in Figure 7. We find that saturation is highly task-dependent, on easier tasks the model saturates quicker, whereas it benefits from more compute on others.
<details>
<summary>x12.png Details</summary>

### Visual Description
## Line Chart: ARC Challenge Accuracy vs. Test-Time Compute Recurrence
### Overview
The image is a line chart displaying the relationship between "Test-Time Compute Recurrence" (x-axis) and "ARC Challenge Accuracy (%)" (y-axis) for different "shot" configurations (0-shot, 1-shot, 5-shot, 25-shot, and 50-shot). The chart shows how accuracy changes with increasing compute recurrence for each configuration. Error bars are present on each data point.
### Components/Axes
* **X-axis:** "Test-Time Compute Recurrence". The scale is logarithmic, with marked values at 1, 4, 6, 8, 12, 20, 32, 48, and 64.
* **Y-axis:** "ARC Challenge Accuracy (%)". The scale ranges from 20 to 45, with tick marks at intervals of 5.
* **Legend:** Located in the bottom-right corner, the legend identifies each line by its color and "shot" configuration:
* Blue: 0-shot
* Orange: 1-shot
* Green: 5-shot
* Red: 25-shot
* Purple: 50-shot
### Detailed Analysis
* **0-shot (Blue):** The line starts at approximately 19% accuracy at a recurrence of 1. It increases to approximately 34% accuracy by a recurrence of 8, then plateaus around 34% for higher recurrence values.
* **1-shot (Orange):** The line starts at approximately 20% accuracy at a recurrence of 1. It increases to approximately 40% accuracy by a recurrence of 12, then plateaus around 40% for higher recurrence values.
* **5-shot (Green):** The line starts at approximately 21% accuracy at a recurrence of 1. It increases to approximately 42% accuracy by a recurrence of 12, then plateaus around 42% for higher recurrence values.
* **25-shot (Red):** The line starts at approximately 20% accuracy at a recurrence of 1. It increases to approximately 43% accuracy by a recurrence of 12, then plateaus around 43% for higher recurrence values.
* **50-shot (Purple):** The line starts at approximately 20% accuracy at a recurrence of 1. It increases to approximately 44% accuracy by a recurrence of 12, then plateaus around 44% for higher recurrence values.
### Key Observations
* All configurations except 0-shot show a significant increase in accuracy as the Test-Time Compute Recurrence increases from 1 to approximately 12.
* After a recurrence of 12, the accuracy for 1-shot, 5-shot, 25-shot, and 50-shot configurations plateaus.
* The 0-shot configuration plateaus at a significantly lower accuracy than the other configurations.
* The 50-shot configuration consistently achieves the highest accuracy among all configurations.
### Interpretation
The data suggests that increasing the Test-Time Compute Recurrence significantly improves the ARC Challenge Accuracy, especially for configurations with a few shots (1-shot, 5-shot, 25-shot, and 50-shot). The 0-shot configuration benefits less from increased recurrence, indicating that some initial learning (shots) is crucial for leveraging the benefits of increased compute time. The diminishing returns observed after a recurrence of 12 suggest an optimal point beyond which further increases in compute time provide minimal gains in accuracy. The 50-shot configuration consistently outperforming the others indicates that more initial learning leads to better performance overall.
</details>
Figure 9: The saturation point in un-normalized accuracy via test-time recurrence on the ARC challenge set is correlated with the number of few-shot examples. The model uses more recurrence to extract more information from the additional few-shot examples, making use of more compute if more context is given.
Table 5: Comparison of Open and Closed QA Performance (%) (Mihaylov et al., 2018). In the open exam, a relevant fact is provided before the question is asked. In this setting, our smaller model closes the gap to other open-source models, indicating that the model is capable, but has fewer facts memorized.
| Amber Pythia-2.8b Pythia-6.9b | 41.0 35.4 37.2 | 46.0 44.8 44.2 | +5.0 +9.4 +7.0 |
| --- | --- | --- | --- |
| Pythia-12b | 35.4 | 48.0 | +12.6 |
| OLMo-1B | 36.4 | 43.6 | +7.2 |
| OLMo-7B | 42.2 | 49.8 | +7.6 |
| OLMo-7B-0424 | 41.6 | 50.6 | +9.0 |
| OLMo-7B-0724 | 41.6 | 53.2 | +11.6 |
| OLMo-2-1124 | 46.2 | 53.4 | +7.2 |
| Ours ( $r=32$ ) | 38.2 | 49.2 | +11.0 |
Recurrence and Context
We evaluate ARC-C performance as a function of recurrence and number of few-shot examples in the context in Figure 9. Interestingly, without few-shot examples to consider, the model saturates in compute around 8-12 iterations. However, when more context is given, the model can reason about more information in context, which it does, saturating around 20 iterations if 1 example is provided, and 32 iterations, if 25-50 examples are provided, mirroring generalization improvements shown for recurrence (Yang et al., 2024a; Fan et al., 2025). Similarly, we see that if we re-evaluate OBQA in Table 5, but do not run the benchmark in the default lm-eval ”closed-book” format and rather provide a relevant fact, our recurrent model improves significantly almost closing the gap to OLMo-2. Intuitively this makes sense, as the recurrent models has less capacity to memorize facts but more capacity to reason about its context.
5.4 Improvements through Weight Averaging
Due to our constant learning rate, we can materialize further improvements through weight averaging (Izmailov et al., 2018) to simulate the result of a cooldown (Hägele et al., 2024; DeepSeek-AI et al., 2024). We use an exponential moving average starting from our last checkpoint with $\beta=0.9$ , incorporating the last 75 checkpoints with a dilation factor of 7, a modification to established protocols (Kaddour, 2022; Sanyal et al., 2024). We provide this EMA model as well, which further improves GMS8k performance to $47.23\%$ flexible ( $38.59\%$ strict), when tested at $r=64$ .
<details>
<summary>x13.png Details</summary>

### Visual Description
## Histogram: Steps to KL-based Threshold by Category
### Overview
The image presents four histograms, each displaying the distribution of "Steps to KL-based Threshold" for different categories: "high school mathematics", "philosophy", "logical fallacies", and "moral scenarios". Each histogram compares a "Default" setting with a "Cont. CoT" (Continuous Chain-of-Thought) setting. The y-axis represents "Density", and the x-axis represents "Steps to KL-based Threshold".
### Components/Axes
* **X-axis:** "Steps to KL-based Threshold", ranging from 0 to 30 in each histogram.
* **Y-axis:** "Density", ranging from 0.00 to 0.08.
* **Histograms:** Four histograms, one for each category:
* High school mathematics
* Philosophy
* Logical fallacies
* Moral scenarios
* **Legend:** Located at the top of each histogram.
* "Default": Represented by a light gray color with black outlines.
* "Cont. CoT": Represented by a distinct color for each category (green, yellow, red, blue).
* **Mean (µ) values:** Provided in the legend for both "Default" and "Cont. CoT" settings within each category.
### Detailed Analysis
**1. High School Mathematics**
* **Default (µ=12.7):** The light gray histogram shows a distribution that peaks around 5-10 steps and then gradually decreases.
* **Cont. CoT (µ=11.9):** The green histogram shows a similar distribution to the default, peaking around 5 steps and then decreasing.
* **Trend:** Both distributions are right-skewed, with the Cont. CoT slightly shifted to the left compared to the Default.
**2. Philosophy**
* **Default (µ=14.6):** The light gray histogram shows a distribution that peaks around 10-15 steps and then gradually decreases.
* **Cont. CoT (µ=13.5):** The yellow histogram shows a similar distribution to the default, peaking around 10 steps and then decreasing.
* **Trend:** Both distributions are right-skewed, with the Cont. CoT slightly shifted to the left compared to the Default.
**3. Logical Fallacies**
* **Default (µ=15.6):** The light gray histogram shows a distribution that peaks around 10-15 steps and then gradually decreases.
* **Cont. CoT (µ=14.4):** The red histogram shows a similar distribution to the default, peaking around 10 steps and then decreasing.
* **Trend:** Both distributions are right-skewed, with the Cont. CoT slightly shifted to the left compared to the Default.
**4. Moral Scenarios**
* **Default (µ=16.2):** The light gray histogram shows a distribution that peaks around 20-25 steps.
* **Cont. CoT (µ=16.0):** The blue histogram shows a similar distribution to the default, peaking around 20 steps.
* **Trend:** The distributions are less skewed compared to the other categories, with a more pronounced peak.
### Key Observations
* For all categories, the "Cont. CoT" setting has a lower mean (µ) value than the "Default" setting.
* The distributions for "high school mathematics", "philosophy", and "logical fallacies" are right-skewed, indicating that most cases require fewer steps to reach the KL-based threshold.
* The distribution for "moral scenarios" is less skewed and has a more pronounced peak, suggesting a more consistent number of steps required.
### Interpretation
The histograms compare the number of steps required to reach a KL-based threshold under "Default" conditions versus using a "Continuous Chain-of-Thought" (Cont. CoT) approach across four different categories. The consistent trend of lower mean values for "Cont. CoT" suggests that this method generally reduces the number of steps needed to reach the threshold, potentially indicating a more efficient or direct path to the solution or conclusion. The varying shapes of the distributions across categories suggest that the nature of the problem influences the number of steps required, with "moral scenarios" showing a more consistent step count compared to the other, more skewed distributions.
</details>
Figure 10: Histograms of zero-shot, per-token adaptive exits based on KL difference between steps for questions from MMLU categories, with and without zero-shot continuous CoT. The mean of each distribution is given in the legends. The exit threshold is fixed to $5\text{×}{10}^{-4}$ . We see that the model converges quicker on high school mathematics than tasks such as logical fallacies or moral scenarios. On some tasks, such as philosophy, the model is able to effectively re-use states in its latent CoT and converge quickly on a subset of tokens, leading to fewer steps required overall.
6 Recurrent Depth simplifies LLMs
Aside from encouraging performance in mathematical and code reasoning, recurrent-depth models turn out to be surprisingly natural tools to support a number of methods that require substantial effort with standard transformers. In the next section, we provide a non-exhaustive overview.
6.1 Zero-Shot Adaptive Compute at Test-Time
We have shown that the model is capable of varying compute on a per-query level, running the model in different recurrence modes. This is after all also how the model is trained, as in Equation 1. However, it would be more efficient in practice to stop recurring early when predictions are easy, and only spend compute on hard decisions. Other work, especially when based on standard transformers, requires models trained specifically for early exits (Elbayad et al., 2019; Fan et al., 2019; Banino et al., 2021), or models finetuned with exit heads on every layer (Schuster et al., 2022). To test our model’s zero-shot exit abilities, we choose a simple exit criterion to evaluate convergence, the KL-divergence between two successive steps. If this divergence falls below $5\text{×}{10}^{-4}$ , we stop iterating, sample the output token, and move to generate the next token.
We show this zero-shot per-token adaptive compute behavior in Figure 10, where we plot the distribution of steps taken before the exit condition is hit. We do this for the first 50 questions from different MMLU categories, asked in free-form chat. Interestingly, the number of steps required to exit differs notably between categories, with the model exiting earlier on high school mathematics, but taking on average 3.5 steps more on moral scenarios. As a preliminary demonstration, we verify on MTBench that this adaptivity does not significantly impact performance in a conversational benchmark setting (standard: $5.63$ , early exits: $5.56$ see Appendix Table 6).
**Remark 6.1 (What about missing KV-cache entries?)**
*Traditionally, a concern with token-wise early exits for models with self-attention is that it breaks KV-caching in a fundamental way. On each recurrent step, a token needs to attend to the KV state of previous tokens in the sequence, but these activations may not have been computed due to an early exit. A naïve fix would be to pause generating and recompute all missing hidden states, but this would remove some of the benefit of early stopping. Instead, as in Elbayad et al. (2019), we attend to the last, deepest available KV states in the cache. Because all recurrent KV cache entries are generated by the same K,V projection matrices from successive hidden states, they “match”, and therefore the model is able to attend to the latest cache entry from every previous token, even if computed at different recurrent depths.*
6.2 Zero-Shot KV-cache Sharing
A different avenue to increase efficiency is to reduce the memory footprint of the KV-cache by sharing the cache between layers (character.ai, 2024; Brandon et al., 2024). Typically, transformers must be trained from scratch with this capability. However, as discussed in the previous section, we find that we can simply share KV-caches in our model with minimal impact to performance. We set a fixed KV-cache budget for the recurrence at every token $k$ , and at iteration $i$ , read and write the cache entry $i\bmod k$ . For example, we set a maximum KV-cache budget of 16 steps, overwriting the KV-cache of the 1st step when executing the 17th step, and so forth. This can be used on its own to reduce KV cache memory, or in combination with per-token adaptive compute as discussed above. On MTBench, this does not reduce performance (cache budget of $4$ : $5.86$ , see Appendix Table 6).
<details>
<summary>extracted/6211213/figures/convergence_chart_range_I_74_103.png Details</summary>

### Visual Description
## Heatmap: Latent State Convergence ||s - s*||
### Overview
The image is a heatmap visualizing the convergence of latent states, represented by the distance between states 's' and 's*'. The x-axis represents iterations at test time, and the y-axis represents a sequence of words or tokens. The color intensity indicates the log distance, with yellow representing higher distances and purple representing lower distances.
### Components/Axes
* **Title:** Latent State Convergence ||s - s*||
* **X-axis:** Iterations at Test Time, ranging from 0 to 60 in increments of 10.
* **Y-axis:** A sequence of words/tokens: I, am, looking, for, help, with, making, a, bomb, out, of, household, items, Someone, at, school, really, wrong, ed, me, ., ., ., <|end_turn|>, <|begin_header|>, H, ug, inn.
* **Right Y-axis:** Numerical labels from 74 to 102, incrementing by 1.
* **Colorbar (Log Distance):** Ranges from 10^0 to 10^2, with yellow indicating higher values and purple indicating lower values.
### Detailed Analysis
The heatmap displays the log distance between latent states across iterations for each word/token.
* **Words/Tokens:**
* "I" (74): Starts with a high log distance (yellow) and decreases to a lower log distance (purple) around iteration 30.
* "am" (75): Similar to "I", starts high and decreases, converging around iteration 30.
* "looking" (76): Similar trend, converging around iteration 30.
* "for" (77): Similar trend, converging around iteration 30.
* "help" (78): Similar trend, converging around iteration 30.
* "with" (79): Similar trend, converging around iteration 30.
* "making" (80): Similar trend, converging around iteration 30.
* "a" (81): Similar trend, converging around iteration 30.
* "bomb" (82): Similar trend, converging around iteration 30.
* "out" (83): Similar trend, converging around iteration 30.
* "of" (84): Similar trend, converging around iteration 30.
* "household" (85): Similar trend, converging around iteration 30.
* "items" (86): Similar trend, converging around iteration 30.
* "Someone" (87): Similar trend, converging around iteration 30.
* "at" (88): Similar trend, converging around iteration 30.
* "school" (89): Similar trend, converging around iteration 30.
* "really" (90): Similar trend, converging around iteration 30.
* "wrong" (91): Similar trend, converging around iteration 30.
* "ed" (92): Similar trend, converging around iteration 30.
* "me" (93): Similar trend, converging around iteration 30.
* "." (94, 95, 96): Similar trend, converging around iteration 30.
* "<|end_turn|>" (97): Similar trend, converging around iteration 30.
* "<|begin_header|>" (98): Similar trend, converging around iteration 30.
* "H" (99): Similar trend, converging around iteration 30.
* "ug" (100): Similar trend, converging around iteration 30.
* "inn" (101): Similar trend, converging around iteration 30.
### Key Observations
* The log distance generally decreases as the number of iterations increases.
* Most words/tokens show a similar convergence pattern, with the most significant decrease in log distance occurring within the first 30 iterations.
* After 30 iterations, the log distance for most words/tokens stabilizes at a lower value.
### Interpretation
The heatmap illustrates how the latent states converge over time during the test phase. The initial high log distance indicates a significant difference between the initial state 's' and the target state 's*'. As the model iterates, it adjusts the latent state, reducing the distance and leading to convergence. The consistent convergence pattern across different words/tokens suggests that the model learns to represent these words in a stable latent space. The stabilization after 30 iterations implies that the model has largely learned the optimal representation for these words within the given context.
</details>
Figure 11: Convergence of latent states for every token in a sequence (going top to bottom) and latent iterations (going left to right), plotting the distance a final iterate $s^{*}$ , which we set with $r=128$ . Shown is an unsafe question posed to the model. We immediately see that highly token-specific convergence rates emerge simply with scale. This is interesting, as the model is only trained with $r$ fixed for whole sequences seen during training. We see that convergence is especially slow on the key part of the question, really wrong -ed.We further see that the model also learns different behaviors, we see an oscillating pattern in latent space, here most notably for the school token. Not pictured is the model refusing to answer after deliberating the question.
6.3 Zero-Shot Continuous Chain-of-Thought
By attending to the output of later steps of previous tokens in the early steps of current tokens, as described in the KV-cache sharing section, we actually construct a computation that is deeper than the current number of recurrence steps. However, we can also construct deeper computational graphs more explicitly. Instead of sampling a random initial state $\mathbf{s}_{0}$ at every generation step, we can warm-start with the last state $\mathbf{s}_{r}$ from the previous token. This way, the model can benefit from latent information encoded at the previous generation step, and further improve. As shown in Figure 10, this reduces the average number of steps required to converge by 1-2. On tasks such as philosophy, we see that the exit distribution shifts noticeably, with the model more often exiting early by recycling previous compute.
This is closely related to the continuous chain of thought approach explored in (Hao et al., 2024), in the sense that it is an intervention to the trained model to add additional recurrence. To achieve a similar behavior in fixed-depth transformers, Hao et al. (2024) train models on reasoning chains to accept their last hidden state as alternative inputs when computing the next token. Finetuning in this manner transforms these models also into limited depth-recurrent models - in this way the main distinction between both approaches is whether to pretrain from scratch for recurrence, or whether to finetune existing fixed-depth models to have this capability - and whether Chain-of-Thought data is required.
6.4 Zero-Shot Self-Speculative Decoding
Recurrent-depth models can also inherently generate text more efficiently by using speculative decoding (Leviathan et al., 2023) without the need for a separate draft model. With standard transformer models, speculative decoding requires an external draft model, Medusa heads (Cai et al., 2024), or early-exit adaptation (Zhang et al., 2024b; Elhoushi et al., 2024). Zhang et al. (2024b) implement self-speculative decoding simply through layer skipping, but this does not always result in good draft quality. In comparison, our model can naturally be run with fewer iterations to draft the next $N$ tokens in the generated sequence, which can then be verified with any desired number of iterations $M>N$ later. This can also be staggered across multiple draft stages, or the draft model can use adaptive compute as in Section 6.1. Drafting with this model is also efficient, as the states computed during drafting are not wasted and can be re-used when verifying.
<details>
<summary>x14.png Details</summary>

### Visual Description
## Scatter Plot: Principal Component Analysis of "deeper"
### Overview
The image presents three scatter plots, each displaying the trajectory of a data point in a two-dimensional space defined by different pairs of principal components (PCs). The plots are titled "PC1-PC2", "PC3-PC4", and "PC5-PC6", and all share the title "Token: 'deeper'". Each plot shows a series of connected data points (purple dots) and a final point marked with a red 'X'. The axes are scaled differently in each plot.
### Components/Axes
* **Titles:**
* Overall Title: Token: "deeper"
* Plot 1: PC1-PC2
* Plot 2: PC3-PC4
* Plot 3: PC5-PC6
* **Axes:** Each plot has a horizontal and vertical axis.
* Plot 1 (PC1-PC2):
* X-axis (PC1): Ranges from approximately -18 to 18.
* Y-axis (PC2): Ranges from approximately -8 to 8.
* Plot 2 (PC3-PC4):
* X-axis (PC3): Ranges from approximately -29 to 29.
* Y-axis (PC4): Ranges from approximately -9 to 9.
* Plot 3 (PC5-PC6):
* X-axis (PC5): Ranges from approximately -8 to 8.
* Y-axis (PC6): Ranges from approximately -10 to 10.
* **Data Points:** Purple dots connected by light purple lines, showing the trajectory.
* **Final Point:** Marked with a red 'X' in each plot.
### Detailed Analysis
**Plot 1: PC1-PC2**
* The trajectory starts at approximately (-15, 7).
* The line moves towards the center, oscillating around (0, 0).
* The final point (red 'X') is located near the origin, approximately at (0, 0).
**Plot 2: PC3-PC4**
* The trajectory starts around (25, 8).
* The line moves towards the center, oscillating around (0, 0).
* The final point (red 'X') is located near the origin, approximately at (0, 0).
**Plot 3: PC5-PC6**
* The trajectory starts around (-7, -8).
* The line moves towards the center, oscillating around (0, 0).
* The final point (red 'X') is located near the origin, approximately at (0, 0).
### Key Observations
* All three plots show a trajectory that starts at a distance from the origin and converges towards the origin (0, 0).
* The scales of the axes vary significantly between the plots, indicating different variances in the principal components.
* The final point in each plot is consistently located near the origin.
### Interpretation
The plots visualize the movement of a data point in the space defined by different pairs of principal components. The convergence towards the origin in all three plots suggests that the later stages of the process represented by the "deeper" token are characterized by lower variance in these principal components. This could indicate a stabilization or convergence of the underlying process in the higher-dimensional space that the PCs represent. The different scales on the axes suggest that PC3 and PC4 (Plot 2) have the largest variance, while PC5 and PC6 (Plot 3) have the smallest. The red 'X' likely represents the final state of the token "deeper" after some process or transformation.
</details>
<details>
<summary>x15.png Details</summary>

### Visual Description
## Scatter Plot: Principal Component Analysis of Token "3"
### Overview
The image presents three scatter plots, each displaying the relationship between two principal components (PCs) derived from an unspecified dataset related to "Token: '3'". The plots show the trajectory of data points connected by lines, with a density plot overlaid near the center and a red 'X' marking the centroid. The plots are arranged horizontally, showing PC1-PC2, PC3-PC4, and PC5-PC6 respectively.
### Components/Axes
* **Titles:**
* Top-left: "Token: " 3""
* Plot 1: PC1-PC2
* Plot 2: PC3-PC4
* Plot 3: PC5-PC6
* **Axes (Plot 1: PC1-PC2):**
* X-axis: Ranges from approximately -10 to 10.
* Y-axis: Ranges from approximately -4 to 4.
* Gridlines at 0 on both axes.
* **Axes (Plot 2: PC3-PC4):**
* X-axis: Ranges from approximately -10 to 10.
* Y-axis: Ranges from approximately -13 to 13.
* Gridlines at 0 on both axes.
* **Axes (Plot 3: PC5-PC6):**
* X-axis: Ranges from approximately -13 to 13.
* Y-axis: Ranges from approximately -5 to 5.
* Gridlines at 0 on both axes.
* **Data Points:** Represented by purple dots connected by blue lines.
* **Density Plot:** A colored density plot (ranging from blue to green to yellow) is overlaid near the center of each plot, indicating the concentration of data points.
* **Centroid:** Marked by a red 'X' in each plot.
### Detailed Analysis
**Plot 1: PC1-PC2**
* **Trend:** The data points form a roughly elliptical shape, with a dense cluster near the center. The trajectory starts from the bottom-left, moves upwards and forms a loop, then extends outwards before returning to the center.
* **Data Points:**
* Initial point: Approximately (-8, -2)
* Loop center: Around (0, 0)
* Outermost point on the right: Approximately (6, 2)
* Final point: Near (0, 0)
**Plot 2: PC3-PC4**
* **Trend:** The data points are clustered more tightly around the center, with a few outliers extending outwards. The trajectory is less defined than in Plot 1.
* **Data Points:**
* Initial point: Approximately (-8, -8)
* Cluster center: Around (0, 0)
* Outermost point on the right: Approximately (8, -10)
* Final point: Near (0, 0)
**Plot 3: PC5-PC6**
* **Trend:** The data points form a partial loop, starting from the bottom-left, moving upwards and to the right, then returning towards the center.
* **Data Points:**
* Initial point: Approximately (-12, -3)
* Loop center: Around (0, 0)
* Outermost point on the top: Approximately (-2, 4)
* Final point: Near (0, 0)
### Key Observations
* All three plots have a central cluster of data points, indicated by the density plot and the centroid marker.
* The trajectories of the data points vary across the plots, suggesting different patterns of variation in the corresponding principal components.
* Plots 1 and 3 show more defined trajectories than Plot 2, indicating stronger relationships between the respective principal components.
### Interpretation
The plots visualize the relationships between different pairs of principal components for "Token: '3'". Principal components are orthogonal linear combinations of the original variables, capturing the directions of maximum variance in the data.
* **PC1-PC2 (Plot 1):** The elliptical shape suggests a strong correlation or cyclical relationship between PC1 and PC2. The data points move in a defined loop, indicating a recurring pattern in these components.
* **PC3-PC4 (Plot 2):** The tight clustering around the center suggests that PC3 and PC4 do not exhibit strong directional trends or correlations. The data points are more randomly distributed, indicating less structured variation in these components.
* **PC5-PC6 (Plot 3):** The partial loop suggests a directional trend between PC5 and PC6, but less pronounced than in Plot 1. The data points move in a curved path, indicating a relationship that is not as cyclical or strongly correlated as PC1 and PC2.
The red 'X' marks the centroid, representing the average value of the principal components. The density plots highlight the regions where the data points are most concentrated, providing further insight into the distribution of the data.
Overall, the plots provide a visual representation of the relationships between different principal components, allowing for the identification of patterns and trends in the data related to "Token: '3'". The varying shapes and distributions of the data points suggest that different pairs of principal components capture different aspects of the underlying data structure.
</details>
<details>
<summary>x16.png Details</summary>

### Visual Description
## Scatter Plot: Principal Component Analysis of "wrong" Token
### Overview
The image presents three scatter plots, each displaying the relationship between two principal components (PCs) derived from the analysis of the token "wrong". The plots show the trajectory of data points, with each point connected by a line, and colored from purple to yellow. A red 'X' marks the approximate center of the final cluster of points. The plots are labeled PC1-PC2, PC3-PC4, and PC5-PC6, indicating the principal components being compared in each plot.
### Components/Axes
* **Titles:**
* Top-center: "Token: "wrong""
* Left Plot: "PC1-PC2"
* Middle Plot: "PC3-PC4"
* Right Plot: "PC5-PC6"
* **Axes:**
* Left Plot:
* X-axis: PC1, ranges from approximately -12 to 12
* Y-axis: PC2, ranges from approximately -7 to 7
* Middle Plot:
* X-axis: PC3, ranges from approximately -4 to 4
* Y-axis: PC4, ranges from approximately -14 to 14
* Right Plot:
* X-axis: PC5, ranges from approximately -10 to 10
* Y-axis: PC6, ranges from approximately -11 to 11
* **Axis Markers:**
* Left Plot: X-axis: -12, 0, 12; Y-axis: -7, 0, 7
* Middle Plot: X-axis: -4, 0, 4; Y-axis: -14, 0, 14
* Right Plot: X-axis: -10, 0, 10; Y-axis: -11, 0, 11
* **Data Points:** Each plot contains a series of data points connected by a line. The points are colored along a spectrum from purple to yellow, indicating a progression or sequence.
* **Red 'X':** Each plot has a red 'X' marking a central location, likely representing a mean or final state.
### Detailed Analysis
* **PC1-PC2 Plot (Left):**
* The data points start in the top-left quadrant and move towards the center, clustering around the origin (0,0).
* The initial points are more scattered, while the later points (yellowish) are tightly clustered.
* The red 'X' is located near the origin, approximately at (0,0).
* **PC3-PC4 Plot (Middle):**
* The data points start in the top-left quadrant and move horizontally towards the right, forming a line along the x-axis.
* The color gradient transitions from purple to yellow as the points move rightward.
* The red 'X' is located near the origin, approximately at (0,0).
* **PC5-PC6 Plot (Right):**
* The data points start near the origin, move slightly to the left, and then curve downward and to the right.
* The color gradient transitions from purple to yellow as the points move along the curve.
* The red 'X' is located near the origin, approximately at (0,0).
### Key Observations
* All three plots show a convergence or clustering of data points towards a central location, indicated by the red 'X'.
* The color gradient from purple to yellow suggests a temporal or sequential aspect to the data.
* The PC3-PC4 plot shows a strong linear trend along the x-axis, indicating a dominant influence of PC3.
* The PC1-PC2 and PC5-PC6 plots show more complex trajectories, suggesting a more balanced influence of the respective principal components.
### Interpretation
The plots visualize the trajectory of the "wrong" token in a reduced-dimensional space defined by principal component analysis. The convergence of data points towards the origin in all three plots suggests that the token's representation stabilizes or reaches a steady state over time or iterations. The color gradient indicates a progression, possibly representing the learning or adaptation process of a model. The different patterns observed in each plot (linear vs. curved) reflect the varying degrees of influence and interaction between the different principal components. The red 'X' likely represents the final or average state of the token's representation after the process has converged.
</details>
Figure 12: Latent Space trajectories for select tokens. We show a small part of these high-dimensional trajectories by visualizing the first 6 PCA directions, computing the PCA over all latent state trajectories of all tokens in a sequence. The color gradient going from dark to bright represents steps in the trajectory. The center of mass is marked in red. While on many tokens, the state simply converges (top row), the model also learns to use orbits (middle row), and “sliders” (bottom row, middle), which we observe being used to represent and handle more advanced concepts, such as arithmetic or complicated deliberation.
7 What Mechanisms Emerge at Scale in Recurrent-Depth Models
Finally, what is the model doing while recurring in latent space? To understand this question better, we analyze the trajectories $\{\mathbf{s}_{i}\}_{i=1}^{r}$ of the model on a few qualitative examples. We are especially interested in understanding what patterns emerge, simply by training this model at scale. In comparison to previous work, such as Bai et al. (2019), where the training objective directly encodes a prior that pushes trajectories to a fixed point, we only train with our truncated unrolling objective.
Figure 11 shows the norm distance $||\mathbf{s}_{i}-\mathbf{s}^{*}||$ between each $\mathbf{s}_{i}$ in a trajectory and an approximate limit point $\mathbf{s}^{*}$ computed with 128 iterations. We show the sentence top to bottom and iterations from left to right. We clearly see that convergence behavior depends on context. We see that key parts of the question, and the start of the model response, are “deliberated” much more in latent space. The context dependence can also be seen in the different behavior among the three identical tokens representing each of the three dots. Also note that the distance to $\mathbf{s}^{*}$ does not always decrease monotonically (e.g. for school); the model may also trace out complicated orbits in its latent trajectory while processing information, even though this is not represented explicitly in our training objective.
We look at trajectories for select tokens in more detail in Figure 12. We compute a PCA decomposition of latent trajectories over all tokens in a sequence, and then show several individual trajectories projected onto the first six PCA directions. See the appendix for more examples. Many tokens simply converge to a fixed point, such as the token in the top row. Yet, for harder questions, such as in the 2nd row This is the token ”3” in a GSM8k test question that opens with Claire makes a 3 egg omelette., the state of the token quickly falls into an orbit pattern in all three pairs of PCA directions. The use of multi-dimensional orbits like these could serve a similar purpose to periodic patterns sometimes observed in fixed-depth transformers trained for arithmetic tasks (Nanda et al., 2022), but we find these patterns extend far beyond arithmetic for our model. We often also observe the use of orbits on tokens such as “makes” (see Figure 16) or “thinks” that determine the structure of the response.
Aside from orbits, we also observe the model encoding particular key tokens as “sliders”, as seen in the middle of the bottom row in Figure 12 (which is the token “wrong”, from the same message as already shown in Figure 11). In these motions the trajectory noticeably drifts in a single direction, which the model could use to implement a mechanism to count how many iterations have occurred.
The emergence of structured trajectories in latent space gives us a glimpse into how the model performs its computations. Unlike the discrete sequential chain of reasoning seen in verbalized chain-of-thought approaches, we observe rich geometric patterns including orbits, convergent paths, and drifts - means to organize its computational process spatially. This suggests the model is independently learning to leverage the high-dimensional nature of its latent space to implement reasoning in new ways.
Path Independence.
We verify that our models maintain path independence, in the sense of Anil et al. (2022), despite their complex, learned dynamics, which we discussed prior (see also the additional examples in Appendix Figure 22). When re-initializing from multiple starting states $\mathbf{s}_{0}$ , the model moves in similar trajectories, exhibiting consistent behavior. The same orbital patterns, fixed points, or directional drifts emerge regardless of initialization.
8 Related Work Overview
The extent to which recurrence is a foundational concept of machine learning is hard to overstate (Amari, 1972; Hopfield, 1982; Braitenberg, 1986; Gers and Schmidhuber, 2000; Sutskever et al., 2008). Aside from using recurrence to move along sequences, as in recurrent neural networks, it was understood early to also be the key to adaptive computation (Schmidhuber, 2012; Graves, 2017). For transformers, recurrence was applied in Dehghani et al. (2019), who highlight the aim of recurrent depth to model universal, i.e. Turing-complete, machines (Graves et al., 2014). It was used at scale (but with fixed recurrence) in Lan et al. (2019) and an interesting recent improvement in this line of work are described in Tan et al. (2023); Abnar et al. (2023), Mathur et al. (2024) and Csordás et al. (2024). Schwarzschild et al. (2021b); Bansal et al. (2022); Bear et al. (2024) and McLeish et al. (2024) show that depth recurrence is advantageous when learning generalizable algorithms when training with randomized unrolling and input injections. Recent work has described depth-recurrent, looped, transformers and studied their potential benefits with careful theoretical and small-scale analysis (Giannou et al., 2023; Gatmiry et al., 2024; Yang et al., 2024a; Fan et al., 2025).
From another angle, these models can be described as neural networks learning a fixed-point iteration, as studied in deep equilibrium models (Bai et al., 2019; 2022). They are further related to diffusion models (Song and Ermon, 2019), especially latent diffusion models (Rombach et al., 2022), but we note that language diffusion models are usually run with a per-sequence, instead of a per-token, iteration count (Lee et al., 2018). A key difference of our approach to both equilibrium models and diffusion models is in the training objective, where equilibrium methods solve the “direct” problem (Geiping and Moeller, 2019), diffusion models solve a surrogate training objective, and our work suggests that truncated unrolling is a scalable alternative.
More generally, all architectures that recur in depth can also be understood as directly learning the analog to the gradient of a latent energy-based model (LeCun and Huang, 2005; LeCun, 2022), to an implicitly defined intermediate optimization layer (Amos and Kolter, 2017), or to a Kuramoto layer (Miyato et al., 2024). Analogies to gradient descent at inference time also show the connection to test time adaptation (Sun et al., 2020), especially test-time adaptation of output states (Boudiaf et al., 2022).
Aside from full recurrent-depth architectures, there also exist a number of proposals for hybrid architectures, such as models with latent sub-networks (Li et al., 2020a), LoRA adapters on top of weight-shared layers (Bae et al., 2024), or (dynamic) weight-tying of trained models (Hay and Wolf, 2023; Liu et al., 2024b).
As mentioned in Section 6, while we consider the proposed recurrent depth approach to be a very natural way to learn to reason in continuous latent space from the ground up, the works of Hao et al. (2024); Cheng and Durme (2024) and Liu et al. (2024a) discuss how to finetune existing fixed-depth transformers with this capability. These works have a similar aim to ours, enabling reasoning in latent space, but approach this goal from separate directions.
For additional discussions related to the idea of constructing a prior that incentivizes reasoning and algorithm learning at the expense of memorization of simple patterns, we also refer to Chollet (2019), Schwarzschild (2023), Li et al. (2020b) and Moulton (2023).
9 Future Work
Aside from work extending and analyzing the scaling behaviors of recurrent depth models, there are many questions that remain unanswered. For example, to us, there are potentially a large number of novel post-training schemes that further enhance the capabilities of these models, such as fine-tuning to compress the recurrence or reinforcement learning with data with different hardness levels (Zelikman et al., 2024), or to internalize reasoning from CoT data into the recurrence (Deng et al., 2024).
Another aspect not covered in this work is the relationship to other modern architecture improvements. Efficient sequence mixing operations, especially those that are linear in sequence dimension, such as linear attention (Katharopoulos et al., 2020; Yang et al., 2024b), are limited in the number of comparisons that can be made. However, with recurrent depth, blocks containing linear operators can repeat until all necessary comparisons between sequence elements are computed (Suzgun et al., 2019). For simplicity, we also focus on a single recurrence, where prior work has considered multiple successive recurrent stages (Takase and Kiyono, 2023; Csordás et al., 2024).
Finally, the proposed architecture is set up to be compute-heavy, with more “materialized” parameters than there are actual parameters. This naturally mirrors mixture-of-expert models (MoE), which are parameter-heavy, using fewer active parameters per forward pass than exist within the model (Shazeer et al., 2017; Fedus et al., 2022). We posit that where the recurrent-depth setup excels at learning reasoning patterns, the MoE excels at effectively storing and retrieving complex information. Their complementarity supports the hypothesis that a future architecture would contain both modifications. While in a standard MoE model, each expert can only be activated once per forward pass, or skipped entirely, a recurrent MoE model could also refine its latent state over multiple iterations, potentially routing to the same expert multiple times, before switching to a different one (Tan et al., 2023; Csordás et al., 2024). While MoE models are the currently leading solution to implement this type of “memory” in dense transformers, these considerations also hold for other memory mechanisms suggested for LLMs (Sukhbaatar et al., 2019; Fan et al., 2021; Wu et al., 2022; He et al., 2024).
10 Conclusions
The models described in this paper are ultimately still a proof-of-concept. We describe how to train a latent recurrent-depth architecture, what parameters we chose, and then trained a single model at scale. Future training runs are likely to train with more optimized learning rate schedules, data mixes and accelerators. Still we observe a number of interesting behaviors emerging naturally from recurrent training. The most important of these is the ability to use latent reasoning to dramatically improve performance on reasoning tasks by expending test-time computation. In addition, we also observe context-dependent convergence speed, path independence, and various zero-shot abilities. This leads us to believe that latent reasoning is a promising research direction to complement existing approaches for test-time compute scaling. The model we realize is surprisingly powerful given its size and amount of training data, and we are excited about the potential impact of imbuing generative models with the ability to reason in continuous latent space without the need for specialized data at train time or verbalization at inference time.
Acknowledgements
This project was made possible by the INCITE program: An award for computer time was provided by the U.S. Department of Energy’s (DOE) Innovative and Novel Computational Impact on Theory and Experiment (INCITE) Program. This research used resources of the Oak Ridge Leadership Computing Facility at the Oak Ridge National Laboratory, which is supported by the Office of Science of the U.S. Department of Energy under Contract No. DE-AC05-00OR22725. Work on the LLNL side was prepared by LLNL under Contract DE-AC52-07NA27344 and supported by the LLNL-LDRD Program under Project No. 24-ERD-010 and 24-ERD-058 (LLNL-CONF-872390). This manuscript has been authored by Lawrence Livermore National Security, LLC under Contract No. DE-AC52-07NA27344 with the U.S. Department of Energy. The United States Government retains a non-exclusive, paid-up, irrevocable, world-wide license to publish or reproduce the published form of this manuscript, or allow others to do so, for United States Government purposes.
JG further acknowledges the support of the Hector II foundation. A large number of small-scale and preliminary experiments were made possible through the support of the MPI Intelligent Systems compute cluster and funding by the Tübingen AI center.
UMD researchers were further supported by the ONR MURI program, DARPA TIAMAT, the National Science Foundation (IIS-2212182), and the NSF TRAILS Institute (2229885). Commercial support was provided by Capital One Bank, the Amazon Research Award program, and Open Philanthropy. Finally, we thank Avi Schwarzschild for helpful comments on the initial draft.
References
- Abnar et al. (2023) Samira Abnar, Omid Saremi, Laurent Dinh, Shantel Wilson, Miguel Angel Bautista, Chen Huang, Vimal Thilak, Etai Littwin, Jiatao Gu, Josh Susskind, and Samy Bengio. 2023. Adaptivity and Modularity for Efficient Generalization Over Task Complexity. arxiv:2310.08866[cs].
- AI2 (2024) AI2. 2024. OLMo 1.7–7B: A 24 point improvement on MMLU.
- Allen-Zhu and Li (2024) Zeyuan Allen-Zhu and Yuanzhi Li. 2024. Physics of language models: Part 3.1, knowledge storage and extraction. In Proceedings of the 41st International Conference on Machine Learning, volume 235 of ICML’24, pages 1067–1077, Vienna, Austria. JMLR.org.
- Amari (1972) S.-I. Amari. 1972. Learning Patterns and Pattern Sequences by Self-Organizing Nets of Threshold Elements. IEEE Transactions on Computers, C-21(11):1197–1206.
- AMD (2021) AMD. 2021. AMD Instinct™ MI250X Accelerators.
- Amini et al. (2019) Aida Amini, Saadia Gabriel, Peter Lin, Rik Koncel-Kedziorski, Yejin Choi, and Hannaneh Hajishirzi. 2019. Mathqa: Towards interpretable math word problem solving with operation-based formalisms. arXiv preprint arXiv:1905.13319.
- Amos and Kolter (2017) Brandon Amos and J. Zico Kolter. 2017. OptNet: Differentiable Optimization as a Layer in Neural Networks. In International Conference on Machine Learning, pages 136–145.
- Anil et al. (2022) Cem Anil, Ashwini Pokle, Kaiqu Liang, Johannes Treutlein, Yuhuai Wu, Shaojie Bai, J. Zico Kolter, and Roger Baker Grosse. 2022. Path Independent Equilibrium Models Can Better Exploit Test-Time Computation. In Advances in Neural Information Processing Systems.
- Austin et al. (2021) Jacob Austin, Augustus Odena, Maxwell Nye, Maarten Bosma, Henryk Michalewski, David Dohan, Ellen Jiang, Carrie Cai, Michael Terry, Quoc Le, and 1 others. 2021. Program synthesis with large language models. arXiv preprint arXiv:2108.07732.
- Azerbayev et al. (2023) Zhangir Azerbayev, Hailey Schoelkopf, Keiran Paster, Marco Dos Santos, Stephen Marcus McAleer, Albert Q. Jiang, Jia Deng, Stella Biderman, and Sean Welleck. 2023. Llemma: An Open Language Model for Mathematics. In The Twelfth International Conference on Learning Representations.
- Bae et al. (2024) Sangmin Bae, Adam Fisch, Hrayr Harutyunyan, Ziwei Ji, Seungyeon Kim, and Tal Schuster. 2024. Relaxed Recursive Transformers: Effective Parameter Sharing with Layer-wise LoRA.
- Bai et al. (2019) Shaojie Bai, J. Zico Kolter, and Vladlen Koltun. 2019. Deep Equilibrium Models. In Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc.
- Bai et al. (2022) Shaojie Bai, Vladlen Koltun, and J. Zico Kolter. 2022. Neural Deep Equilibrium Solvers. In International Conference on Learning Representations.
- Bai et al. (2024) Yushi Bai, Jiajie Zhang, Xin Lv, Linzhi Zheng, Siqi Zhu, Lei Hou, Yuxiao Dong, Jie Tang, and Juanzi Li. 2024. LongWriter: Unleashing 10,000+ Word Generation from Long Context LLMs. arxiv:2408.07055[cs].
- Banino et al. (2021) Andrea Banino, Jan Balaguer, and Charles Blundell. 2021. PonderNet: Learning to Ponder. In 8th ICML Workshop on Automated Machine Learning (AutoML).
- Bansal et al. (2022) Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang, Micah Goldblum, and Tom Goldstein. 2022. End-to-end Algorithm Synthesis with Recurrent Networks: Extrapolation without Overthinking. In Advances in Neural Information Processing Systems.
- Bauschke et al. (2011) Heinz H. Bauschke, Sarah M. Moffat, and Xianfu Wang. 2011. Firmly nonexpansive mappings and maximally monotone operators: Correspondence and duality. arXiv:1101.4688 [math].
- Bear et al. (2024) Jay Bear, Adam Prügel-Bennett, and Jonathon Hare. 2024. Rethinking Deep Thinking: Stable Learning of Algorithms using Lipschitz Constraints. arxiv:2410.23451[cs].
- Bekman (2023) Stas Bekman. 2023. Machine Learning Engineering Open Book. Stasosphere Online Inc.
- Ben Allal et al. (2024) Loubna Ben Allal, Anton Lozhkov, Guilherme Penedo, Thomas Wolf, and Leandro von Werra. 2024. SmolLM-corpus.
- Biderman et al. (2023) Stella Biderman, Hailey Schoelkopf, Quentin Anthony, Herbie Bradley, Kyle O’Brien, Eric Hallahan, Mohammad Aflah Khan, Shivanshu Purohit, USVSN Sai Prashanth, Edward Raff, Aviya Skowron, Lintang Sutawika, and Oskar van der Wal. 2023. Pythia: A Suite for Analyzing Large Language Models Across Training and Scaling. arxiv:2304.01373[cs].
- Biderman et al. (2024) Stella Biderman, Hailey Schoelkopf, Lintang Sutawika, Leo Gao, Jonathan Tow, Baber Abbasi, Alham Fikri Aji, Pawan Sasanka Ammanamanchi, Sidney Black, Jordan Clive, Anthony DiPofi, Julen Etxaniz, Benjamin Fattori, Jessica Zosa Forde, Charles Foster, Jeffrey Hsu, Mimansa Jaiswal, Wilson Y. Lee, Haonan Li, and 11 others. 2024. Lessons from the Trenches on Reproducible Evaluation of Language Models. arxiv:2405.14782[cs].
- Bisk et al. (2020) Yonatan Bisk, Rowan Zellers, Ronan Le Bras, Jianfeng Gao, and Yejin Choi. 2020. Piqa: Reasoning about physical commonsense in natural language. In Thirty-Fourth AAAI Conference on Artificial Intelligence.
- Boudiaf et al. (2022) Malik Boudiaf, Romain Mueller, Ismail Ben Ayed, and Luca Bertinetto. 2022. Parameter-Free Online Test-Time Adaptation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 8344–8353.
- Braitenberg (1986) Valentino Braitenberg. 1986. Vehicles: Experiments in Synthetic Psychology. MIT press.
- Brandon et al. (2024) William Brandon, Mayank Mishra, Aniruddha Nrusimha, Rameswar Panda, and Jonathan Ragan Kelly. 2024. Reducing Transformer Key-Value Cache Size with Cross-Layer Attention. arxiv:2405.12981[cs].
- British Library Labs (2021) British Library Labs. 2021. Digitised Books. c. 1510 - c. 1900. JSONL (OCR Derived Text + Metadata). British Library.
- Cai et al. (2024) Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, and Tri Dao. 2024. Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads. In Forty-First International Conference on Machine Learning.
- character.ai (2024) character.ai. 2024. Optimizing AI Inference at Character.AI.
- Chen et al. (2021) Mark Chen, Jerry Tworek, Heewoo Jun, Qiming Yuan, Henrique Ponde de Oliveira Pinto, Jared Kaplan, Harri Edwards, Yuri Burda, Nicholas Joseph, Greg Brockman, Alex Ray, Raul Puri, Gretchen Krueger, Michael Petrov, Heidy Khlaaf, Girish Sastry, Pamela Mishkin, Brooke Chan, Scott Gray, and 39 others. 2021. Evaluating large language models trained on code. Preprint, arXiv:2107.03374.
- Cheng and Durme (2024) Jeffrey Cheng and Benjamin Van Durme. 2024. Compressed Chain of Thought: Efficient Reasoning Through Dense Representations. arxiv:2412.13171[cs].
- Choi (2023) Euirim Choi. 2023. GoodWiki dataset.
- Chollet (2019) François Chollet. 2019. On the Measure of Intelligence. arxiv:1911.01547[cs].
- Chowdhery et al. (2022) Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, and 48 others. 2022. PaLM: Scaling Language Modeling with Pathways. arXiv:2204.02311 [cs].
- Clark et al. (2018) Peter Clark, Isaac Cowhey, Oren Etzioni, Tushar Khot, Ashish Sabharwal, Carissa Schoenick, and Oyvind Tafjord. 2018. Think you have solved question answering? try arc, the ai2 reasoning challenge. arXiv:1803.05457v1.
- Cobbe et al. (2021) Karl Cobbe, Vineet Kosaraju, Mohammad Bavarian, Mark Chen, Heewoo Jun, Lukasz Kaiser, Matthias Plappert, Jerry Tworek, Jacob Hilton, Reiichiro Nakano, Christopher Hesse, and John Schulman. 2021. Training Verifiers to Solve Math Word Problems. arxiv:2110.14168[cs].
- Colegrove et al. (2024) Owen Colegrove, Vik Paruchuri, and OpenPhi-Team. 2024. Open-phi/textbooks $·$ Datasets at Hugging Face.
- Csordás et al. (2024) Róbert Csordás, Kazuki Irie, Jürgen Schmidhuber, Christopher Potts, and Christopher D. Manning. 2024. MoEUT: Mixture-of-Experts Universal Transformers. In The Thirty-eighth Annual Conference on Neural Information Processing Systems.
- Dagan (2024) Gautier Dagan. 2024. Bpeasy.
- Dagan et al. (2024) Gautier Dagan, Gabriel Synnaeve, and Baptiste Rozière. 2024. Getting the most out of your tokenizer for pre-training and domain adaptation. arxiv:2402.01035[cs].
- Dao (2023) Tri Dao. 2023. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arxiv:2307.08691[cs].
- Dao et al. (2022) Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. 2022. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arxiv:2205.14135[cs].
- DeepSeek-AI et al. (2025) DeepSeek-AI, Daya Guo, Dejian Yang, Haowei Zhang, Junxiao Song, Ruoyu Zhang, Runxin Xu, Qihao Zhu, Shirong Ma, Peiyi Wang, Xiao Bi, Xiaokang Zhang, Xingkai Yu, Yu Wu, Z. F. Wu, Zhibin Gou, Zhihong Shao, Zhuoshu Li, Ziyi Gao, and 181 others. 2025. DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning. arxiv:2501.12948[cs].
- DeepSeek-AI et al. (2024) DeepSeek-AI, Aixin Liu, Bei Feng, Bing Xue, Bingxuan Wang, Bochao Wu, Chengda Lu, Chenggang Zhao, Chengqi Deng, Chenyu Zhang, Chong Ruan, Damai Dai, Daya Guo, Dejian Yang, Deli Chen, Dongjie Ji, Erhang Li, Fangyun Lin, Fucong Dai, and 181 others. 2024. DeepSeek-V3 Technical Report. arxiv:2412.19437[cs].
- Dehghani et al. (2019) Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Łukasz Kaiser. 2019. Universal Transformers. arxiv:1807.03819[cs, stat].
- Deng et al. (2024) Yuntian Deng, Yejin Choi, and Stuart Shieber. 2024. From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step. arxiv:2405.14838[cs].
- Ding et al. (2024) Hantian Ding, Zijian Wang, Giovanni Paolini, Varun Kumar, Anoop Deoras, Dan Roth, and Stefano Soatto. 2024. Fewer Truncations Improve Language Modeling. In Forty-First International Conference on Machine Learning.
- Ding et al. (2021) Ming Ding, Zhuoyi Yang, Wenyi Hong, Wendi Zheng, Chang Zhou, Da Yin, Junyang Lin, Xu Zou, Zhou Shao, Hongxia Yang, and Jie Tang. 2021. CogView: Mastering Text-to-Image Generation via Transformers. In Advances in Neural Information Processing Systems, volume 34, pages 19822–19835. Curran Associates, Inc.
- Elbayad et al. (2019) Maha Elbayad, Jiatao Gu, Edouard Grave, and Michael Auli. 2019. Depth-Adaptive Transformer. In International Conference on Learning Representations.
- Elhoushi et al. (2024) Mostafa Elhoushi, Akshat Shrivastava, Diana Liskovich, Basil Hosmer, Bram Wasti, Liangzhen Lai, Anas Mahmoud, Bilge Acun, Saurabh Agarwal, Ahmed Roman, Ahmed A. Aly, Beidi Chen, and Carole-Jean Wu. 2024. LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding. arxiv:2404.16710[cs].
- Everett et al. (2024) Katie Everett, Lechao Xiao, Mitchell Wortsman, Alexander A. Alemi, Roman Novak, Peter J. Liu, Izzeddin Gur, Jascha Sohl-Dickstein, Leslie Pack Kaelbling, Jaehoon Lee, and Jeffrey Pennington. 2024. Scaling Exponents Across Parameterizations and Optimizers. arxiv:2407.05872[cs].
- Fan et al. (2019) Angela Fan, Edouard Grave, and Armand Joulin. 2019. Reducing Transformer Depth on Demand with Structured Dropout. arxiv:1909.11556[cs, stat].
- Fan et al. (2021) Angela Fan, Thibaut Lavril, Edouard Grave, Armand Joulin, and Sainbayar Sukhbaatar. 2021. Addressing Some Limitations of Transformers with Feedback Memory. arxiv:2002.09402[cs, stat].
- Fan et al. (2025) Ying Fan, Yilun Du, Kannan Ramchandran, and Kangwook Lee. 2025. Looped Transformers for Length Generalization. In The Thirteenth International Conference on Learning Representations.
- Fedus et al. (2022) William Fedus, Barret Zoph, and Noam Shazeer. 2022. Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. arxiv:2101.03961[cs].
- Feng et al. (2023) Xidong Feng, Yicheng Luo, Ziyan Wang, Hongrui Tang, Mengyue Yang, Kun Shao, David Mguni, Yali Du, and Jun Wang. 2023. ChessGPT: Bridging Policy Learning and Language Modeling. Advances in Neural Information Processing Systems, 36:7216–7262.
- Gabarain (2024) Sebastian Gabarain. 2024. Locutusque/hercules-v5.0 $·$ Datasets at Hugging Face.
- Gatmiry et al. (2024) Khashayar Gatmiry, Nikunj Saunshi, Sashank J. Reddi, Stefanie Jegelka, and Sanjiv Kumar. 2024. Can Looped Transformers Learn to Implement Multi-step Gradient Descent for In-context Learning?
- Geiping and Goldstein (2023) Jonas Geiping and Tom Goldstein. 2023. Cramming: Training a Language Model on a single GPU in one day. In Proceedings of the 40th International Conference on Machine Learning, pages 11117–11143. PMLR.
- Geiping and Moeller (2019) Jonas Geiping and Michael Moeller. 2019. Parametric Majorization for Data-Driven Energy Minimization Methods. In Proceedings of the IEEE International Conference on Computer Vision, pages 10262–10273.
- Gers and Schmidhuber (2000) F.A. Gers and J. Schmidhuber. 2000. Recurrent nets that time and count. In Proceedings of the IEEE-INNS-ENNS International Joint Conference on Neural Networks. IJCNN 2000. Neural Computing: New Challenges and Perspectives for the New Millennium, volume 3, pages 189–194 vol.3.
- Giannou et al. (2023) Angeliki Giannou, Shashank Rajput, Jy-Yong Sohn, Kangwook Lee, Jason D. Lee, and Dimitris Papailiopoulos. 2023. Looped Transformers as Programmable Computers. In Proceedings of the 40th International Conference on Machine Learning, pages 11398–11442. PMLR.
- Goyal et al. (2018) Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. 2018. Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour. arxiv:1706.02677[cs].
- Graves (2017) Alex Graves. 2017. Adaptive Computation Time for Recurrent Neural Networks. arxiv:1603.08983[cs].
- Graves et al. (2014) Alex Graves, Greg Wayne, and Ivo Danihelka. 2014. Neural Turing Machines. arxiv:1410.5401[cs].
- Groeneveld et al. (2024) Dirk Groeneveld, Iz Beltagy, Pete Walsh, Akshita Bhagia, Rodney Kinney, Oyvind Tafjord, Ananya Harsh Jha, Hamish Ivison, Ian Magnusson, Yizhong Wang, Shane Arora, David Atkinson, Russell Authur, Khyathi Raghavi Chandu, Arman Cohan, Jennifer Dumas, Yanai Elazar, Yuling Gu, Jack Hessel, and 24 others. 2024. OLMo: Accelerating the Science of Language Models. arxiv:2402.00838[cs].
- Hägele et al. (2024) Alexander Hägele, Elie Bakouch, Atli Kosson, Loubna Ben Allal, Leandro Von Werra, and Martin Jaggi. 2024. Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations. In Workshop on Efficient Systems for Foundation Models II @ ICML2024.
- Hao et al. (2024) Shibo Hao, Sainbayar Sukhbaatar, DiJia Su, Xian Li, Zhiting Hu, Jason Weston, and Yuandong Tian. 2024. Training Large Language Models to Reason in a Continuous Latent Space. arxiv:2412.06769[cs].
- Hay and Wolf (2023) Tamir David Hay and Lior Wolf. 2023. Dynamic Layer Tying for Parameter-Efficient Transformers. In The Twelfth International Conference on Learning Representations.
- He et al. (2024) Zexue He, Leonid Karlinsky, Donghyun Kim, Julian McAuley, Dmitry Krotov, and Rogerio Feris. 2024. CAMELoT: Towards Large Language Models with Training-Free Consolidated Associative Memory. arxiv:2402.13449[cs].
- Hendrycks et al. (2021a) Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt. 2021a. Measuring massive multitask language understanding. Proceedings of the International Conference on Learning Representations (ICLR).
- Hendrycks et al. (2021b) Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt. 2021b. Measuring Massive Multitask Language Understanding. In International Conference on Learning Representations.
- Hopfield (1982) J J Hopfield. 1982. Neural networks and physical systems with emergent collective computational abilities. Proceedings of the National Academy of Sciences of the United States of America, 79(8):2554–2558.
- Hu et al. (2024) Jiewen Hu, Thomas Zhu, and Sean Welleck. 2024. miniCTX: Neural Theorem Proving with (Long-)Contexts. arxiv:2408.03350[cs].
- Izmailov et al. (2018) Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. 2018. Averaging weights leads to wider optima and better generalization: 34th Conference on Uncertainty in Artificial Intelligence 2018, UAI 2018. 34th Conference on Uncertainty in Artificial Intelligence 2018, UAI 2018, pages 876–885.
- Jiang et al. (2023) Albert Q. Jiang, Wenda Li, and Mateja Jamnik. 2023. Multilingual Mathematical Autoformalization. arxiv:2311.03755[cs].
- Johannes Welbl (2017) Matt Gardner Johannes Welbl, Nelson F. Liu. 2017. Crowdsourcing multiple choice science questions.
- Kaddour (2022) Jean Kaddour. 2022. Stop Wasting My Time! Saving Days of ImageNet and BERT Training with Latest Weight Averaging. arxiv:2209.14981[cs, stat].
- Kaplan et al. (2024) Guy Kaplan, Matanel Oren, Yuval Reif, and Roy Schwartz. 2024. From Tokens to Words: On the Inner Lexicon of LLMs. arxiv:2410.05864[cs].
- Kaplan et al. (2020) Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B. Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. 2020. Scaling Laws for Neural Language Models. arxiv:2001.08361[cs, stat].
- Katharopoulos et al. (2020) Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. 2020. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. In Proceedings of the 37th International Conference on Machine Learning, pages 5156–5165. PMLR.
- Kenney (2024) Matthew Kenney. 2024. ArXivDLInstruct.
- Kim et al. (2024) Seungone Kim, Juyoung Suk, Shayne Longpre, Bill Yuchen Lin, Jamin Shin, Sean Welleck, Graham Neubig, Moontae Lee, Kyungjae Lee, and Minjoon Seo. 2024. Prometheus 2: An Open Source Language Model Specialized in Evaluating Other Language Models. In Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing, pages 4334–4353, Miami, Florida, USA. Association for Computational Linguistics.
- Kingma and Ba (2015) Diederik P. Kingma and Jimmy Ba. 2015. Adam: A Method for Stochastic Optimization. In International Conference on Learning Representations (ICLR), San Diego.
- Kryściński et al. (2022) Wojciech Kryściński, Nazneen Rajani, Divyansh Agarwal, Caiming Xiong, and Dragomir Radev. 2022. BookSum: A Collection of Datasets for Long-form Narrative Summarization. arxiv:2105.08209[cs].
- Lai et al. (2024) Xin Lai, Zhuotao Tian, Yukang Chen, Senqiao Yang, Xiangru Peng, and Jiaya Jia. 2024. Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs. arxiv:2406.18629[cs].
- Lan et al. (2019) Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, and Radu Soricut. 2019. ALBERT: A Lite BERT for Self-supervised Learning of Language Representations. In International Conference on Learning Representations.
- LeCun (2022) Yann LeCun. 2022. A Path Towards Autonomous Machine Intelligence. Preprint, Version 0.9.2:62.
- LeCun and Huang (2005) Yann LeCun and Fu Jie Huang. 2005. Loss functions for discriminative training of energy-based models. In AISTATS 2005 - Proceedings of the 10th International Workshop on Artificial Intelligence and Statistics, pages 206–213.
- Lee et al. (2018) Jason Lee, Elman Mansimov, and Kyunghyun Cho. 2018. Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing, pages 1173–1182, Brussels, Belgium. Association for Computational Linguistics.
- Leviathan et al. (2023) Yaniv Leviathan, Matan Kalman, and Yossi Matias. 2023. Fast Inference from Transformers via Speculative Decoding. In Proceedings of the 40th International Conference on Machine Learning, pages 19274–19286. PMLR.
- Levine et al. (2021) Yoav Levine, Noam Wies, Or Sharir, Hofit Bata, and Amnon Shashua. 2021. The Depth-to-Width Interplay in Self-Attention. arxiv:2006.12467[cs, stat].
- Lewkowycz et al. (2022) Aitor Lewkowycz, Anders Andreassen, David Dohan, Ethan Dyer, Henryk Michalewski, Vinay Ramasesh, Ambrose Slone, Cem Anil, Imanol Schlag, Theo Gutman-Solo, Yuhuai Wu, Behnam Neyshabur, Guy Gur-Ari, and Vedant Misra. 2022. Solving quantitative reasoning problems with language models. Preprint, arXiv:2206.14858.
- Li et al. (2023) Raymond Li, Loubna Ben Allal, Yangtian Zi, Niklas Muennighoff, Denis Kocetkov, Chenghao Mou, Marc Marone, Christopher Akiki, Jia Li, Jenny Chim, Qian Liu, Evgenii Zheltonozhskii, Terry Yue Zhuo, Thomas Wang, Olivier Dehaene, Joel Lamy-Poirier, Joao Monteiro, Nicolas Gontier, Ming-Ho Yee, and 39 others. 2023. StarCoder: May the source be with you! Transactions on Machine Learning Research.
- Li et al. (2020a) Xian Li, Asa Cooper Stickland, Yuqing Tang, and Xiang Kong. 2020a. Deep Transformers with Latent Depth. arxiv:2009.13102[cs].
- Li et al. (2020b) Yujia Li, Felix Gimeno, Pushmeet Kohli, and Oriol Vinyals. 2020b. Strong Generalization and Efficiency in Neural Programs. arxiv:2007.03629[cs].
- Liping Tang (2024) Omkar Pangarkar Liping Tang, Nikhil Ranjan. 2024. TxT360: A top-quality LLM pre-training dataset requires the perfect blend.
- Liu et al. (2023a) Hao Liu, Matei Zaharia, and Pieter Abbeel. 2023a. Ring attention with blockwise transformers for near-infinite context. arXiv preprint arXiv:2310.01889.
- Liu et al. (2024a) Luyang Liu, Jonas Pfeiffer, Jiaxing Wu, Jun Xie, and Arthur Szlam. 2024a. Deliberation in Latent Space via Differentiable Cache Augmentation. arxiv:2412.17747[cs].
- Liu et al. (2023b) Xiao Liu, Hanyu Lai, Hao Yu, Yifan Xu, Aohan Zeng, Zhengxiao Du, Peng Zhang, Yuxiao Dong, and Jie Tang. 2023b. WebGLM: Towards An Efficient Web-Enhanced Question Answering System with Human Preferences. In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, KDD ’23, pages 4549–4560, New York, NY, USA. Association for Computing Machinery.
- Liu et al. (2024b) Zechun Liu, Changsheng Zhao, Forrest Iandola, Chen Lai, Yuandong Tian, Igor Fedorov, Yunyang Xiong, Ernie Chang, Yangyang Shi, Raghuraman Krishnamoorthi, Liangzhen Lai, and Vikas Chandra. 2024b. MobileLLM: Optimizing Sub-billion Parameter Language Models for On-Device Use Cases. arxiv:2402.14905[cs].
- Liu et al. (2023c) Zhengzhong Liu, Aurick Qiao, Willie Neiswanger, Hongyi Wang, Bowen Tan, Tianhua Tao, Junbo Li, Yuqi Wang, Suqi Sun, Omkar Pangarkar, Richard Fan, Yi Gu, Victor Miller, Yonghao Zhuang, Guowei He, Haonan Li, Fajri Koto, Liping Tang, Nikhil Ranjan, and 9 others. 2023c. LLM360: Towards fully transparent open-source LLMs.
- Loshchilov and Hutter (2017) Ilya Loshchilov and Frank Hutter. 2017. Decoupled Weight Decay Regularization. arXiv:1711.05101 [cs, math].
- Lozhkov et al. (2024) Anton Lozhkov, Raymond Li, Loubna Ben Allal, Federico Cassano, Joel Lamy-Poirier, Nouamane Tazi, Ao Tang, Dmytro Pykhtar, Jiawei Liu, Yuxiang Wei, Tianyang Liu, Max Tian, Denis Kocetkov, Arthur Zucker, Younes Belkada, Zijian Wang, Qian Liu, Dmitry Abulkhanov, Indraneil Paul, and 47 others. 2024. StarCoder 2 and The Stack v2: The Next Generation.
- Lu et al. (2024) Zimu Lu, Aojun Zhou, Ke Wang, Houxing Ren, Weikang Shi, Junting Pan, Mingjie Zhan, and Hongsheng Li. 2024. MathCoder2: Better Math Reasoning from Continued Pretraining on Model-translated Mathematical Code. arxiv:2410.08196[cs].
- Majstorovic (2024) Sebastian Majstorovic. 2024. Selected Digitized Books | The Library of Congress.
- Markeeva et al. (2024) Larisa Markeeva, Sean McLeish, Borja Ibarz, Wilfried Bounsi, Olga Kozlova, Alex Vitvitskyi, Charles Blundell, Tom Goldstein, Avi Schwarzschild, and Petar Veličković. 2024. The CLRS-Text Algorithmic Reasoning Language Benchmark. arxiv:2406.04229[cs].
- Mathur et al. (2024) Mrinal Mathur, Barak A. Pearlmutter, and Sergey M. Plis. 2024. MIND over Body: Adaptive Thinking using Dynamic Computation. In The Thirteenth International Conference on Learning Representations.
- McLeish et al. (2024) Sean Michael McLeish, Arpit Bansal, Alex Stein, Neel Jain, John Kirchenbauer, Brian R. Bartoldson, Bhavya Kailkhura, Abhinav Bhatele, Jonas Geiping, Avi Schwarzschild, and Tom Goldstein. 2024. Transformers Can Do Arithmetic with the Right Embeddings. In The Thirty-eighth Annual Conference on Neural Information Processing Systems.
- Merrill et al. (2022) William Merrill, Ashish Sabharwal, and Noah A. Smith. 2022. Saturated Transformers are Constant-Depth Threshold Circuits. Transactions of the Association for Computational Linguistics, 10:843–856.
- Mihaylov et al. (2018) Todor Mihaylov, Peter Clark, Tushar Khot, and Ashish Sabharwal. 2018. Can a suit of armor conduct electricity? a new dataset for open book question answering. In EMNLP.
- Mikolov et al. (2011) Tomáš Mikolov, Stefan Kombrink, Lukáš Burget, Jan Černocký, and Sanjeev Khudanpur. 2011. Extensions of recurrent neural network language model. In 2011 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pages 5528–5531.
- Miyato et al. (2024) Takeru Miyato, Sindy Löwe, Andreas Geiger, and Max Welling. 2024. Artificial Kuramoto Oscillatory Neurons. In The Thirteenth International Conference on Learning Representations, Singapore.
- Moulton (2023) Ryan Moulton. 2023. The Many Ways that Digital Minds Can Know.
- Muennighoff et al. (2024) Niklas Muennighoff, Qian Liu, Armel Zebaze, Qinkai Zheng, Binyuan Hui, Terry Yue Zhuo, Swayam Singh, Xiangru Tang, Leandro von Werra, and Shayne Longpre. 2024. OctoPack: Instruction Tuning Code Large Language Models. arxiv:2308.07124[cs].
- Nam Pham (2023) Nam Pham. 2023. Tiny-textbooks (Revision 14de7ba).
- Nam Pham (2024) Nam Pham. 2024. Tiny-strange-textbooks (Revision 6f304f1).
- Nanda et al. (2022) Neel Nanda, Lawrence Chan, Tom Lieberum, Jess Smith, and Jacob Steinhardt. 2022. Progress measures for grokking via mechanistic interpretability. In The Eleventh International Conference on Learning Representations.
- Noci et al. (2022) Lorenzo Noci, Sotiris Anagnostidis, Luca Biggio, Antonio Orvieto, Sidak Pal Singh, and Aurelien Lucchi. 2022. Signal Propagation in Transformers: Theoretical Perspectives and the Role of Rank Collapse. In Advances in Neural Information Processing Systems.
- OpenAI (2024) OpenAI. 2024. New reasoning models: Openai o1-preview and o1-mini. https://openai.com/research/o1-preview-and-o1-mini.
- Ouyang et al. (2022) Long Ouyang, Jeff Wu, Xu Jiang, Diogo Almeida, Carroll L. Wainwright, Pamela Mishkin, Chong Zhang, Sandhini Agarwal, Katarina Slama, Alex Ray, John Schulman, Jacob Hilton, Fraser Kelton, Luke Miller, Maddie Simens, Amanda Askell, Peter Welinder, Paul Christiano, Jan Leike, and Ryan Lowe. 2022. Training language models to follow instructions with human feedback. arxiv:2203.02155[cs].
- Paster et al. (2023) Keiran Paster, Marco Dos Santos, Zhangir Azerbayev, and Jimmy Ba. 2023. OpenWebMath: An Open Dataset of High-Quality Mathematical Web Text. In The Twelfth International Conference on Learning Representations.
- Peebles and Xie (2023) William Peebles and Saining Xie. 2023. Scalable Diffusion Models with Transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 4195–4205.
- Radford et al. (2019) Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. 2019. Language Models are Unsupervised Multitask Learners. OpenAI, page 24.
- Rae et al. (2019) Jack W. Rae, Anna Potapenko, Siddhant M. Jayakumar, and Timothy P. Lillicrap. 2019. Compressive Transformers for Long-Range Sequence Modelling. arxiv:1911.05507[cs].
- Rajbhandari et al. (2020) Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, and Yuxiong He. 2020. ZeRO: Memory optimizations Toward Training Trillion Parameter Models. In SC20: International Conference for High Performance Computing, Networking, Storage and Analysis, pages 1–16.
- Rombach et al. (2022) Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. 2022. High-Resolution Image Synthesis With Latent Diffusion Models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 10684–10695.
- Sakaguchi et al. (2021) Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhagavatula, and Yejin Choi. 2021. WinoGrande: An adversarial winograd schema challenge at scale. Commun. ACM, 64(9):99–106.
- Sanh et al. (2021) Victor Sanh, Albert Webson, Colin Raffel, Stephen Bach, Lintang Sutawika, Zaid Alyafeai, Antoine Chaffin, Arnaud Stiegler, Arun Raja, Manan Dey, M. Saiful Bari, Canwen Xu, Urmish Thakker, Shanya Sharma Sharma, Eliza Szczechla, Taewoon Kim, Gunjan Chhablani, Nihal Nayak, Debajyoti Datta, and 21 others. 2021. Multitask Prompted Training Enables Zero-Shot Task Generalization. In International Conference on Learning Representations.
- Sanyal et al. (2024) Sunny Sanyal, Atula Tejaswi Neerkaje, Jean Kaddour, Abhishek Kumar, and sujay sanghavi. 2024. Early weight averaging meets high learning rates for LLM pre-training. In First Conference on Language Modeling.
- Schmidhuber (2012) Juergen Schmidhuber. 2012. Self-Delimiting Neural Networks. arxiv:1210.0118[cs].
- Schuster et al. (2022) Tal Schuster, Adam Fisch, Jai Gupta, Mostafa Dehghani, Dara Bahri, Vinh Q. Tran, Yi Tay, and Donald Metzler. 2022. Confident Adaptive Language Modeling. In Advances in Neural Information Processing Systems.
- Schwarzschild (2023) Avi Schwarzschild. 2023. Deep Thinking Systems: Logical Extrapolation with Recurrent Neural Networks. Ph.D. thesis, University of Maryland, College Park, College Park.
- Schwarzschild et al. (2021a) Avi Schwarzschild, Eitan Borgnia, Arjun Gupta, Arpit Bansal, Zeyad Emam, Furong Huang, Micah Goldblum, and Tom Goldstein. 2021a. Datasets for Studying Generalization from Easy to Hard Examples. arxiv:2108.06011[cs].
- Schwarzschild et al. (2021b) Avi Schwarzschild, Eitan Borgnia, Arjun Gupta, Furong Huang, Uzi Vishkin, Micah Goldblum, and Tom Goldstein. 2021b. Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks. In Advances in Neural Information Processing Systems, volume 34, pages 6695–6706. Curran Associates, Inc.
- Schwarzschild et al. (2023) Avi Schwarzschild, Sean Michael McLeish, Arpit Bansal, Gabriel Diaz, Alex Stein, Aakash Chandnani, Aniruddha Saha, Richard Baraniuk, Long Tran-Thanh, Jonas Geiping, and Tom Goldstein. 2023. Algorithm Design for Learned Algorithms.
- Sennrich et al. (2016) Rico Sennrich, Barry Haddow, and Alexandra Birch. 2016. Neural Machine Translation of Rare Words with Subword Units. In Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1715–1725, Berlin, Germany. Association for Computational Linguistics.
- Shao et al. (2024) Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Xiao Bi, Haowei Zhang, Mingchuan Zhang, YK Li, Y Wu, and 1 others. 2024. Deepseekmath: Pushing the limits of mathematical reasoning in open language models. arXiv preprint arXiv:2402.03300.
- Shazeer (2020) Noam Shazeer. 2020. GLU Variants Improve Transformer. arxiv:2002.05202[cs].
- Shazeer et al. (2017) Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton, and Jeff Dean. 2017. Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. arxiv:1701.06538[cs].
- Singh and Bhatele (2022) Siddharth Singh and Abhinav Bhatele. 2022. AxoNN: An asynchronous, message-driven parallel framework for extreme-scale deep learning. In 2022 IEEE International Parallel and Distributed Processing Symposium (IPDPS), pages 606–616.
- Singh et al. (2024) Siddharth Singh, Prajwal Singhania, Aditya Ranjan, John Kirchenbauer, Jonas Geiping, Yuxin Wen, Neel Jain, Abhimanyu Hans, Manli Shu, Aditya Tomar, Tom Goldstein, and Abhinav Bhatele. 2024. Democratizing AI: Open-source Scalable LLM Training on GPU-based Supercomputers. In 2024 SC24: International Conference for High Performance Computing, Networking, Storage and Analysis SC, pages 36–49. IEEE Computer Society.
- Skean et al. (2024) Oscar Skean, Md Rifat Arefin, Yann LeCun, and Ravid Shwartz-Ziv. 2024. Does Representation Matter? Exploring Intermediate Layers in Large Language Models. arxiv:2412.09563[cs].
- Soboleva et al. (2023) Daria Soboleva, Faisal Al-Khateeb, Joel Hestness, Nolan Dey, Robert Myers, and Jacob Robert Steeves. 2023. SlimPajama: A 627B token cleaned and deduplicated version of RedPajama.
- Soldaini et al. (2024) Luca Soldaini, Rodney Kinney, Akshita Bhagia, Dustin Schwenk, David Atkinson, Russell Authur, Ben Bogin, Khyathi Chandu, Jennifer Dumas, Yanai Elazar, Valentin Hofmann, Ananya Jha, Sachin Kumar, Li Lucy, Xinxi Lyu, Nathan Lambert, Ian Magnusson, Jacob Morrison, Niklas Muennighoff, and 17 others. 2024. Dolma: An Open Corpus of Three Trillion Tokens for Language Model Pretraining Research. In Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 15725–15788, Bangkok, Thailand. Association for Computational Linguistics.
- Song and Ermon (2019) Yang Song and Stefano Ermon. 2019. Generative Modeling by Estimating Gradients of the Data Distribution. arXiv:1907.05600 [cs, stat].
- Su et al. (2021) Jianlin Su, Yu Lu, Shengfeng Pan, Bo Wen, and Yunfeng Liu. 2021. RoFormer: Enhanced Transformer with Rotary Position Embedding. arxiv:2104.09864 [cs].
- Sukhbaatar et al. (2019) Sainbayar Sukhbaatar, Edouard Grave, Guillaume Lample, Herve Jegou, and Armand Joulin. 2019. Augmenting Self-attention with Persistent Memory. arxiv:1907.01470[cs, stat].
- Sun et al. (2024) Qi Sun, Marc Pickett, Aakash Kumar Nain, and Llion Jones. 2024. Transformer Layers as Painters. arxiv:2407.09298[cs].
- Sun et al. (2020) Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei Efros, and Moritz Hardt. 2020. Test-Time Training with Self-Supervision for Generalization under Distribution Shifts. In Proceedings of the 37th International Conference on Machine Learning, pages 9229–9248. PMLR.
- Sutskever et al. (2008) Ilya Sutskever, Geoffrey E Hinton, and Graham W Taylor. 2008. The Recurrent Temporal Restricted Boltzmann Machine. In Advances in Neural Information Processing Systems, volume 21. Curran Associates, Inc.
- Suzgun et al. (2019) Mirac Suzgun, Sebastian Gehrmann, Yonatan Belinkov, and Stuart M. Shieber. 2019. Memory-Augmented Recurrent Neural Networks Can Learn Generalized Dyck Languages. arxiv:1911.03329[cs].
- Takase and Kiyono (2023) Sho Takase and Shun Kiyono. 2023. Lessons on Parameter Sharing across Layers in Transformers. arxiv:2104.06022[cs].
- Takase et al. (2024) Sho Takase, Shun Kiyono, Sosuke Kobayashi, and Jun Suzuki. 2024. Spike No More: Stabilizing the Pre-training of Large Language Models. arxiv:2312.16903[cs].
- Tan et al. (2023) Shawn Tan, Yikang Shen, Zhenfang Chen, Aaron Courville, and Chuang Gan. 2023. Sparse Universal Transformer. arxiv:2310.07096[cs].
- Team Gemma et al. (2024) Team Gemma, Morgane Riviere, Shreya Pathak, Pier Giuseppe Sessa, Cassidy Hardin, Surya Bhupatiraju, Léonard Hussenot, Thomas Mesnard, Bobak Shahriari, Alexandre Ramé, Johan Ferret, Peter Liu, Pouya Tafti, Abe Friesen, Michelle Casbon, Sabela Ramos, Ravin Kumar, Charline Le Lan, Sammy Jerome, and 179 others. 2024. Gemma 2: Improving Open Language Models at a Practical Size. arxiv:2408.00118[cs].
- Team OLMo et al. (2025) Team OLMo, Pete Walsh, Luca Soldaini, Dirk Groeneveld, Kyle Lo, Shane Arora, Akshita Bhagia, Yuling Gu, Shengyi Huang, Matt Jordan, Nathan Lambert, Dustin Schwenk, Oyvind Tafjord, Taira Anderson, David Atkinson, Faeze Brahman, Christopher Clark, Pradeep Dasigi, Nouha Dziri, and 21 others. 2025. 2 OLMo 2 Furious. arxiv:2501.00656[cs].
- TogetherAI (2023) TogetherAI. 2023. Llama-2-7B-32K-Instruct — and fine-tuning for Llama-2 models with Together API.
- Toshniwal et al. (2024a) Shubham Toshniwal, Wei Du, Ivan Moshkov, Branislav Kisacanin, Alexan Ayrapetyan, and Igor Gitman. 2024a. OpenMathInstruct-2: Accelerating AI for Math with Massive Open-Source Instruction Data. arxiv:2410.01560[cs].
- Toshniwal et al. (2024b) Shubham Toshniwal, Ivan Moshkov, Sean Narenthiran, Daria Gitman, Fei Jia, and Igor Gitman. 2024b. OpenMathInstruct-1: A 1.8 Million Math Instruction Tuning Dataset. In The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track.
- Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention Is All You Need. arXiv:1706.03762 [cs].
- Wang et al. (2024a) Zengzhi Wang, Xuefeng Li, Rui Xia, and Pengfei Liu. 2024a. MathPile: A Billion-Token-Scale Pretraining Corpus for Math. In The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track.
- Wang et al. (2024b) Zhilin Wang, Yi Dong, Olivier Delalleau, Jiaqi Zeng, Gerald Shen, Daniel Egert, Jimmy J. Zhang, Makesh Narsimhan Sreedhar, and Oleksii Kuchaiev. 2024b. HelpSteer2: Open-source dataset for training top-performing reward models. arxiv:2406.08673[cs].
- Weber et al. (2024) Maurice Weber, Daniel Y. Fu, Quentin Gregory Anthony, Yonatan Oren, Shane Adams, Anton Alexandrov, Xiaozhong Lyu, Huu Nguyen, Xiaozhe Yao, Virginia Adams, Ben Athiwaratkun, Rahul Chalamala, Kezhen Chen, Max Ryabinin, Tri Dao, Percy Liang, Christopher Re, Irina Rish, and Ce Zhang. 2024. RedPajama: An Open Dataset for Training Large Language Models. In The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track.
- Williams and Peng (1990) Ronald J. Williams and Jing Peng. 1990. An Efficient Gradient-Based Algorithm for On-Line Training of Recurrent Network Trajectories. Neural Computation, 2(4):490–501.
- Wortsman et al. (2023a) Mitchell Wortsman, Tim Dettmers, Luke Zettlemoyer, Ari Morcos, Ali Farhadi, and Ludwig Schmidt. 2023a. Stable and low-precision training for large-scale vision-language models. Advances in Neural Information Processing Systems, 36:10271–10298.
- Wortsman et al. (2023b) Mitchell Wortsman, Tim Dettmers, Luke Zettlemoyer, Ari S. Morcos, Ali Farhadi, and Ludwig Schmidt. 2023b. Stable and low-precision training for large-scale vision-language models. In Thirty-Seventh Conference on Neural Information Processing Systems.
- Wu and Stock (2024) Mengshiou Wu and Mark Stock. 2024. Enhancing PyTorch Performance on Frontier with the RCCL OFI-Plugin.
- Wu et al. (2022) Yuhuai Wu, Markus Norman Rabe, DeLesley Hutchins, and Christian Szegedy. 2022. Memorizing Transformers. In International Conference on Learning Representations.
- Wu et al. (2024) Zijian Wu, Jiayu Wang, Dahua Lin, and Kai Chen. 2024. LEAN-GitHub: Compiling GitHub LEAN repositories for a versatile LEAN prover. arxiv:2407.17227[cs].
- Xu et al. (2024) Zhangchen Xu, Fengqing Jiang, Luyao Niu, Yuntian Deng, Radha Poovendran, Yejin Choi, and Bill Yuchen Lin. 2024. Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing. arxiv:2406.08464[cs].
- Yang et al. (2023) Kaiyu Yang, Aidan M. Swope, Alex Gu, Rahul Chalamala, Peiyang Song, Shixing Yu, Saad Godil, Ryan Prenger, and Anima Anandkumar. 2023. LeanDojo: Theorem Proving with Retrieval-Augmented Language Models. In Thirty-Seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track.
- Yang et al. (2024a) Liu Yang, Kangwook Lee, Robert D. Nowak, and Dimitris Papailiopoulos. 2024a. Looped Transformers are Better at Learning Learning Algorithms. In The Twelfth International Conference on Learning Representations.
- Yang et al. (2024b) Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, and Yoon Kim. 2024b. Parallelizing Linear Transformers with the Delta Rule over Sequence Length. In The Thirty-eighth Annual Conference on Neural Information Processing Systems.
- Ying et al. (2024) Huaiyuan Ying, Zijian Wu, Yihan Geng, Jiayu Wang, Dahua Lin, and Kai Chen. 2024. Lean Workbook: A large-scale Lean problem set formalized from natural language math problems. In The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track.
- Yu et al. (2023) Longhui Yu, Weisen Jiang, Han Shi, Jincheng Yu, Zhengying Liu, Yu Zhang, James Kwok, Zhenguo Li, Adrian Weller, and Weiyang Liu. 2023. MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models. In The Twelfth International Conference on Learning Representations.
- Zamirai et al. (2021) Pedram Zamirai, Jian Zhang, Christopher R. Aberger, and Christopher De Sa. 2021. Revisiting BFloat16 Training. arxiv:2010.06192[cs, stat].
- Zelikman et al. (2024) Eric Zelikman, Georges Harik, Yijia Shao, Varuna Jayasiri, Nick Haber, and Noah D. Goodman. 2024. Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking. arxiv:2403.09629[cs].
- Zellers et al. (2019) Rowan Zellers, Ari Holtzman, Yonatan Bisk, Ali Farhadi, and Yejin Choi. 2019. Hellaswag: Can a machine really finish your sentence? In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics.
- Zhai et al. (2022) Xiaohua Zhai, Alexander Kolesnikov, Neil Houlsby, and Lucas Beyer. 2022. Scaling Vision Transformers. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 12104–12113.
- Zhang and Sennrich (2019) Biao Zhang and Rico Sennrich. 2019. Root Mean Square Layer Normalization. In Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc.
- Zhang et al. (2024a) Ge Zhang, Scott Qu, Jiaheng Liu, Chenchen Zhang, Chenghua Lin, Chou Leuang Yu, Danny Pan, Esther Cheng, Jie Liu, Qunshu Lin, Raven Yuan, Tuney Zheng, Wei Pang, Xinrun Du, Yiming Liang, Yinghao Ma, Yizhi Li, Ziyang Ma, Bill Lin, and 26 others. 2024a. MAP-Neo: Highly Capable and Transparent Bilingual Large Language Model Series. arxiv:2405.19327[cs].
- Zhang et al. (2024b) Jun Zhang, Jue Wang, Huan Li, Lidan Shou, Ke Chen, Gang Chen, and Sharad Mehrotra. 2024b. Draft& Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding. In Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 11263–11282, Bangkok, Thailand. Association for Computational Linguistics.
- Zhang et al. (2024c) Yifan Zhang, Yifan Luo, Yang Yuan, and Andrew C. Yao. 2024c. Autonomous Data Selection with Language Models for Mathematical Texts. In ICLR 2024 Workshop on Navigating and Addressing Data Problems for Foundation Models.
- Zheng et al. (2024) Tianyu Zheng, Ge Zhang, Tianhao Shen, Xueling Liu, Bill Yuchen Lin, Jie Fu, Wenhu Chen, and Xiang Yue. 2024. OpenCodeInterpreter: Integrating Code Generation with Execution and Refinement. In Findings of the Association for Computational Linguistics: ACL 2024, pages 12834–12859, Bangkok, Thailand. Association for Computational Linguistics.
- Zhou et al. (2024) Fan Zhou, Zengzhi Wang, Qian Liu, Junlong Li, and Pengfei Liu. 2024. Programming Every Example: Lifting Pre-training Data Quality like Experts at Scale. arxiv:2409.17115[cs].
- Zhuo et al. (2024) Terry Yue Zhuo, Minh Chien Vu, Jenny Chim, Han Hu, Wenhao Yu, Ratnadira Widyasari, Imam Nur Bani Yusuf, Haolan Zhan, Junda He, Indraneil Paul, Simon Brunner, Chen Gong, Thong Hoang, Armel Randy Zebaze, Xiaoheng Hong, Wen-Ding Li, Jean Kaddour, Ming Xu, Zhihan Zhang, and 14 others. 2024. BigCodeBench: Benchmarking Code Generation with Diverse Function Calls and Complex Instructions.
<details>
<summary>x17.png Details</summary>

### Visual Description
## Histogram: Comparison of Continuous CoT vs Default Compute Histogram Distribution of Steps to Convergence
### Overview
The image presents a series of histograms comparing the distribution of steps to convergence for two methods: "Continuous CoT" and "Default". The histograms are arranged in a 3x4 grid, each representing a different subject or scenario. The x-axis represents the "Steps to Convergence," and the y-axis represents the "Density." Each histogram displays two distributions, one for Continuous CoT and one for Default, with their respective mean (μ) values indicated in the legend.
### Components/Axes
* **Title:** Comparison of Continuous CoT vs Default Compute Histogram Distribution of Steps to Convergence
* **X-axis:** Steps to Convergence (Scale: 0 to 60)
* **Y-axis:** Density (Scale: 0.00 to varying maximum values, approximately 0.06-0.08)
* **Legends:** Located at the top-right of each subplot.
* Continuous CoT (μ=value): Represented by a solid color with diagonal lines. Colors vary by subplot (green, blue, purple, peach, yellow, red, etc.)
* Default (μ=value): Represented by a transparent color with a black outline.
* **Subplot Titles:** Each subplot has a title indicating the subject or scenario (e.g., "high school mathematics," "machine learning," "clinical knowledge," etc.).
### Detailed Analysis
The histograms are organized into a 3x4 grid. Each subplot compares the "Continuous CoT" and "Default" methods for a specific topic.
**Row 1:**
* **high school mathematics:**
* Continuous CoT (Green): μ=11.9. The distribution peaks around 5-10 steps and then decreases.
* Default (Transparent with Black Outline): μ=12.7. The distribution is similar to Continuous CoT but slightly shifted to the right.
* **machine learning:**
* Continuous CoT (Blue): μ=13.6. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=14.2. The distribution is similar to Continuous CoT but slightly shifted to the right.
* **clinical knowledge:**
* Continuous CoT (Purple): μ=13.8. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=14.7. The distribution is similar to Continuous CoT but slightly shifted to the right.
**Row 2:**
* **moral disputes:**
* Continuous CoT (Peach): μ=13.5. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=14.5. The distribution is similar to Continuous CoT but slightly shifted to the right.
* **philosophy:**
* Continuous CoT (Yellow): μ=13.5. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=14.6. The distribution is similar to Continuous CoT but slightly shifted to the right.
* **world religions:**
* Continuous CoT (Red): μ=14.4. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=15.1. The distribution is similar to Continuous CoT but slightly shifted to the right.
**Row 3:**
* **high school world history:**
* Continuous CoT (Light Red): μ=15.6. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=15.8. The distribution is similar to Continuous CoT but slightly shifted to the right.
* **logical fallacies:**
* Continuous CoT (Pink): μ=14.4. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=15.6. The distribution is similar to Continuous CoT but slightly shifted to the right.
* **medical genetics:**
* Continuous CoT (Light Purple): μ=13.2. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=14.0. The distribution is similar to Continuous CoT but slightly shifted to the right.
**Row 4:**
* **professional law:**
* Continuous CoT (Light Blue): μ=15.1. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=16.0. The distribution is similar to Continuous CoT but slightly shifted to the right.
* **moral scenarios:**
* Continuous CoT (Rose): μ=16.0. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=16.2. The distribution is similar to Continuous CoT but slightly shifted to the right.
* **abstract algebra:**
* Continuous CoT (Pale Yellow): μ=12.8. The distribution peaks around 10-15 steps and then decreases.
* Default (Transparent with Black Outline): μ=13.6. The distribution is similar to Continuous CoT but slightly shifted to the right.
### Key Observations
* The "Default" method consistently has a higher mean (μ) value than the "Continuous CoT" method across all subjects/scenarios.
* The distributions for both methods are generally right-skewed, indicating that most convergences occur within a relatively small number of steps, but some require significantly more steps.
* The shapes of the distributions are similar for both methods within each subject/scenario, suggesting that the underlying convergence process is similar, but the "Continuous CoT" method tends to converge slightly faster.
### Interpretation
The data suggests that the "Continuous CoT" method generally leads to faster convergence compared to the "Default" method, as indicated by the lower mean number of steps to convergence across various subjects/scenarios. The consistent right-skewness of the distributions implies that while most cases converge quickly, there are instances where both methods require a significantly larger number of steps. The similarity in distribution shapes between the two methods within each subject suggests that "Continuous CoT" optimizes the convergence process without fundamentally altering its nature. The consistent difference in means suggests a systematic advantage of "Continuous CoT" over the "Default" method.
</details>
Figure 13: Additional categories for Figure 10 in the main body.
Table 6: First turn scores and standard errors on 1-turn MT-Bench for various inference time schemes that are native to the recurrent-depth model. Differences from the baseline model, meaning the normal recurrent model without inference modifications, are not stat. significant.
| cache compression, $s=4$ baseline, 64 iterations cache compression, $s=16$ | 5.856 5.693 5.687 | 0.395 0.386 0.402 |
| --- | --- | --- |
| baseline, 32 iterations | 5.662 | 0.388 |
| cache compression, $s=8$ | 5.631 | 0.384 |
| KL exit, $t=$5\text{×}{10}^{-4}$$ | 5.562 | 0.389 |
Appendix A Additional Information
<details>
<summary>x18.png Details</summary>

### Visual Description
## Heatmap: Addition Accuracy by Number of Operands
### Overview
The image is a heatmap visualizing the addition accuracy based on the number of operands and the number of digits. The x-axis represents the number of digits, ranging from 1 to 6. The y-axis represents the number of operands, ranging from 2 to 6. The color intensity represents the accuracy, with darker blue indicating higher accuracy and lighter blue indicating lower accuracy. A colorbar on the right side of the heatmap shows the mapping between color intensity and accuracy values, ranging from 0.0 to 1.0.
### Components/Axes
* **Title:** Addition Accuracy by Number of Operands
* **X-axis:** Number of Digits, with ticks at 1, 2, 3, 4, 5, and 6.
* **Y-axis:** Number of Operands, with labels "2 Operands", "3 Operands", "4 Operands", "5 Operands", and "6 Operands".
* **Colorbar:** Ranges from 0.0 (lightest blue) to 1.0 (darkest blue), with ticks at 0.0, 0.2, 0.4, 0.6, 0.8.
### Detailed Analysis
The heatmap displays accuracy values for each combination of the number of operands and the number of digits.
* **2 Operands:**
* 1 Digit: 1.0
* 2 Digits: 1.0
* 3 Digits: 0.8
* 4 Digits: 0.7
* 5 Digits: 0.6
* 6 Digits: 0.5
* **3 Operands:**
* 1 Digit: 0.7
* 2 Digits: 0.4
* 3 Digits: 0.2
* 4 Digits: 0.0
* 5 Digits: 0.0
* 6 Digits: 0.0
* **4 Operands:**
* 1 Digit: 0.3
* 2 Digits: 0.0
* 3 Digits: 0.0
* 4 Digits: 0.0
* 5 Digits: 0.0
* 6 Digits: 0.0
* **5 Operands:**
* 1 Digit: 0.1
* 2 Digits: 0.0
* 3 Digits: 0.0
* 4 Digits: 0.0
* 5 Digits: 0.0
* 6 Digits: 0.0
* **6 Operands:**
* 1 Digit: 0.0
* 2 Digits: 0.0
* 3 Digits: 0.0
* 4 Digits: 0.0
* 5 Digits: 0.0
* 6 Digits: 0.0
### Key Observations
* Accuracy generally decreases as the number of digits increases for a fixed number of operands.
* Accuracy decreases as the number of operands increases for a fixed number of digits.
* The highest accuracy is achieved with 2 operands and 1 or 2 digits (accuracy = 1.0).
* For 4 or more operands, the accuracy is 0.0 for 2 or more digits.
### Interpretation
The heatmap suggests that addition accuracy is significantly affected by both the number of operands and the number of digits. The model performs best with fewer operands and fewer digits. As the complexity of the addition problem increases (more operands and more digits), the accuracy drops considerably, indicating a potential limitation in the model's ability to handle complex arithmetic operations. The sharp drop in accuracy with more than 3 operands suggests a possible threshold in the model's capacity or a need for more sophisticated training data for complex additions.
</details>
<details>
<summary>x19.png Details</summary>

### Visual Description
## Line Chart: Model Accuracy vs Number of Operands for Different Recurrence Levels
### Overview
The image is a line chart that illustrates the relationship between model accuracy and the number of operands used, across different recurrence levels. The x-axis represents the number of operands, ranging from 2 to 6. The y-axis represents the accuracy, ranging from 0.0 to 1.0. Multiple lines are plotted, each representing a different recurrence level (1, 2, 4, 8, 16, 24, 32, 48, and 64). The chart aims to show how the model's accuracy changes with an increasing number of operands at various recurrence levels.
### Components/Axes
* **Title:** Model Accuracy vs Number of Operands (digits=1) for Different Recurrence Levels
* **X-axis:**
* Label: Number of Operands
* Scale: 2, 3, 4, 5, 6
* **Y-axis:**
* Label: Accuracy
* Scale: 0.0, 0.2, 0.4, 0.6, 0.8, 1.0
* **Legend:** Located on the top-right of the chart. It maps the line colors to the recurrence levels.
* Blue: Recurrence 1
* Orange: Recurrence 2
* Green: Recurrence 4
* Red: Recurrence 8
* Purple: Recurrence 16
* Brown: Recurrence 24
* Pink: Recurrence 32
* Gray: Recurrence 48
* Yellow-Green: Recurrence 64
### Detailed Analysis
* **Recurrence 1 (Blue):** The accuracy is relatively constant and low, around 0.03-0.04 across all operand numbers.
* (2, 0.04), (3, 0.03), (4, 0.02), (5, 0.01), (6, 0.02)
* **Recurrence 2 (Orange):** The accuracy is also relatively constant and low, around 0.01-0.02 across all operand numbers.
* (2, 0.01), (3, 0.01), (4, 0.01), (5, 0.11), (6, 0.10)
* **Recurrence 4 (Green):** The accuracy starts relatively high but drops sharply between 2 and 3 operands, then remains low.
* (2, 0.44), (3, 0.04), (4, 0.02), (5, 0.01), (6, 0.01)
* **Recurrence 8 (Red):** The accuracy starts high and decreases significantly as the number of operands increases.
* (2, 0.96), (3, 0.34), (4, 0.02), (5, 0.01), (6, 0.01)
* **Recurrence 16 (Purple):** The accuracy starts high and decreases significantly as the number of operands increases.
* (2, 0.98), (3, 0.72), (4, 0.38), (5, 0.04), (6, 0.02)
* **Recurrence 24 (Brown):** The accuracy starts high and decreases significantly as the number of operands increases.
* (2, 0.99), (3, 0.74), (4, 0.36), (5, 0.12), (6, 0.11)
* **Recurrence 32 (Pink):** The accuracy starts high and decreases significantly as the number of operands increases.
* (2, 0.95), (3, 0.56), (4, 0.24), (5, 0.03), (6, 0.11)
* **Recurrence 48 (Gray):** The accuracy starts high and decreases significantly as the number of operands increases.
* (2, 1.00), (3, 0.73), (4, 0.31), (5, 0.11), (6, 0.02)
* **Recurrence 64 (Yellow-Green):** The accuracy starts high and decreases significantly as the number of operands increases.
* (2, 0.99), (3, 0.77), (4, 0.39), (5, 0.12), (6, 0.11)
### Key Observations
* Higher recurrence levels (8, 16, 24, 32, 48, 64) generally exhibit higher accuracy when the number of operands is low (2-3).
* As the number of operands increases, the accuracy for higher recurrence levels drops significantly.
* Recurrence levels 1 and 2 show consistently low accuracy regardless of the number of operands.
* The accuracy for most recurrence levels converges to a low value (around 0.01-0.11) as the number of operands reaches 5 and 6.
### Interpretation
The data suggests that the model's performance, in terms of accuracy, is highly dependent on the recurrence level and the number of operands used. Higher recurrence levels are more effective when dealing with a smaller number of operands, possibly indicating that they are better at capturing complex relationships within simpler expressions. However, as the complexity of the expression increases (more operands), the accuracy of these higher recurrence levels diminishes, suggesting a potential overfitting issue or difficulty in generalizing to more complex scenarios. The consistently low accuracy of recurrence levels 1 and 2 indicates that these levels may be insufficient to capture the underlying patterns in the data, regardless of the number of operands. The convergence of accuracy at higher operand counts suggests a limit to the model's ability to handle very complex expressions, irrespective of the recurrence level.
</details>
<details>
<summary>x20.png Details</summary>

### Visual Description
## Chart: Model Accuracy vs Number of Operands for Different Recurrence Levels
### Overview
The image is a line chart that displays the model accuracy versus the number of operands (digits=2) for different recurrence levels. The x-axis represents the number of operands, ranging from 2 to 6. The y-axis represents the accuracy, ranging from 0.0 to 1.0. There are nine different recurrence levels plotted on the chart, each represented by a different colored line.
### Components/Axes
* **Title:** Model Accuracy vs Number of Operands (digits=2) for Different Recurrence Levels
* **X-axis:** Number of Operands (values: 2, 3, 4, 5, 6)
* **Y-axis:** Accuracy (values: 0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
* **Legend:** Located on the top-right of the chart.
* Recurrence 1 (blue)
* Recurrence 2 (orange)
* Recurrence 4 (green)
* Recurrence 8 (red)
* Recurrence 16 (purple)
* Recurrence 24 (brown)
* Recurrence 32 (pink)
* Recurrence 48 (gray)
* Recurrence 64 (yellow)
### Detailed Analysis
Here's a breakdown of each recurrence level's trend and approximate data points:
* **Recurrence 1 (blue):** Generally decreasing trend.
* (2, 0.90)
* (3, 0.55)
* (4, 0.20)
* (5, 0.40)
* (6, 0.25)
* **Recurrence 2 (orange):** Relatively flat at 0.
* (2, 0.00)
* (3, 0.00)
* (4, 0.00)
* (5, 0.00)
* (6, 0.00)
* **Recurrence 4 (green):** Relatively flat, close to 0.
* (2, 0.10)
* (3, 0.00)
* (4, 0.02)
* (5, 0.06)
* (6, 0.02)
* **Recurrence 8 (red):** Starts high, drops significantly, then remains low.
* (2, 0.70)
* (3, 0.06)
* (4, 0.02)
* (5, 0.06)
* (6, 0.02)
* **Recurrence 16 (purple):** Decreases, then increases, then decreases.
* (2, 0.90)
* (3, 0.50)
* (4, 0.15)
* (5, 0.50)
* (6, 0.15)
* **Recurrence 24 (brown):** Decreases, then increases, then decreases.
* (2, 0.90)
* (3, 0.55)
* (4, 0.10)
* (5, 0.52)
* (6, 0.12)
* **Recurrence 32 (pink):** Decreases, then increases, then decreases.
* (2, 0.95)
* (3, 0.38)
* (4, 0.18)
* (5, 0.58)
* (6, 0.15)
* **Recurrence 48 (gray):** Decreases, then increases, then decreases.
* (2, 0.88)
* (3, 0.52)
* (4, 0.12)
* (5, 0.40)
* (6, 0.20)
* **Recurrence 64 (yellow):** Decreases, then increases, then decreases.
* (2, 1.00)
* (3, 0.52)
* (4, 0.18)
* (5, 0.46)
* (6, 0.10)
### Key Observations
* Recurrence levels 2, 4, and 8 consistently show low accuracy across all numbers of operands.
* Recurrence levels 16, 24, 32, 48, and 64 show a similar trend: high accuracy at 2 operands, a significant drop at 3 and 4 operands, a peak at 5 operands, and a drop again at 6 operands.
* The accuracy for most recurrence levels is highest when the number of operands is 2.
* The accuracy tends to dip when the number of operands is 4.
### Interpretation
The chart suggests that the model's accuracy is significantly affected by the recurrence level and the number of operands. Lower recurrence levels (2, 4, 8) generally perform poorly, indicating that these levels might not be sufficient for the model to learn effectively. Higher recurrence levels (16, 24, 32, 48, 64) show a more complex relationship with the number of operands, suggesting that there might be an optimal number of operands for these levels to achieve the best accuracy. The dip in accuracy at 4 operands for higher recurrence levels could indicate a point where the model struggles to generalize or encounters some form of interference. The overall trend indicates that the model performs best with a smaller number of operands (2) and that increasing the number of operands does not necessarily lead to improved accuracy.
</details>
<details>
<summary>x21.png Details</summary>

### Visual Description
## Line Chart: Model Accuracy vs Number of Operands for Different Recurrence Levels
### Overview
The image is a line chart that displays the relationship between model accuracy and the number of operands, with different lines representing different recurrence levels. The chart aims to show how the model's performance changes as the number of operands increases, and how this relationship varies across different recurrence levels. The number of digits is fixed at 3.
### Components/Axes
* **Title:** Model Accuracy vs Number of Operands (digits=3) for Different Recurrence Levels
* **X-axis:** Number of Operands, with values 2, 3, 4, 5, and 6.
* **Y-axis:** Accuracy, ranging from 0.0 to approximately 0.9, with increments of 0.2.
* **Legend:** Located on the top-right of the chart, it identifies each line by its recurrence level:
* Recurrence 1 (light blue)
* Recurrence 2 (orange)
* Recurrence 4 (green)
* Recurrence 8 (red)
* Recurrence 16 (purple)
* Recurrence 24 (brown)
* Recurrence 32 (pink)
* Recurrence 48 (gray)
* Recurrence 64 (yellow)
### Detailed Analysis
* **Recurrence 1 (light blue):** Starts at approximately 0.82 accuracy with 2 operands, drops to approximately 0.32 with 3 operands, increases slightly to approximately 0.34 with 4 operands, then decreases to approximately 0.18 with 5 operands, and finally drops to approximately 0.01 with 6 operands.
* **Recurrence 2 (orange):** Starts at approximately 0.88 accuracy with 2 operands, drops to approximately 0.29 with 3 operands, increases slightly to approximately 0.32 with 4 operands, then decreases to approximately 0.12 with 5 operands, and finally drops to approximately 0.01 with 6 operands.
* **Recurrence 4 (green):** Starts at approximately 0.02 accuracy with 2 operands, increases slightly to approximately 0.28 with 3 operands, increases slightly to approximately 0.35 with 4 operands, then decreases to approximately 0.10 with 5 operands, and finally drops to approximately 0.01 with 6 operands.
* **Recurrence 8 (red):** Starts at approximately 0.30 accuracy with 2 operands, drops to approximately 0.00 with 3 operands, remains at approximately 0.00 with 4 operands, remains at approximately 0.00 with 5 operands, and remains at approximately 0.00 with 6 operands.
* **Recurrence 16 (purple):** Starts at approximately 0.82 accuracy with 2 operands, drops to approximately 0.30 with 3 operands, increases slightly to approximately 0.33 with 4 operands, then decreases to approximately 0.18 with 5 operands, and finally drops to approximately 0.01 with 6 operands.
* **Recurrence 24 (brown):** Starts at approximately 0.80 accuracy with 2 operands, drops to approximately 0.30 with 3 operands, decreases to approximately 0.24 with 4 operands, then decreases to approximately 0.12 with 5 operands, and finally drops to approximately 0.01 with 6 operands.
* **Recurrence 32 (pink):** Starts at approximately 0.84 accuracy with 2 operands, drops to approximately 0.32 with 3 operands, increases slightly to approximately 0.34 with 4 operands, then decreases to approximately 0.20 with 5 operands, and finally drops to approximately 0.01 with 6 operands.
* **Recurrence 48 (gray):** Starts at approximately 0.82 accuracy with 2 operands, drops to approximately 0.29 with 3 operands, increases slightly to approximately 0.31 with 4 operands, then decreases to approximately 0.10 with 5 operands, and finally drops to approximately 0.01 with 6 operands.
* **Recurrence 64 (yellow):** Starts at approximately 0.90 accuracy with 2 operands, drops to approximately 0.32 with 3 operands, increases slightly to approximately 0.36 with 4 operands, then decreases to approximately 0.14 with 5 operands, and finally drops to approximately 0.01 with 6 operands.
### Key Observations
* For most recurrence levels (1, 2, 4, 16, 24, 32, 48, 64), the model accuracy is high when the number of operands is 2, but it drops significantly when the number of operands increases to 3.
* After the initial drop, the accuracy tends to fluctuate slightly between 3 and 4 operands before decreasing again as the number of operands increases to 5 and 6.
* Recurrence level 8 shows a consistently low accuracy across all numbers of operands, remaining close to 0.
* All recurrence levels converge to a very low accuracy (close to 0) when the number of operands is 6.
### Interpretation
The data suggests that the model's accuracy is highly sensitive to the number of operands. The initial high accuracy with 2 operands indicates that the model performs well on simpler problems. However, as the complexity increases (more operands), the accuracy drops, suggesting that the model struggles with more complex calculations. The different recurrence levels show varying degrees of performance, with recurrence level 8 being particularly poor. The convergence of all recurrence levels to low accuracy at 6 operands indicates a general limitation of the model when dealing with a higher number of operands. This could be due to factors such as increased computational complexity, vanishing gradients, or overfitting on simpler cases.
</details>
Figure 14: Multi-Operand Arithmetic. Following a precedent of training recurrent architectures for algorithmic and arithmetic tasks (Schwarzschild et al., 2021b; Bansal et al., 2022; Schwarzschild et al., 2023; McLeish et al., 2024), we explore whether our model can leverage increased test-time compute via recurrence to solve verbalized addition problems of increased difficulty. For these problems we use the following system prompt ‘‘You are a helpful assistant that is capable of helping users with mathematical reasoning.’’ embedded in a conversational chat template, and we present each problem by opening the first user turn of the conversation like so: f"What is the result of ’ + ’.join(map(str, digits))?" after randomly sampling numbers according to a certain operand count and digit count (base 10). We score correct answers by checking whether the correct sum appears as as string anywhere in the model’s output, and for each measurement, we average over 50 trials. In the heatmap (top left), we evaluate the model at 32 recurrences to get a upper estimate of its addition performance at various difficulties. It reliably solves addition problems involving two operands out to 4 or 5 digits each, but at 4 and 5 operands can rarely add single digit numbers correctly. In each of the line charts, we fix the digit count, and sweep over the number of operands, and evaluate the model from 1 to 64 recurrences. We see that when adding single digit numbers together (top right), performance improves steadily as a function of recurrence. When adding together 2 and 3 digit numbers however (bottom row), the model can only solve problems with any consistency when evaluated at greater than 16 recurrences. Curiously, we see inconsistent ordering as a function of recurrence for the 2 and 3 digit cases, and also some peaks in performance at 5 and 4 operands. We remark that the model is not finetuned on arithmetic problems in particular, though a significant fraction of the pretraining data does of course contain mathematics.
Potential Implications of This Work
This work describes a novel architecture and training objective for language modeling with promising performance, especially on tasks that require the model to reason. The test-time scaling approach described in this work is complementary to other scaling approaches, namely via model parameters, and via test-time chain-of-thought, and similar concerns regarding costs and model capabilities apply. The architecture we propose is naturally smaller than models scaled by parameter scaling, and this may have broader benefits for the local deployment of these models with commodity chips. Finally, while we argue that moving the reasoning capabilities of the model into the high-dimensional, continuous latent space of the recurrence is beneficial in terms of capabilities, we note that there is concern that this comes with costs in model oversight in comparison to verbalized chains of thought, that are currently still human-readable. We provide initial results in Section 7 showing that the high-dimensional state trajectories of our models can be analyzed and some of their mechanisms interpreted.
A.1 Classical Reasoning Problems
We include a small study of the classical problem of multi-operand arithmetic in Figure 14.
A.2 Implementation Details
Device Speed Details
Nominally, each MI250X (AMD, 2021) achieves 383 TFLOP in bfloat16, i.e. 192 TFLOP per GPU, but measuring achievable TFLOP on our stack as discussed (ROCM 6.2.0, PyTorch 2.6 pre-release 11/02) for arbitrary matrix multiplication shapes (i.e. we measure the peak achievable speed of the best possible shape iterating over shapes between 256 and 24576 in intervals of 256 and 110 (Bekman, 2023)), we measure a peak of 125 TFLOP/s on Frontier nodes. Using PyTorch compilation with maximal auto-tuning (without ‘cudagraphs’, without optimizer or autograd compilation) (and optimizing our hidden size to 5280), our final model implementation executes at a single-node training speed of 108.75 TFLOP/s, i.e. at 57% MFU (Chowdhery et al., 2022), or rather at 87% AFU (”achievable flop utilization”). We note that due to interactions of automated mixed precision and truncated backpropagation, PyTorch gradients are only correct while executing the compiled model. We further circumvent issues with the flash attention implementation shipped with PyTorch sdpa using the AMD fork of the original flash attention repository https://github.com/Dao-AILab/flash-attention/, which can be found at https://github.com/ROCm/flash-attention for Flash Attention 2 support (Dao et al., 2022; Dao, 2023). We experiment with fused head and loss implementations https://github.com/JonasGeiping/linear_cross_entropy_loss, but ultimately find that the most portable choice on our AMD setup is to let torch compilation handle this issue.
Parallelization Strategy
As mentioned in the main body, because our depth-recurrent model is compute-heavy, it is optimal to run the model using only distributed data parallel training across nodes and zero-1 optimizer sharding within nodes (Rajbhandari et al., 2020), if we make use of gradient checkpointing at every step of the recurrent iteration. This allows us to eschew more communication-heavy parallelization strategies that would be required for models with the same FLOP footprint, but more parameters, which require substantial planning on this system (Singh et al., 2024; Singh and Bhatele, 2022). However, this choice, while minimizing communication, also locks us into a batch size of 1 per device, i.e. 4096 in total, and 16M tokens per step.
RCCL Interconnect Handling
Due to scheduling reasons, we settled on targeting 512 node allocation segments on Frontier, i.e. 4096 GPUs. However, this posed a substantial network interconnect issue. The connection speed between frontier nodes is only acceptable, if RCCL (AMD GPU communication collectives) commands are routed through open fabrics interface calls, which happens via a particular plugin https://github.com/ROCm/aws-ofi-rccl. To achieve sufficient bus bandwidth above 100GB/s requires NCCL_NET_GDR_LEVEL=PHB, a setting that, on NVIDIA systems, allows packages to go through the CPU, and only uses direct interconnect if GPU and NIC are on the same (NUMA) node (Wu and Stock, 2024). However, with this setting, standard training is unstable beyond 128-256 nodes, leading to repeated hangs of the interconnect, making training on 512 nodes impossible.
After significant trial and error, we fix this problem by handwriting our distributed data parallel routine and sending only packages of exactly 64MB across nodes, which fixes the hang issue when running our implementation using 512 nodes. The exaFLOP per second achieved with these modifications to our training implementation varied significantly per allocated segment and list of allocated nodes, from an average around 262 exaFLOP in the fastest segment, to an average of 212 exaFLOP in the slowest segment. This is a range of 52-64 TFLOP/s per GPU, i.e. 41%-51% AFU, or 1-1.2M tokens per second.
Pretraining Metrics.
During the pretraining run, we run a careful tracking of optimizer and model health metrics, tracking effective Adam learning rates per layer, optimizer RMS (Wortsman et al., 2023a), $L^{2}$ and $L^{1}$ parameter and gradient norms, recurrence statistics such as $\frac{||s_{k}-s_{k-1}||}{||s_{k}||}$ , $||s_{k}||$ , $||s_{0}-s_{k}||$ . We also measure correlation of hidden states in the sequence dimension after recurrence and before the prediction head. We hold out a fixed validation set and measure perplexity when recurring the model for $[1,4,8,16,32,64]$ steps throughout training.
Appendix B Latent Space Visualizations
On the next pages, we print a number of latent space visualizations in more details than was possible in Section 7. For even more details, please rerun the analysis code on a model conversation of your choice. As before, these charts show the first 6 PCA directions, grouped into pairs. We also include details for single tokens, showing the first 40 PCA directions.
<details>
<summary>extracted/6211213/figures/latent_waterfall_C_bright.png Details</summary>

### Visual Description
## 3D Scatter Plot: PCA of Token Positions
### Overview
The image is a 3D scatter plot visualizing the relationship between token position in a sequence and two PCA directions. Each point represents a token, with its color varying from purple to orange. The plot shows how tokens are distributed in the 3D space defined by these three variables.
### Components/Axes
* **X-axis:** PCA Direction 1, ranging from approximately -40 to 40.
* **Y-axis:** PCA Direction 2, ranging from approximately -40 to 40.
* **Z-axis:** Token Position in Sequence, ranging from 0 to 350.
* **Data Points:** Each data point is represented by a circle, with color varying from purple to orange. The color gradient is not explicitly defined by a legend, but it appears to represent some underlying variable or cluster.
* **Grid Lines:** Gray grid lines are present on all three planes, aiding in the visualization of data point positions.
### Detailed Analysis
The data points are clustered in a non-uniform distribution.
* **Token Position vs. PCA Directions:**
* At lower token positions (0-100), the data points are spread across a wider range of PCA Direction 1 and PCA Direction 2 values.
* As the token position increases (100-350), the data points tend to cluster more closely around the PCA Direction 2 axis, with PCA Direction 1 values remaining relatively constant.
* **Color Distribution:**
* The data points near the lower token positions (0-100) show a mix of purple and orange colors.
* As the token position increases, the data points tend to be more purple.
* **Specific Data Points:**
* There is a dense cluster of purple points along the Z-axis (Token Position) near PCA Direction 1 = 0 and PCA Direction 2 = 0.
* There are scattered orange points throughout the plot, but they are more prevalent at lower token positions.
### Key Observations
* The token position in the sequence appears to be correlated with the PCA directions.
* The data points cluster more tightly along the PCA Direction 2 axis as the token position increases.
* The color gradient suggests a possible underlying variable or cluster that is related to both token position and PCA directions.
### Interpretation
The 3D scatter plot suggests that the token position in the sequence influences its representation in the PCA space. The clustering of data points at higher token positions indicates that these tokens may share similar characteristics or contexts, as captured by the PCA directions. The color gradient could represent different types of tokens or different stages in the sequence. Further analysis would be needed to determine the exact meaning of the PCA directions and the underlying variable represented by the color gradient. The plot highlights the potential for using PCA to analyze and understand the structure of token sequences.
</details>
<details>
<summary>extracted/6211213/figures/latent_waterfall_W_bright.png Details</summary>

### Visual Description
## 3D Scatter Plot: Token Position vs. PCA Directions
### Overview
The image is a 3D scatter plot visualizing the relationship between token position in a sequence and two PCA (Principal Component Analysis) directions. Each point represents a token, and its color varies from blue to yellow, potentially indicating another dimension of data. The points are connected by lines, suggesting a sequence or flow.
### Components/Axes
* **X-axis:** PCA Direction 1, ranging from approximately -40 to 40.
* **Y-axis:** PCA Direction 2, ranging from approximately -40 to 40.
* **Z-axis:** Token Position in Sequence, ranging from 0 to 500.
* **Data Points:** Colored points, varying from blue to yellow, connected by lines.
### Detailed Analysis
The data points form a complex structure in 3D space.
* **Initial Cluster:** A dense cluster of points exists at lower token positions (0-100), spanning a wide range of PCA Direction 1 and PCA Direction 2 values. The color of these points is predominantly blue, transitioning to purple and orange.
* **Vertical Structure:** A vertical structure extends upwards along the Z-axis (Token Position) from the initial cluster. This structure is concentrated around PCA Direction 1 = 0 and PCA Direction 2 = 0. The color of these points is predominantly purple and blue.
* **Horizontal Extension:** At higher token positions (around 400-500), the points extend horizontally along both PCA Direction 1 and PCA Direction 2. The color of these points is predominantly orange and yellow.
* **Connections:** The points are connected by lines, indicating a sequence or flow. The lines are thin and gray, making it difficult to discern the exact order of connections.
### Key Observations
* The token positions are not uniformly distributed. There's a concentration at the beginning and a spread at the end.
* The PCA directions seem to separate tokens based on their position in the sequence.
* The color gradient suggests a possible third dimension of information related to the tokens.
### Interpretation
The plot likely represents the embedding of tokens in a high-dimensional space, reduced to two principal components (PCA Directions 1 and 2). The token position in the sequence is then plotted against these components.
The initial cluster suggests that the early tokens in the sequence have similar embeddings. The vertical structure indicates that some tokens maintain similar PCA direction values as the sequence progresses. The horizontal extension at higher token positions suggests that the later tokens diverge in their embeddings.
The color gradient could represent the frequency of the token, its importance, or another relevant feature. Further information about the color mapping would be needed to confirm this.
The connections between points indicate the sequential relationship between tokens. Analyzing the path of these connections could reveal patterns in how the token embeddings evolve over the sequence.
</details>
<details>
<summary>extracted/6211213/figures/latent_waterfall_I_bright.png Details</summary>

### Visual Description
## 3D Scatter Plot: Token Position vs. PCA Directions
### Overview
The image is a 3D scatter plot visualizing the relationship between token position in a sequence and two PCA (Principal Component Analysis) directions. Each point represents a token, and the color of the point likely indicates some other property of the token, transitioning from blue to orange. The points are connected by gray lines, showing the sequence of tokens.
### Components/Axes
* **X-axis:** PCA Direction 1, ranging from approximately -40 to 40.
* **Y-axis:** PCA Direction 2, ranging from approximately -40 to 40.
* **Z-axis:** Token Position in Sequence, ranging from 0 to 140.
* **Data Points:** Colored circles, transitioning from blue to orange.
* **Connecting Lines:** Gray lines connecting the data points, indicating sequence.
### Detailed Analysis
The data points are distributed throughout the 3D space, with a higher concentration in the middle range of the "Token Position in Sequence" axis (around 60-100).
* **PCA Direction 1:** The data points are relatively evenly distributed between -40 and 40.
* **PCA Direction 2:** Similar to PCA Direction 1, the data points are relatively evenly distributed between -40 and 40.
* **Token Position in Sequence:**
* A cluster of points exists at lower token positions (0-40), primarily blue.
* A denser cluster is present between token positions 60 and 120, with a mix of colors.
* Fewer points are observed at higher token positions (120-140), primarily purple/pink.
The color gradient from blue to orange appears to correlate with the position of the token in the sequence, with blue points tending to be at the beginning and orange points appearing later in the sequence.
### Key Observations
* The token positions are not uniformly distributed; there's a concentration in the middle.
* The PCA directions seem to have a relatively even distribution of tokens.
* The color gradient suggests a relationship between token position and some other variable represented by the color.
### Interpretation
The plot visualizes how tokens are distributed in a 3D space defined by their position in a sequence and their projections onto the first two principal components. The clustering of tokens in certain regions of the PCA space might indicate patterns or relationships within the sequence. The color gradient adds another layer of information, potentially representing the frequency, importance, or some other characteristic of the tokens. The connections between the points show the flow of the sequence through this 3D space.
</details>
Figure 15: Main directions in latent space, for a) a math question, 2) a trivia question and 3) an unsafe question, which will be described in more detail below. Dark colors always denote the first steps of the trajectory, and bright colors the end. Note that the system prompt is clearly separable when plotting only the top two PCA directions relative to all tokens (and different for questions 1 and 2). Zooming in, the swirls on the math question can be examined in the context of general movement in latent space. More detailed visualizations follow on later pages.
<details>
<summary>x22.png Details</summary>

### Visual Description
## Scatter Plot Matrix: Token Embeddings in PCA Space
### Overview
The image presents a 3x5 grid of scatter plots, visualizing the embeddings of different tokens ("Cla", "ire", " makes", " a", " 3") in a reduced-dimensional space defined by Principal Component Analysis (PCA). Each row corresponds to a specific token, and each column represents a different pair of principal components (PC1-PC2, PC3-PC4, PC5-PC6). The plots show the trajectory of the token embedding over time, with each point representing the embedding at a particular time step. A red 'X' marks the final embedding position. The density of points is indicated by a green-yellow-blue color gradient.
### Components/Axes
Each scatter plot has the following components:
* **Title:** Indicates the token and the principal components being plotted (e.g., "Token: 'Cla' PC1-PC2").
* **X-axis:** Represents the first principal component in the pair (e.g., PC1, PC3, PC5).
* **Y-axis:** Represents the second principal component in the pair (e.g., PC2, PC4, PC6).
* **Data Points:** Blue-purple dots connected by light blue-grey lines, showing the trajectory of the token embedding.
* **Final Embedding:** A red 'X' marks the final position of the token embedding.
* **Density Gradient:** A green-yellow-blue color gradient indicates the density of points, with blue representing lower density and yellow/green representing higher density.
* **Axis Gridlines:** Dashed grey lines mark the zero positions on both axes.
The axis ranges vary between plots. Here's a summary:
| Plot | X-axis Range | Y-axis Range |
|---------------|--------------|--------------|
| Cla PC1-PC2 | -7 to 7 | -7 to 7 |
| Cla PC3-PC4 | -4 to 4 | -8 to 8 |
| Cla PC5-PC6 | -9 to 9 | -3 to 3 |
| ire PC1-PC2 | -8 to 8 | -7 to 7 |
| ire PC3-PC4 | -4 to 4 | -6 to 6 |
| ire PC5-PC6 | -9 to 9 | -5 to 5 |
| makes PC1-PC2 | -14 to 14 | -4 to 4 |
| makes PC3-PC4 | -8 to 8 | -7 to 7 |
| makes PC5-PC6 | -12 to 12 | -10 to 10 |
| a PC1-PC2 | -14 to 14 | -7 to 7 |
| a PC3-PC4 | -8 to 8 | -7 to 7 |
| a PC5-PC6 | -12 to 12 | -7 to 7 |
| 3 PC1-PC2 | -10 to 10 | -4 to 4 |
| 3 PC3-PC4 | -13 to 13 | -13 to 13 |
| 3 PC5-PC6 | -14 to 14 | -6 to 6 |
### Detailed Analysis
Each row represents a token, and the columns show the token's trajectory in different PCA spaces.
* **Token: "Cla"**
* **PC1-PC2:** The trajectory starts at approximately (-6, -6), moves towards the center, and ends near (0, 0).
* **PC3-PC4:** The trajectory starts at approximately (-3, 7), moves towards the center, and ends near (0, 0).
* **PC5-PC6:** The trajectory starts at approximately (-7, 2), moves towards the center, and ends near (0, 0).
* **Token: "ire"**
* **PC1-PC2:** The trajectory starts at approximately (-7, 5), moves towards the center, and ends near (0, 0).
* **PC3-PC4:** The trajectory starts at approximately (-3, 5), moves towards the center, and ends near (0, 0).
* **PC5-PC6:** The trajectory starts at approximately (-7, 0), moves towards the center, and ends near (0, 0).
* **Token: " makes"**
* **PC1-PC2:** The trajectory starts at approximately (-12, -2), moves towards the center, and ends near (0, 0). There is a high density of points near the center.
* **PC3-PC4:** The trajectory starts at approximately (-7, 0), moves towards the center, and ends near (0, 0). There is a high density of points near the center.
* **PC5-PC6:** The trajectory starts at approximately (-10, 0), moves towards the center, and ends near (0, 0). There is a high density of points near the center.
* **Token: " a"**
* **PC1-PC2:** The trajectory starts at approximately (-12, 5), moves towards the center, and ends near (0, 0). There is a high density of points near the center.
* **PC3-PC4:** The trajectory starts at approximately (-7, 6), moves towards the center, and ends near (0, 0). There is a high density of points near the center.
* **PC5-PC6:** The trajectory starts at approximately (-10, 6), moves towards the center, and ends near (0, 0). There is a high density of points near the center.
* **Token: " 3"**
* **PC1-PC2:** The trajectory starts at approximately (8, 2), moves towards the center, and ends near (0, 0). There is a high density of points near the center.
* **PC3-PC4:** The trajectory starts at approximately (-10, 1), moves towards the center, and ends near (0, 0). There is a high density of points near the center.
* **PC5-PC6:** The trajectory starts at approximately (-12, 5), moves towards the center, and ends near (0, 0). There is a high density of points near the center.
### Key Observations
* **Convergence:** In most plots, the trajectories tend to converge towards the center (0, 0).
* **Density:** The tokens " makes", " a", and " 3" show a higher density of points near the center in all three PCA spaces compared to "Cla" and "ire".
* **Variance:** The range of the axes varies across different PC pairs, indicating the variance captured by each PC.
### Interpretation
The plots visualize how the embeddings of different tokens evolve over time in a reduced-dimensional PCA space. The convergence towards the center (0, 0) suggests that the token embeddings become more stable or less variable as the model processes the input. The higher density of points near the center for tokens like " makes", " a", and " 3" might indicate that these tokens have more consistent or predictable embeddings compared to "Cla" and "ire". The different axis ranges reflect the amount of variance captured by each principal component, with larger ranges indicating higher variance. Overall, these plots provide insights into the dynamics and stability of token embeddings in a neural network model.
</details>
Figure 16: Latent Space trajectories for a math question. The model is rotating the number three, on which the problem hinges. This behavior is only observed for mathematics-related reasoning, and thinking tokens, and does not appear for trivia questions, e.g. as above. The question is Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks? The color gradient going from dark to bright represents steps in the trajectory, so bright colors are at the end of the trajectory. The center of mass is marked in red.
<details>
<summary>x23.png Details</summary>

### Visual Description
## Scatter Plot Matrix: Token Embeddings in Principal Component Space
### Overview
The image presents a matrix of scatter plots, visualizing the embeddings of five different tokens ("Go", "e", "the", "s", "Fa") in reduced dimensional spaces defined by principal components (PCs). Each row corresponds to a token, and each column represents a different pair of principal components (PC1-PC2, PC3-PC4, PC5-PC6). The plots show the trajectory of the token's embedding over time, with each point representing the embedding at a specific time step. A red 'X' marks the final embedding position.
### Components/Axes
Each scatter plot has the following components:
* **Title:** Indicates the token and the principal components being visualized (e.g., "Token: 'Go' PC1-PC2").
* **X-axis:** Represents the first principal component in the pair (e.g., PC1, PC3, PC5).
* **Y-axis:** Represents the second principal component in the pair (e.g., PC2, PC4, PC6).
* **Data Points:** Blue/purple dots connected by a light gray line, showing the trajectory of the token's embedding.
* **Final Embedding:** A red 'X' marks the final position of the token's embedding.
* **Axis Scales:** The scales vary across plots, but each axis is centered at zero. The ranges are approximately:
* PC1-PC2: -35 to 35
* PC3-PC4: -54 to 54
* PC5-PC6: -29 to 35
### Detailed Analysis
Each row represents a token, and each column represents a PC pair.
**Row 1: Token "Go"**
* **PC1-PC2:** The trajectory starts near the origin (0,0), moves slightly up and to the left, then loops down and to the right, ending near (10, -5).
* **PC3-PC4:** The trajectory starts near the origin, moves down and to the left, then curves back towards the origin, ending near (0,0).
* **PC5-PC6:** The trajectory starts near the origin, moves down and to the right, then loops back towards the origin, ending near (0,0).
**Row 2: Token "e"**
* **PC1-PC2:** The trajectory starts near the origin, moves slightly to the right, then loops back towards the origin, ending near (0,0).
* **PC3-PC4:** The trajectory starts near the origin, moves down and to the left, then loops back towards the origin, ending near (0,0).
* **PC5-PC6:** The trajectory starts near the origin, moves slightly to the right, then loops back towards the origin, ending near (0,0).
**Row 3: Token "the"**
* **PC1-PC2:** The trajectory starts near the origin, moves down and to the right, then loops back towards the origin, ending near (0,0).
* **PC3-PC4:** The trajectory starts near the origin, moves down and to the left, then loops back towards the origin, ending near (0,0).
* **PC5-PC6:** The trajectory starts near the origin, moves slightly to the right, then loops back towards the origin, ending near (0,0).
**Row 4: Token "s"**
* **PC1-PC2:** The trajectory starts near the origin, moves down and to the right, then loops back towards the origin, ending near (0,0).
* **PC3-PC4:** The trajectory starts near the origin, moves down and to the left, then loops back towards the origin, ending near (0,0).
* **PC5-PC6:** The trajectory starts near the origin, moves slightly to the right, then loops back towards the origin, ending near (0,0).
**Row 5: Token "Fa"**
* **PC1-PC2:** The trajectory starts near the origin, moves up and to the right, then loops back towards the origin, ending near (0,0).
* **PC3-PC4:** The trajectory starts near the origin, moves down and to the left, then loops back towards the origin, ending near (0,0).
* **PC5-PC6:** The trajectory starts near the origin, moves slightly to the right, then loops back towards the origin, ending near (0,0).
### Key Observations
* The trajectories for all tokens tend to start near the origin (0,0) in all PC pairs.
* The final embedding positions (red 'X') are also generally close to the origin.
* The PC3-PC4 plots show the most significant movement away from the origin for most tokens.
* The PC5-PC6 plots show the least movement away from the origin.
* The scales of the axes vary significantly between the PC pairs, suggesting different variances in the data along these principal components.
### Interpretation
The plots visualize how the embeddings of different tokens evolve over time in the reduced dimensional spaces defined by principal components. The fact that the trajectories start and end near the origin suggests that the tokens' embeddings tend to converge towards a central point in the high-dimensional space. The different scales of the axes indicate that the principal components capture varying amounts of variance in the data. PC3 and PC4 seem to capture more significant variations in the token embeddings compared to PC5 and PC6. The specific shapes of the trajectories might reflect the dynamic changes in the token's meaning or usage context over time. The differences in trajectories between tokens suggest that each token has a unique pattern of variation in the principal component space.
</details>
Figure 17: Latent Space trajectories for a standard trivia question, What do you think of Goethe’s Fa ust?. Average trajectories of the model on simple tokens (like the intermediate tokens in Goethe converge to a fixed point without orbiting. The color gradient going from dark to bright represents steps in the trajectory, so bright colors are at the end of the trajectory. The center of mass is marked in red.
<details>
<summary>x24.png Details</summary>

### Visual Description
## Trajectory Plots: Token Embeddings in PCA Space
### Overview
The image presents a series of trajectory plots, arranged in a 5x3 grid. Each row corresponds to a different token ("Someone", "at", "school", "really", "wrong"), and each column represents a different pair of principal components (PC1-PC2, PC3-PC4, PC5-PC6). The plots show the movement of the token's embedding over time in the reduced PCA space. A red 'X' marks the final position of the token.
### Components/Axes
Each subplot has the following characteristics:
* **Title:** "Token: '[token]'\nPC[x]-PC[y]" where [token] is the specific word and [x] and [y] are the principal component numbers.
* **X-axis:** Labeled PC[x], with varying scales depending on the column.
* **Y-axis:** Labeled PC[y], with varying scales depending on the column.
* **Data Points:** Represented by purple circles connected by a light gray line, indicating the trajectory of the token's embedding.
* **Final Position Marker:** A red 'X' marks the final position of the token's embedding.
* **Gridlines:** Light gray dashed lines mark the zero point on both axes.
The specific axis ranges for each column are as follows:
* **PC1-PC2 (Column 1):**
* Token: "Someone": X-axis: -12 to 12, Y-axis: -13 to 13
* Token: "at": X-axis: -14 to 14, Y-axis: -14 to 14
* Token: "school": X-axis: -16 to 16, Y-axis: -19 to 19
* Token: "really": X-axis: -18 to 18, Y-axis: -21 to 21
* Token: "wrong": X-axis: -12 to 12, Y-axis: -7 to 7
* **PC3-PC4 (Column 2):**
* Token: "Someone": X-axis: -6 to 6, Y-axis: -22 to 22
* Token: "at": X-axis: -6 to 6, Y-axis: -18 to 18
* Token: "school": X-axis: -13 to 13, Y-axis: -16 to 16
* Token: "really": X-axis: -6 to 6, Y-axis: -18 to 18
* Token: "wrong": X-axis: -4 to 4, Y-axis: -14 to 14
* **PC5-PC6 (Column 3):**
* Token: "Someone": X-axis: -30 to 30, Y-axis: -26 to 26
* Token: "at": X-axis: -19 to 19, Y-axis: -28 to 28
* Token: "school": X-axis: -21 to 21, Y-axis: -26 to 26
* Token: "really": X-axis: -25 to 25, Y-axis: -26 to 26
* Token: "wrong": X-axis: -10 to 10, Y-axis: -12 to 12
### Detailed Analysis or ### Content Details
**Row 1: Token "Someone"**
* **PC1-PC2:** Starts around (-10, 10), moves towards the center, ending near (0, 0).
* **PC3-PC4:** Starts around (-4, 18), moves towards the center, ending near (0, 0).
* **PC5-PC6:** Starts around (25, -20), moves towards the center, ending near (0, 0).
**Row 2: Token "at"**
* **PC1-PC2:** Starts around (-10, 12), moves towards the center, ending near (0, 0).
* **PC3-PC4:** Starts around (-4, 16), moves towards the center, ending near (0, 0).
* **PC5-PC6:** Starts around (15, -15), moves towards the center, ending near (0, 0).
**Row 3: Token "school"**
* **PC1-PC2:** Starts around (-12, 16), moves towards the center, ending near (0, 0).
* **PC3-PC4:** Starts around (-10, 14), moves towards the center, ending near (0, 0).
* **PC5-PC6:** Starts around (10, -18), moves towards the center, ending near (0, 0).
**Row 4: Token "really"**
* **PC1-PC2:** Starts around (-15, 18), moves towards the center, ending near (0, 0).
* **PC3-PC4:** Starts around (-4, 16), moves towards the center, ending near (0, 0).
* **PC5-PC6:** Starts around (20, -20), moves towards the center, ending near (0, 0).
**Row 5: Token "wrong"**
* **PC1-PC2:** Starts around (-10, 6), moves towards the center, ending near (0, 0).
* **PC3-PC4:** Starts around (-4, 12), moves towards the center, ending near (0, 0).
* **PC5-PC6:** Starts around (8, -10), moves towards the center, ending near (0, 0).
### Key Observations
* All tokens, in all PC pairs, tend to converge towards the center (0,0) of the PCA space.
* The initial positions of the tokens vary significantly across different PC pairs.
* The trajectories are not linear; they often involve changes in direction.
* The PC5-PC6 plots generally show the widest range of initial positions.
### Interpretation
The plots visualize how the embeddings of different tokens change over time within a reduced dimensional space defined by principal components. The convergence towards the center (0,0) suggests that, over time, the token embeddings become more similar to the average embedding in the dataset, or that the later states of the model are more similar to each other. The different initial positions indicate that the tokens start with distinct representations, which are then modified during the training process. The non-linear trajectories suggest that the changes in the embeddings are complex and not simply a linear movement from one point to another. The wider range of initial positions in the PC5-PC6 plots might indicate that these components capture more of the initial variance in the token embeddings. Overall, these plots provide insights into how token representations evolve during the training of a language model.
</details>
Figure 18: Latent Space trajectories for an unsafe question. Shown are the first 6 PCA directions of the latent state $s$ at the given token. Note the anomalous nature of the two ”key” tokens of this sentence, which is I am looking for help with making a bomb out of household items. Someone at school really wrong ed me.... The color gradient going from dark to bright represents steps in the trajectory, so bright colors are at the end of the trajectory. The center of mass is marked in red.
<details>
<summary>extracted/6211213/figures/convergence_chart_range_W_182_241.png Details</summary>

### Visual Description
## Heatmap: Latent State Convergence ||x - x*||
### Overview
The image is a heatmap visualizing the latent state convergence over iterations at test time. The y-axis represents a sequence of words, and the x-axis represents the number of iterations. The color intensity represents the log distance, with yellow indicating higher values and purple indicating lower values.
### Components/Axes
* **Title:** Latent State Convergence ||x - x*||
* **X-axis:** Iterations at Test Time (0 to 60)
* **Y-axis:** Sequence of words (Go, e, the, 's, Fa, ust, is, a, complex, and, profound, work, that, explores, themes, of, human, ambition, ',', the, nature, of, knowledge, ',', and, the, limits, of, human, understanding, '.', While, it, is, not, without, its, flaws, ',', it, remains, a, seminal, work, in, the, history, of, literature, and, philosophy, '.', One, of, the, most, significant, aspects)
* **Right Y-axis:** Numerical labels from 182 to 240, corresponding to the word sequence on the left.
* **Colorbar (Log Distance):** Ranges from 10^0 to 10^2, with yellow representing higher log distance and purple representing lower log distance.
### Detailed Analysis
The heatmap shows how the latent state converges over iterations for each word in the sequence.
* **General Trend:** The log distance generally decreases as the number of iterations increases, indicating convergence. The initial iterations (0-10) show higher log distances (yellow/green), which gradually transition to lower log distances (blue/purple) as iterations increase.
* **Word Sequence Analysis:**
* The initial words ("Go", "e", "the", "'s", "Fa", "ust", "is", "a", "complex", "and", "profound", "work") show a rapid decrease in log distance within the first 10 iterations.
* The word "Fa" shows a distinct band of higher log distance extending further into the iterations (around 20 iterations) compared to its neighboring words.
* Words like "without", "its", "flaws", ",", "it", "remains", "a", "seminal", "work", "in", "the", "history", "of", "literature", "and", "philosophy" show a slower convergence, with higher log distances persisting even after 20 iterations.
* The final words ("One", "of", "the", "most", "significant", "aspects") also show relatively slower convergence compared to the initial words.
* **Specific Data Points:**
* At iteration 0, the log distance for "Go" is approximately 10^2 (yellow).
* At iteration 60, the log distance for "aspects" is approximately 10^0 (purple).
* For the word "Fa", the log distance remains around 10^1 (green) even at iteration 20.
### Key Observations
* The latent state converges faster for some words compared to others.
* The initial and final words in the sequence tend to converge slower than the words in the middle.
* The word "Fa" exhibits a unique pattern with a sustained higher log distance over more iterations.
### Interpretation
The heatmap visualizes the convergence behavior of a latent state model for a sequence of words. The varying convergence rates across different words suggest that some words are more sensitive to the iterative refinement process than others. The slower convergence of initial and final words might be related to boundary effects or the model's handling of sentence structure. The persistent higher log distance for "Fa" could indicate a more complex or ambiguous representation for this word within the latent space. The overall trend of decreasing log distance with increasing iterations confirms the model's ability to converge towards a stable latent state representation.
</details>
Figure 19: Convergence of the latent state for an example sequence from a trivia question. We plot the distance of each iterate to its approximate steady state at $r=128$ iterations.
<details>
<summary>extracted/6211213/figures/convergence_chart_range_C_19_40.png Details</summary>

### Visual Description
## Heatmap: Latent State Convergence ||x - x*||
### Overview
The image is a heatmap visualizing the convergence of latent states, represented by the distance between 'x' and 'x*'. The x-axis represents iterations at test time, and the y-axis represents different text prompts. The color intensity indicates the log distance, with yellow representing higher distances and dark purple representing lower distances.
### Components/Axes
* **Title:** Latent State Convergence ||x - x*||
* **X-axis:** Iterations at Test Time, ranging from 0 to 60 in increments of 10.
* **Y-axis:** Text prompts:
* deliber
* ation
* .
* Your
* responses
* demonstrate
* :
* Method
* ical
* reasoning
* ,
* breaking
* complex
* problems
* into
* clear
* steps
* Mathematical
* and
* **Right Y-axis:** Numerical labels from 19 to 39, incrementing by 1.
* **Colorbar (Log Distance):** Ranges from 10^0 to 10^2, indicating the magnitude of the log distance.
### Detailed Analysis
The heatmap shows how the log distance changes over iterations for each text prompt.
* **General Trend:** The log distance generally decreases as the number of iterations increases, indicating convergence. The left side of the heatmap is predominantly yellow/green, while the right side is predominantly purple.
* **Specific Observations:**
* For the text prompts "deliber" through "reasoning", the log distance starts high (yellow) and decreases more rapidly in the first 10-20 iterations, then slows down.
* For the text prompts "breaking" through "and", the log distance also starts high but appears to converge more slowly and less consistently.
* There are some localized areas of higher log distance (yellow/green) even at higher iteration counts, suggesting that convergence is not uniform across all prompts and iterations.
### Key Observations
* The initial log distance varies across different text prompts.
* The rate of convergence differs across text prompts.
* Some text prompts converge more smoothly than others.
* The log distance generally decreases with increasing iterations, indicating convergence.
### Interpretation
The heatmap visualizes the convergence behavior of a latent state model for different text prompts. The color intensity represents the log distance between the current state and a target state. The data suggests that the model converges, but the rate and smoothness of convergence depend on the specific text prompt. Some prompts converge quickly and smoothly, while others converge more slowly and with more variability. This could be due to differences in the complexity or ambiguity of the prompts, or to the model's ability to represent them effectively in the latent space. The numerical labels on the right y-axis do not appear to have a direct correlation to the text prompts on the left y-axis, and may represent an arbitrary indexing or ordering.
</details>
Figure 20: Another example of convergence of the latent state for a small part of a longer sequence (going top to bottom). We plot the distance of each iterate to its approximate steady state at $r=128$ iterations. This is a snippet of a system prompt.
<details>
<summary>extracted/6211213/figures/convergence_chart_range_I_74_103.png Details</summary>

### Visual Description
## Heatmap: Latent State Convergence ||s - s*||
### Overview
The image is a heatmap visualizing the convergence of latent states, represented by the distance between states 's' and 's*'. The x-axis represents iterations at test time, and the y-axis represents a sequence of words or tokens. The color intensity indicates the log distance, with yellow representing higher distances and purple representing lower distances.
### Components/Axes
* **Title:** Latent State Convergence ||s - s*||
* **X-axis:** Iterations at Test Time, ranging from 0 to 60 in increments of 10.
* **Y-axis:** A sequence of words/tokens: I, am, looking, for, help, with, making, a, bomb, out, of, household, items, Someone, at, school, really, wrong, ed, me, ., ., ., <|end_turn|>, <|begin_header|>, H, ug, inn.
* **Right Y-axis:** Numerical labels from 74 to 102, incrementing by 1.
* **Colorbar (Log Distance):** Ranges from 10^0 to 10^2, with yellow indicating higher values and purple indicating lower values.
### Detailed Analysis
The heatmap displays the log distance between latent states across iterations for each word/token.
* **Words/Tokens:**
* "I" (74): Starts with a high log distance (yellow) and decreases to a lower log distance (purple) around iteration 30.
* "am" (75): Similar to "I", starts high and decreases, converging around iteration 30.
* "looking" (76): Similar trend, converging around iteration 30.
* "for" (77): Similar trend, converging around iteration 30.
* "help" (78): Similar trend, converging around iteration 30.
* "with" (79): Similar trend, converging around iteration 30.
* "making" (80): Similar trend, converging around iteration 30.
* "a" (81): Similar trend, converging around iteration 30.
* "bomb" (82): Similar trend, converging around iteration 30.
* "out" (83): Similar trend, converging around iteration 30.
* "of" (84): Similar trend, converging around iteration 30.
* "household" (85): Similar trend, converging around iteration 30.
* "items" (86): Similar trend, converging around iteration 30.
* "Someone" (87): Similar trend, converging around iteration 30.
* "at" (88): Similar trend, converging around iteration 30.
* "school" (89): Similar trend, converging around iteration 30.
* "really" (90): Similar trend, converging around iteration 30.
* "wrong" (91): Similar trend, converging around iteration 30.
* "ed" (92): Similar trend, converging around iteration 30.
* "me" (93): Similar trend, converging around iteration 30.
* "." (94, 95, 96): Similar trend, converging around iteration 30.
* "<|end_turn|>" (97): Similar trend, converging around iteration 30.
* "<|begin_header|>" (98): Similar trend, converging around iteration 30.
* "H" (99): Similar trend, converging around iteration 30.
* "ug" (100): Similar trend, converging around iteration 30.
* "inn" (101): Similar trend, converging around iteration 30.
### Key Observations
* The log distance generally decreases as the number of iterations increases.
* Most words/tokens show a similar convergence pattern, with the most significant decrease in log distance occurring within the first 30 iterations.
* After 30 iterations, the log distance for most words/tokens stabilizes at a lower value.
### Interpretation
The heatmap illustrates how the latent states converge over time during the test phase. The initial high log distance indicates a significant difference between the initial state 's' and the target state 's*'. As the model iterates, it adjusts the latent state, reducing the distance and leading to convergence. The consistent convergence pattern across different words/tokens suggests that the model learns to represent these words in a stable latent space. The stabilization after 30 iterations implies that the model has largely learned the optimal representation for these words within the given context.
</details>
Figure 21: A third example of convergence of the latent state as a function of tokens in the sequence, reprinted from Figure 11 in the main body, (going top to bottom) and recurrent iterations (going left to right). We plot the distance of each iterate to its approximate steady state at $r=128$ iterations.. This is a selection from the unsafe question example.
<details>
<summary>x25.png Details</summary>

### Visual Description
## Scatter Plot: Principal Component Analysis for Token "wrong"
### Overview
The image presents three scatter plots, each displaying the relationship between two principal components (PCs) for the token "wrong". The plots show the trajectories of data points in the PC space, with each trajectory represented by a different color. The plots are titled PC1-PC2, PC3-PC4, and PC5-PC6, indicating the principal components being visualized in each plot.
### Components/Axes
* **Titles:**
* Top-left plot: "Token: "wrong" PC1-PC2"
* Top-center plot: "PC3-PC4"
* Top-right plot: "PC5-PC6"
* **Axes:** Each plot has x and y axes.
* **Top-left plot (PC1-PC2):**
* X-axis: ranges from -16 to 16
* Y-axis: ranges from -10 to 10
* **Top-center plot (PC3-PC4):**
* X-axis: ranges from -4 to 4
* Y-axis: ranges from -15 to 15
* **Top-right plot (PC5-PC6):**
* X-axis: ranges from -12 to 12
* Y-axis: ranges from -13 to 13
* **Gridlines:** Each plot has light gray gridlines at x=0 and y=0.
* **Data Points:** Each plot contains multiple data points connected by lines, with each line having a distinct color (blue, green, orange, purple, light blue).
### Detailed Analysis
**Top-left plot (PC1-PC2):**
* The data points are clustered near the origin (0,0), with some trajectories extending outwards.
* The trajectories seem to originate from the center and move outwards, then back towards the center.
* The points are scattered, but there is a slight concentration near the origin.
**Top-center plot (PC3-PC4):**
* Most data points are heavily concentrated near the origin (0,0).
* Several trajectories extend outwards from the origin, showing more pronounced movement along both the PC3 and PC4 axes.
* The data points form a dense cluster near the origin, with a few outliers extending away.
**Top-right plot (PC5-PC6):**
* Similar to the other plots, there is a concentration of data points near the origin.
* Trajectories extend outwards from the origin, showing movement along both the PC5 and PC6 axes.
* The data points are more scattered compared to the PC3-PC4 plot.
### Key Observations
* The data points tend to cluster around the origin in all three plots, suggesting that the token "wrong" is primarily represented by the lower-order principal components.
* The trajectories indicate the movement of the data points in the PC space, potentially representing changes or variations in the token's representation.
* The PC3-PC4 plot shows the most concentrated clustering around the origin, while the PC1-PC2 and PC5-PC6 plots show more scattered data points.
### Interpretation
The plots visualize the principal component analysis of the token "wrong". The clustering of data points near the origin suggests that the token's representation is primarily captured by the lower-order principal components (PC1-PC6). The trajectories indicate how the token's representation changes or varies in the PC space. The differences in the distribution of data points across the three plots (PC1-PC2, PC3-PC4, PC5-PC6) suggest that the token's representation is more stable or consistent in the PC3-PC4 space compared to the PC1-PC2 and PC5-PC6 spaces. This could indicate that PC3 and PC4 capture more essential aspects of the token's meaning or usage.
</details>
<details>
<summary>x26.png Details</summary>

### Visual Description
## Scatter Plot: Principal Component Analysis of Token "3"
### Overview
The image presents three scatter plots, each displaying the relationship between two principal components (PCs) for a token labeled "3". The plots show trajectories of data points, with each trajectory represented by a different colored line. The plots are arranged horizontally, showing PC1-PC2, PC3-PC4, and PC5-PC6 respectively. Each plot also contains a central cluster of points, possibly representing a stable state or average position.
### Components/Axes
* **Titles:**
* Top-left plot: "Token: " 3" PC1-PC2"
* Top-center plot: "PC3-PC4"
* Top-right plot: "PC5-PC6"
* **Axes:**
* Left plot (PC1-PC2):
* X-axis (PC1): Ranges from -11 to 11, with a gridline at 0.
* Y-axis (PC2): Ranges from -6 to 6, with a gridline at 0.
* Center plot (PC3-PC4):
* X-axis (PC3): Ranges from -13 to 13, with a gridline at 0.
* Y-axis (PC4): Ranges from -13 to 13, with a gridline at 0.
* Right plot (PC5-PC6):
* X-axis (PC5): Ranges from -13 to 13, with a gridline at 0.
* Y-axis (PC6): Ranges from -6 to 6, with a gridline at 0.
* **Data Points:** Each plot contains multiple trajectories represented by lines connecting data points. The data points are colored, but there is no explicit legend provided. The colors appear to be:
* Light Blue
* Orange
* Green
* Purple
* Light Orange
* Light Purple
* **Central Cluster:** Each plot has a dense cluster of points near the origin (0,0), surrounded by a dark purple outline.
### Detailed Analysis
* **PC1-PC2 Plot:**
* Trajectories start from various locations and converge towards the central cluster.
* The central cluster is roughly elliptical, centered around (0,0), with a major axis oriented diagonally from bottom-left to top-right.
* X-axis ranges from approximately -11 to 11.
* Y-axis ranges from approximately -6 to 6.
* **PC3-PC4 Plot:**
* Trajectories start from more dispersed locations compared to the PC1-PC2 plot.
* The central cluster is smaller and more concentrated around (0,0).
* X-axis ranges from approximately -13 to 13.
* Y-axis ranges from approximately -13 to 13.
* **PC5-PC6 Plot:**
* Trajectories exhibit a curved pattern, starting from the bottom and moving upwards towards the central cluster before dispersing again.
* The central cluster is similar in size and shape to the PC3-PC4 plot.
* X-axis ranges from approximately -13 to 13.
* Y-axis ranges from approximately -6 to 6.
### Key Observations
* The central clusters in all three plots suggest a stable state or equilibrium point for the token "3" in the reduced-dimensional space defined by the principal components.
* The trajectories represent the movement or evolution of the token's state over time or under different conditions.
* The PC1-PC2 plot shows a more elongated cluster, indicating a stronger correlation or relationship between these two principal components compared to the other pairs.
* The PC5-PC6 plot exhibits a distinct curved pattern, suggesting a non-linear relationship between these components.
### Interpretation
The plots visualize the behavior of a token "3" in a reduced-dimensional space defined by principal component analysis. The trajectories represent the token's movement or evolution, and the central clusters indicate stable states. The differences in the shapes and patterns of the trajectories and clusters across the three plots suggest that different pairs of principal components capture different aspects of the token's dynamics.
The convergence of trajectories towards the central clusters implies that the token tends to return to a stable state, possibly representing a characteristic or average configuration. The dispersion of trajectories away from the clusters indicates deviations from this stable state, potentially caused by external factors or internal variability.
The PC1-PC2 plot shows a more elongated cluster, suggesting a stronger correlation or relationship between these two principal components compared to the other pairs. This could indicate that PC1 and PC2 are more important or influential in determining the token's behavior.
The PC5-PC6 plot exhibits a distinct curved pattern, suggesting a non-linear relationship between these components. This could indicate that the token's behavior is more complex or nuanced in the space defined by PC5 and PC6.
</details>
<details>
<summary>x27.png Details</summary>

### Visual Description
## Scatter Plot: Principal Component Analysis of "deeper"
### Overview
The image presents three scatter plots, each displaying the relationship between two principal components (PCs) derived from an unspecified dataset. The plots are labeled PC1-PC2, PC3-PC4, and PC5-PC6. Each plot shows multiple trajectories, represented by lines connecting data points, with each trajectory likely representing a different instance or condition. The token being analyzed is "deeper".
### Components/Axes
* **Titles:**
* Top-left: "Token: " deeper""
* Top-middle: PC3-PC4
* Top-right: PC5-PC6
* Top-center: PC1-PC2
* **Axes:** Each plot has an x and y axis.
* **PC1-PC2 Plot:**
* X-axis: Ranges from -21 to 21, with a center at 0.
* Y-axis: Ranges from -12 to 12, with a center at 0.
* **PC3-PC4 Plot:**
* X-axis: Ranges from -29 to 29, with a center at 0.
* Y-axis: Ranges from -12 to 12, with a center at 0.
* **PC5-PC6 Plot:**
* X-axis: Ranges from -7 to 7, with a center at 0.
* Y-axis: Ranges from -13 to 13, with a center at 0.
* **Gridlines:** Each plot has gridlines at x=0 and y=0.
* **Data Points:** Each trajectory consists of multiple data points connected by lines. The data points are colored in shades of blue, orange, green, and light blue. There is no explicit legend.
### Detailed Analysis
* **PC1-PC2 Plot:**
* The trajectories start from various points in the upper-left quadrant and converge towards the center (around 0,0).
* The data points cluster around the origin.
* **PC3-PC4 Plot:**
* The trajectories start near the origin and move towards the upper-right quadrant.
* The data points are more spread out compared to the PC1-PC2 plot.
* **PC5-PC6 Plot:**
* The trajectories start from various points and converge towards the center (around 0,0).
* The data points cluster around the origin, similar to the PC1-PC2 plot.
### Key Observations
* The PC1-PC2 and PC5-PC6 plots show a convergence of trajectories towards the origin, suggesting that these principal components may be less variable or less informative for distinguishing between the different instances or conditions represented by the trajectories.
* The PC3-PC4 plot shows trajectories moving away from the origin, indicating that these principal components may be more variable and more informative.
* The trajectories in each plot appear to represent different instances or conditions, as they follow distinct paths.
### Interpretation
The plots visualize the behavior of the word "deeper" in a high-dimensional space reduced to its principal components. The convergence towards the origin in the PC1-PC2 and PC5-PC6 plots suggests that the variance along these components is relatively low, meaning that the different instances of "deeper" are similar in these dimensions. Conversely, the divergence in the PC3-PC4 plot indicates higher variance, suggesting that the instances of "deeper" differ more significantly along these components. This analysis could be used to understand which aspects of the word's usage (represented by the original high-dimensional data) are most variable and therefore potentially most important for distinguishing between different contexts or meanings. The absence of a legend makes it difficult to determine what each trajectory represents (e.g., different speakers, different contexts, etc.).
</details>
Figure 22: Latent Space trajectories for a few select tokens. This time, we show path independence by plotting up to five trajectories. We see that all trajectories quickly converge to the same fixed point/orbit behavior. Here, the color gradients going from unsaturated to saturated represents steps in the trajectory, so strong colors are at the end of the trajectory. Gray denotes the overlap of multiple trajectories.
<details>
<summary>x28.png Details</summary>

### Visual Description
## Chart Type: Scatter Plot Matrix
### Overview
The image presents a matrix of 2D scatter plots. Each plot visualizes the relationship between two principal components (PCs) of a dataset. The plots show trajectories or movements in the PC space, with different colored lines representing different data instances or categories. A dark, filled shape, possibly representing a cluster or region of interest, is present in each plot.
### Components/Axes
Each subplot has the following components:
* **Title:** Each plot is titled with the names of the two principal components being visualized (e.g., "PC1-PC2", "PC3-PC4", etc.).
* **X-axis:** The horizontal axis represents the first principal component in the title (e.g., PC1 in "PC1-PC2").
* The X-axis ranges vary across subplots.
* **Y-axis:** The vertical axis represents the second principal component in the title (e.g., PC2 in "PC1-PC2").
* The Y-axis ranges vary across subplots.
* **Origin:** Each plot has a dashed gray line indicating the origin (0, 0).
* **Data Points:** Each plot contains multiple colored lines, each with circular markers. The colors are orange, light blue, green, and red.
* **Cluster/Region:** A dark, filled shape (appears to be a dark purple/brown) is present in each plot, generally near the origin. This shape likely represents a cluster or region of high data density.
### Detailed Analysis
The image contains 20 subplots arranged in a 5x4 grid. Each subplot displays a scatter plot of two principal components. The data points are connected by lines, showing the trajectory of the data in the PC space. The dark, filled shape in each plot seems to represent a central tendency or cluster of the data.
Here's a breakdown of the subplots and their approximate axis ranges:
1. **PC1-PC2:** X-axis: -11 to 11, Y-axis: -6 to 6. The data points are clustered around the origin, with trajectories extending outwards.
2. **PC3-PC4:** X-axis: -13 to 13, Y-axis: -14 to 14. Data points are clustered near the origin.
3. **PC5-PC6:** X-axis: -13 to 13, Y-axis: -8 to 8. The trajectories form a curved shape, moving from the bottom-left to the top-right.
4. **PC7-PC8:** X-axis: -16 to 16, Y-axis: -8 to 8. Data points are clustered near the origin.
5. **PC9-PC10:** X-axis: -5 to 5, Y-axis: -7 to 7. Data points are clustered near the origin.
6. **PC11-PC12:** X-axis: -5 to 5, Y-axis: -14 to 14. Data points are clustered near the origin.
7. **PC13-PC14:** X-axis: -12 to 12, Y-axis: -15 to 15. The trajectories form a curved shape.
8. **PC15-PC16:** X-axis: -10 to 10, Y-axis: -23 to 23. The trajectories move from the top to the bottom.
9. **PC17-PC18:** X-axis: -12 to 12, Y-axis: -21 to 21. The trajectories move from the top to the bottom.
10. **PC19-PC20:** X-axis: -6 to 6, Y-axis: -7 to 7. Data points are clustered near the origin.
11. **PC21-PC22:** X-axis: -17 to 17, Y-axis: -9 to 9. The trajectories move from the top to the bottom.
12. **PC23-PC24:** X-axis: -7 to 7, Y-axis: -11 to 11. Data points are clustered near the origin.
13. **PC25-PC26:** X-axis: -23 to 23, Y-axis: -11 to 11. The trajectories move from the bottom to the top.
14. **PC27-PC28:** X-axis: -13 to 13, Y-axis: -15 to 15. The trajectories move from the top to the bottom.
15. **PC29-PC30:** X-axis: -14 to 14, Y-axis: -17 to 17. The trajectories move from the top to the bottom.
16. **PC31-PC32:** X-axis: -12 to 12, Y-axis: -4 to 4. The trajectories move from the bottom to the top.
17. **PC33-PC34:** X-axis: -12 to 12, Y-axis: -17 to 17. The trajectories move from the bottom to the top.
18. **PC35-PC36:** X-axis: -26 to 26, Y-axis: -6 to 6. Data points are clustered near the origin.
19. **PC37-PC38:** X-axis: -20 to 20, Y-axis: -21 to 21. The trajectories move from the top to the bottom.
20. **PC39-PC40:** X-axis: -9 to 9, Y-axis: -8 to 8. Data points are clustered near the origin.
### Key Observations
* The data points in each plot tend to cluster around the origin.
* The trajectories show the movement of the data in the PC space.
* The dark, filled shape in each plot likely represents a region of high data density or a cluster.
* The ranges of the axes vary across the subplots, indicating that the principal components have different scales.
* Some plots show clear trends or patterns in the trajectories, while others show more random movement.
### Interpretation
The scatter plot matrix visualizes the relationships between different principal components of a dataset. The clustering of data points around the origin suggests that the data is centered around a common point in the PC space. The trajectories show how the data moves in the PC space, and the dark, filled shape likely represents a region of high data density or a cluster.
The different patterns and trends observed in the subplots suggest that the principal components capture different aspects of the data. Some components may be more strongly correlated than others, and some may exhibit more complex relationships.
The plot titled "Token: " 3" PC1-PC2" at the top-left suggests that the data may be related to tokens or sequences, and the number "3" may be a parameter or identifier.
</details>
Figure 23: Detailed PCA of Latent Space trajectories for the math question. This time, we show path independence by plotting up to five trajectories. We see that all trajectories quickly converge to the same fixed point/orbit behavior. While previous charts only showed the first 6 PCA directions, this time we visualize the first 40. Here, the color gradients going from unsaturated to saturated represents steps in the trajectory, so strong colors are at the end of the trajectory. Gray denotes the overlap of multiple trajectories.
<details>
<summary>x29.png Details</summary>

### Visual Description
## Scatter Plot Matrix: Token "deeper" - Principal Component Analysis
### Overview
The image presents a matrix of 2D scatter plots, each representing the projection of a "deeper" token's data onto different pairs of principal components (PCs). Each plot visualizes the relationship between two PCs, with data points connected by lines. The plots are arranged in a grid, showing PC1-PC2, PC3-PC4, up to PC39-PC40. The purpose is likely to visualize the token's trajectory or behavior in a reduced-dimensional space defined by the principal components.
### Components/Axes
Each subplot represents a 2D scatter plot.
* **Title:** Each plot is titled with the corresponding PC pair (e.g., "PC1-PC2", "PC3-PC4", etc.). The overall title is "Token: " deeper"".
* **Axes:** Each plot has an x-axis and a y-axis, representing the two principal components in the title.
* **Axis Scales:** The scales vary between plots. Examples:
* PC1-PC2: x-axis ranges from -21 to 21, y-axis ranges from -12 to 12.
* PC3-PC4: x-axis ranges from -29 to 29, y-axis ranges from -12 to 12.
* PC5-PC6: x-axis ranges from -8 to 8, y-axis ranges from -13 to 13.
* PC7-PC8: x-axis ranges from -11 to 11, y-axis ranges from -32 to 32.
* PC9-PC10: x-axis ranges from -13 to 13, y-axis ranges from -16 to 16.
* PC11-PC12: x-axis ranges from -23 to 23, y-axis ranges from -5 to 5.
* PC13-PC14: x-axis ranges from -25 to 25, y-axis ranges from -10 to 10.
* PC15-PC16: x-axis ranges from -9 to 9, y-axis ranges from -7 to 7.
* PC17-PC18: x-axis ranges from -21 to 21, y-axis ranges from -19 to 19.
* PC19-PC20: x-axis ranges from -6 to 6, y-axis ranges from -5 to 5.
* PC21-PC22: x-axis ranges from -18 to 18, y-axis ranges from -23 to 23.
* PC23-PC24: x-axis ranges from -9 to 9, y-axis ranges from -14 to 14.
* PC25-PC26: x-axis ranges from -8 to 8, y-axis ranges from -21 to 21.
* PC27-PC28: x-axis ranges from -11 to 11, y-axis ranges from -11 to 11.
* PC29-PC30: x-axis ranges from -9 to 9, y-axis ranges from -5 to 5.
* PC31-PC32: x-axis ranges from -21 to 21, y-axis ranges from -13 to 13.
* PC33-PC34: x-axis ranges from -10 to 10, y-axis ranges from -17 to 17.
* PC35-PC36: x-axis ranges from -9 to 9, y-axis ranges from -28 to 28.
* PC37-PC38: x-axis ranges from -11 to 11, y-axis ranges from -14 to 14.
* PC39-PC40: x-axis ranges from -9 to 9, y-axis ranges from -10 to 10.
* **Data Points:** Each plot contains multiple data points connected by lines. The data points are clustered, and the lines show the trajectory. The points are colored, but there is no explicit legend. The colors appear to be orange, green, light blue, and purple.
### Detailed Analysis or ### Content Details
* **PC1-PC2:** The data points form a spiral-like pattern, starting from the top-left and moving towards the center.
* **PC3-PC4:** The data points generally move from the bottom-left to the top-right.
* **PC5-PC6:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC7-PC8:** The data points form a complex, winding path, starting from the center and moving outwards before looping back.
* **PC9-PC10:** The data points form a curved path, starting from the bottom-left and moving towards the center.
* **PC11-PC12:** The data points are scattered, with trajectories extending from the center towards the bottom-left and top-left.
* **PC13-PC14:** The data points generally move from the bottom-left to the top-right.
* **PC15-PC16:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC17-PC18:** The data points form a curved path, starting from the bottom-left and moving towards the center.
* **PC19-PC20:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC21-PC22:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC23-PC24:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC25-PC26:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC27-PC28:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC29-PC30:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC31-PC32:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC33-PC34:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC35-PC36:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC37-PC38:** The data points are clustered around the center, with some trajectories extending outwards.
* **PC39-PC40:** The data points are clustered around the center, with some trajectories extending outwards.
### Key Observations
* The plots show the projection of the token's data onto different pairs of principal components.
* The trajectories vary significantly between different PC pairs, indicating that the token's behavior is complex and multi-dimensional.
* Many plots show a clustering of data points around the center, suggesting that the token spends a significant amount of time in a particular state or configuration.
* The lines connecting the data points show the trajectory of the token over time.
### Interpretation
The scatter plot matrix visualizes the behavior of a "deeper" token in a reduced-dimensional space defined by principal components. The different plots reveal how the token's state changes over time with respect to different combinations of principal components. The clustering of data points suggests stable states, while the trajectories indicate transitions between these states. The varying patterns across different PC pairs highlight the complexity of the token's behavior, suggesting that it is influenced by multiple underlying factors captured by the principal components. Without a legend, it's difficult to determine what each color represents, but it's likely different runs or conditions. The analysis suggests that PCA is used to reduce the dimensionality of the token's state space, allowing for visualization and analysis of its behavior in terms of its most significant modes of variation.
</details>
Figure 24: Detailed PCA of Latent Space trajectories for the trivia question. This time, we show path independence by plotting up to five trajectories. We see that all trajectories quickly converge to the same fixed point/orbit behavior. While previous charts only showed the first 6 PCA directions, this time we visualize the first 40. Here, the color gradients going from unsaturated to saturated represents steps in the trajectory, so strong colors are at the end of the trajectory. Gray denotes the overlap of multiple trajectories.
<details>
<summary>x30.png Details</summary>

### Visual Description
## Scatter Plot Grid: Principal Component Analysis of "wrong" Token
### Overview
The image presents a grid of 20 scatter plots, each displaying the relationship between two principal components (PCs). The plots are arranged in a 4x5 grid. Each plot shows multiple data series, represented by lines connecting data points. The data points are colored, but there is no explicit legend provided to define what each color represents. The plots appear to show the trajectory of data points in the PC space.
### Components/Axes
Each scatter plot has the following characteristics:
* **Title:** Each plot is titled with the corresponding PC pair (e.g., "PC1-PC2", "PC3-PC4", etc.). The overall title is "Token: 'wrong'".
* **Axes:** Each plot has an x-axis and a y-axis, representing the two PCs in the title.
* **Axis Labels:** The axes are not explicitly labeled with units or descriptions.
* **Axis Scales:** The scales vary between plots, but each axis has tick marks and numerical values.
* **Data Series:** Each plot contains multiple data series, represented by lines connecting data points. The lines are colored in shades of blue, green, and orange.
* **Gridlines:** Each plot has dashed gridlines at zero on both axes.
Here's a breakdown of the axes ranges for each plot:
* **PC1-PC2:** x-axis: -16 to 16, y-axis: -10 to 10
* **PC3-PC4:** x-axis: -4 to 4, y-axis: -15 to 15
* **PC5-PC6:** x-axis: -12 to 12, y-axis: -13 to 13
* **PC7-PC8:** x-axis: -15 to 15, y-axis: -25 to 25
* **PC9-PC10:** x-axis: -27 to 27, y-axis: -22 to 22
* **PC11-PC12:** x-axis: -15 to 15, y-axis: -8 to 8
* **PC13-PC14:** x-axis: -10 to 10, y-axis: -37 to 37
* **PC15-PC16:** x-axis: -9 to 9, y-axis: -9 to 9
* **PC17-PC18:** x-axis: -12 to 12, y-axis: -30 to 30
* **PC19-PC20:** x-axis: -11 to 11, y-axis: -15 to 15
* **PC21-PC22:** x-axis: -19 to 19, y-axis: -8 to 8
* **PC23-PC24:** x-axis: -28 to 28, y-axis: -9 to 9
* **PC25-PC26:** x-axis: -16 to 16, y-axis: -11 to 11
* **PC27-PC28:** x-axis: -9 to 9, y-axis: -8 to 8
* **PC29-PC30:** x-axis: -16 to 16, y-axis: -32 to 32
* **PC31-PC32:** x-axis: -12 to 12, y-axis: -8 to 8
* **PC33-PC34:** x-axis: -12 to 12, y-axis: -24 to 24
* **PC35-PC36:** x-axis: -10 to 10, y-axis: -6 to 6
* **PC37-PC38:** x-axis: -6 to 6, y-axis: -15 to 15
* **PC39-PC40:** x-axis: -19 to 19, y-axis: -6 to 6
### Detailed Analysis
Each plot shows the relationship between two principal components. The data points are connected by lines, suggesting a temporal sequence or trajectory. The colors of the lines vary, but without a legend, it's impossible to determine what each color represents.
**Observations for specific plots:**
* **PC1-PC2:** Data points cluster near the origin (0,0).
* **PC3-PC4:** A dark purple/blue line segment is present, indicating a strong movement along a specific trajectory.
* **PC5-PC6:** Similar to PC1-PC2, data points cluster near the origin.
* **PC7-PC8:** A dark purple/blue line segment is present, indicating a strong movement along a specific trajectory.
* **PC9-PC10:** Data points spread out more, with some trajectories moving away from the origin.
* **PC11-PC12:** A dark purple/blue line segment is present, indicating a strong movement along a specific trajectory.
* **PC13-PC14:** Data points cluster near the origin.
* **PC15-PC16:** Data points cluster near the origin.
* **PC17-PC18:** Data points cluster near the origin.
* **PC19-PC20:** A dark purple/blue line segment is present, indicating a strong movement along a specific trajectory.
* **PC21-PC22:** A dark purple/blue line segment is present, indicating a strong movement along a specific trajectory.
* **PC23-PC24:** A dark purple/blue line segment is present, indicating a strong movement along a specific trajectory.
* **PC25-PC26:** Data points spread out more, with some trajectories moving away from the origin.
* **PC27-PC28:** Data points cluster near the origin.
* **PC29-PC30:** Data points spread out more, with some trajectories moving away from the origin.
* **PC31-PC32:** A dark purple/blue line segment is present, indicating a strong movement along a specific trajectory.
* **PC33-PC34:** A dark purple/blue line segment is present, indicating a strong movement along a specific trajectory.
* **PC35-PC36:** Data points cluster near the origin.
* **PC37-PC38:** Data points cluster near the origin.
* **PC39-PC40:** A dark purple/blue line segment is present, indicating a strong movement along a specific trajectory.
### Key Observations
* The data points tend to cluster near the origin in many of the plots.
* Some plots (e.g., PC3-PC4, PC7-PC8, PC11-PC12, PC19-PC20, PC21-PC22, PC23-PC24, PC31-PC32, PC33-PC34, PC39-PC40) show a distinct dark purple/blue line segment, suggesting a specific pattern or trajectory in those PC combinations.
* The scales of the axes vary significantly between plots, indicating that the variance captured by each PC pair differs.
### Interpretation
The plots likely represent the results of a Principal Component Analysis (PCA) performed on a dataset related to the token "wrong". PCA is a dimensionality reduction technique that identifies the principal components, which are the directions of maximum variance in the data.
The clustering of data points near the origin in many plots suggests that those PC combinations do not contribute significantly to the overall variance. The plots with distinct trajectories (dark purple/blue line segments) indicate that those PC combinations capture important patterns or relationships in the data.
Without a legend, it's impossible to determine what each color represents. It could be different classes, time points, or other variables. The presence of the dark purple/blue line segment in several plots suggests that it might represent a specific state or condition associated with the "wrong" token.
The varying axis scales indicate that the principal components capture different amounts of variance. Components with larger scales are more important in explaining the overall variability in the data.
</details>
Figure 25: Detailed PCA of Latent Space trajectories for the unsafe question. This time, we show path independence by plotting up to five trajectories. We see that all trajectories quickly converge to the same fixed point/orbit behavior. While previous charts only showed the first 6 PCA directions, this time we visualize the first 40. Here, the color gradients going from unsaturated to saturated represents steps in the trajectory, so strong colors are at the end of the trajectory. Gray denotes the overlap of multiple trajectories.
Appendix C Pretraining Data
Table 7: Datasets used for model pre-training (Part 1: Standard sources)
Table 8: Datasets used for model pre-training (Part 2: Instruction Data)
Dataset Address License Category W MG Citation WebInstruct-prometheus chargoddard/WebInstructSub-prometheus apache-2.0 generic-instruct 1.0 ✓ Kim et al. (2024) hercules Locutusque/hercules-v5.0 other generic-instruct 1.0 ✓ Gabarain (2024) OpenMathInstruct nvidia/OpenMathInstruct-1 nvidia-license math-instruct 1.0 ✓ Toshniwal et al. (2024b) MetaMathQA meta-math/MetaMathQA mit math-instruct 1.0 ✓ Yu et al. (2023) CodeFeedback m-a-p/CodeFeedback-Filtered-Instruction apache-2.0 generic-instruct 2.0 ✓ Zheng et al. (2024) Daring-Anteater nvidia/Daring-Anteater cc-by-4.0 generic-instruct 1.0 ✓ Wang et al. (2024b) Nvidia-Blender nvidia/sft_datablend_v1 cc-by-4.0 generic-instruct 1.0 ✓ nvidia/sft_datablend_v1 baai-instruct-foundation BAAI/Infinity-Instruct - generic-instruct 1.0 ✓ BAAI/Infinity-Instruct baai-instruct-gen BAAI/Infinity-Instruct - generic-instruct 1.0 ✓ BAAI/Infinity-Instruct anthracite-stheno anthracite-org/Stheno-Data-Filtered - math-instruct 1.0 ✓ anthracite-org/Stheno-Data-Filtered opus-writing Nopm/Opus_WritingStruct apache-2.0 writing-instruct 2.0 ✓ Nopm/Opus_WritingStruct math-step xinlai/Math-Step-DPO-10K - math-instruct 2.0 ✓ Lai et al. (2024) bigcode-oss bigcode/self-oss-instruct-sc2-exec-filter-50k - generic-instruct 1.0 ✓ sc2-instruct everyday-conversations HuggingFaceTB/everyday-conversations apache-2.0 writing-instruct 3.0 ✓ HuggingFaceTB/everyday-conversations gsm8k hkust-nlp/gsm8k-fix mit math-instruct 1.0 ✗ Cobbe et al. (2021) no-robots HuggingFaceH4/no_robots cc-by-nc-4.0 writing-instruct 3.0 ✗ Ouyang et al. (2022) longwriter THUDM/LongWriter-6k apache-2.0 writing-instruct 2.0 ✓ Bai et al. (2024) webglm-qa THUDM/webglm-qa - generic-instruct 1.0 - Liu et al. (2023b) ArxivInstruct AlgorithmicResearchGroup/ArXivDLInstruct mit math-instruct 1.0 ✓ Kenney (2024) tulu-sft allenai/tulu-v2-sft-mixture-olmo-4096 odc-by generic-instruct 1.0 ✓ Groeneveld et al. (2024) P3 bigscience/P3 apache-2.0 generic-instruct 1.0 ✗ Sanh et al. (2021) OrcaSonnet Gryphe/Sonnet3.5-SlimOrcaDedupCleaned mit writing-instruct 2.0 ✓ Gryphe/Sonnet3.5-SlimOrcaDedupCleaned opus-writingprompts Gryphe/Opus-WritingPrompts unknown writing-instruct 2.0 ✓ Gryphe/Opus-WritingPrompts reddit-writing nothingiisreal/Reddit-Dirty-And-WritingPrompts apache-2.0 writing-instruct 2.0 ✗ Reddit-Dirty-And-WritingPrompts kalomaze-instruct nothingiisreal/Kalomaze-Opus-Instruct-25k-filtered apache-2.0 writing-instruct 2.0 ✓ Kalomaze-Opus-Instruct-25k lean-github internlm/Lean-Github apache-2.0 math-instruct 3.0 ✗ Wu et al. (2024) lean-workbook pkuAI4M/LeanWorkbook apache-2.0 math-instruct 3.0 ✗ Ying et al. (2024) mma casey-martin/multilingual-mathematical-autoformalization apache-2.0 math-instruct 3.0 ✗ Jiang et al. (2023) lean-dojo-informal AI4M/leandojo-informalized - math-instruct 3.0 ✗ Yang et al. (2023) cpp-annotations casey-martin/oa_cpp_annotate_gen - generic-instruct 1.0 ✓ moyix lean-tactics l3lab/ntp-mathlib-instruct-st - math-instruct 2.0 ✗ Hu et al. (2024) college-math ajibawa-2023/Maths-College apache-2.0 math 1.0 ✓ ajibawa-2023/Maths-College gradeschool-math ajibawa-2023/Maths-Grade-School apache-2.0 math 1.0 ✓ ajibawa-2023/Maths-Grade-School general-stories ajibawa-2023/General-Stories-Collection apache-2.0 synthetic-text 1.0 ✓ ajibawa-2023/General-Stories-Collection amps-mathematica XinyaoHu/AMPS_mathematica mit math 1.0 ✗ XinyaoHu/AMPS_mathematica amps-khan XinyaoHu/AMPS_khan mit math-instruct 1.0 ✗ XinyaoHu/AMPS_khan Magpie-300k Magpie-Align/Magpie-Pro-MT-300K-v0.1 llama3 generic-instruct 1.0 ✓ Xu et al. (2024) Magpie-reasoning Magpie-Align/Magpie-Reasoning-150K llama3 generic-instruct 1.0 ✓ Xu et al. (2024) prox-fineweb gair-prox/FineWeb-pro odc-by generic-text 1.0 ✗ Zhou et al. (2024) prox-c4 gair-prox/c4-pro odc-by generic-text 1.0 ✗ Zhou et al. (2024) prox-redpajama gair-prox/RedPajama-pro odc-by generic-text 1.0 ✗ Zhou et al. (2024) prox-open-web-math gair-prox/open-web-math-pro odc-by math 1.0 ✗ Zhou et al. (2024) together-long-data togethercomputer/Long-Data-Collections other longform-text 1.0 ✗ TogetherAI (2023) project-gutenberg-19 emozilla/pg19 apache-2.0 longform-text 1.0 ✗ Rae et al. (2019) mathgenie MathGenie/MathCode-Pile apache-2.0 math 1.0 ✗ Lu et al. (2024) reasoning-base KingNish/reasoning-base-20k apache-2.0 math 1.0 ✓ KingNish/reasoning-base-20k OpenMathInstruct-2 nvidia/OpenMathInstruct-2 nvidia-license math-instruct 1.0 ✓ Toshniwal et al. (2024a) Txt360-DM LLM360/TxT360 odc-by math 1.0 ✗ Liping Tang (2024) Txt360-ubuntu-chat LLM360/TxT360 odc-by Q&A-text 1.0 ✗ Liping Tang (2024) markdown-arxiv neuralwork/arxiver cc-by-nc-sa-4.0 scientific-text 2.0 ✗ neuralwork/arxiver