# 1 Scaling by Thinking in Continuous Space
spacing=nonfrench
Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach
Jonas Geiping 1 Sean McLeish 2 Neel Jain 2 John Kirchenbauer 2 Siddharth Singh 2 Brian R. Bartoldson 3 Bhavya Kailkhura 3 Abhinav Bhatele 2 Tom Goldstein 2 footnotetext: 1 ELLIS Institute Tübingen, Max-Planck Institute for Intelligent Systems, Tübingen AI Center 2 University of Maryland, College Park 3 Lawrence Livermore National Laboratory. Correspondence to: Jonas Geiping, Tom Goldstein < jonas@tue.ellis.eu, tomg@umd.edu >.
## Abstract
We study a novel language model architecture that is capable of scaling test-time computation by implicitly reasoning in latent space. Our model works by iterating a recurrent block, thereby unrolling to arbitrary depth at test-time. This stands in contrast to mainstream reasoning models that scale up compute by producing more tokens. Unlike approaches based on chain-of-thought, our approach does not require any specialized training data, can work with small context windows, and can capture types of reasoning that are not easily represented in words. We scale a proof-of-concept model to 3.5 billion parameters and 800 billion tokens. We show that the resulting model can improve its performance on reasoning benchmarks, sometimes dramatically, up to a computation load equivalent to 50 billion parameters.
Model: huggingface.co/tomg-group-umd/huginn-0125 Code and Data: github.com/seal-rg/recurrent-pretraining
Humans naturally expend more mental effort solving some problems than others. While humans are capable of thinking over long time spans by verbalizing intermediate results and writing them down, a substantial amount of thought happens through complex, recurrent firing patterns in the brain, before the first word of an answer is uttered.
Early attempts at increasing the power of language models focused on scaling model size, a practice that requires extreme amounts of data and computation. More recently, researchers have explored ways to enhance the reasoning capability of models by scaling test time computation. The mainstream approach involves post-training on long chain-of-thought examples to develop the model’s ability to verbalize intermediate calculations in its context window and thereby externalize thoughts.
<details>
<summary>x1.png Details</summary>

### Visual Description
\n
## Chart: Scaling up Test-Time Compute with Recurrent Depth
### Overview
This chart illustrates the relationship between Test-Time Compute Recurrence and Accuracy (%) for three different datasets: ARC challenge, GSM8K CoT, and OpenBookQA, across varying Materialized Parameters. The chart demonstrates how increasing the depth of computation during test time impacts the accuracy of each dataset, with performance scaling alongside the number of materialized parameters.
### Components/Axes
* **Title:** Scaling up Test-Time Compute with Recurrent Depth
* **X-axis:** Test-Time Compute Recurrence (Scale: 1, 4, 6, 8, 12, 20, 32, 48, 64)
* **Y-axis:** Accuracy (%) (Scale: 0 to 50)
* **Materialized Parameters (Top Axis):** 3.6B, 8.3B, 11.5B, 14.6B, 21.0B, 33.6B, 52.6B, 77.9B, 103B
* **Legend:**
* ARC challenge (Blue)
* GSM8K CoT (Orange)
* OpenBookQA (Green)
### Detailed Analysis
The chart displays three lines representing the accuracy of each dataset as Test-Time Compute Recurrence increases. The Materialized Parameters are displayed along the top of the chart, indicating the model size used for each data point.
**ARC challenge (Blue):**
* The line starts at approximately 22% accuracy at a recurrence of 1.
* It gradually increases to around 28% at a recurrence of 4.
* A steeper increase is observed between recurrences 4 and 8, reaching approximately 38%.
* The line plateaus around 42-45% accuracy from a recurrence of 12 onwards.
* Approximate data points: (1, 22%), (4, 28%), (6, 34%), (8, 38%), (12, 41%), (20, 43%), (32, 44%), (48, 44%), (64, 45%)
**GSM8K CoT (Orange):**
* The line begins at approximately 1% accuracy at a recurrence of 1.
* It remains very low until a recurrence of 6, where it starts to increase rapidly.
* Between recurrences 6 and 12, the accuracy jumps from roughly 2% to around 42%.
* The line then plateaus, fluctuating between 42% and 46% accuracy from a recurrence of 12 onwards.
* Approximate data points: (1, 1%), (4, 1%), (6, 2%), (8, 15%), (12, 42%), (20, 45%), (32, 45%), (48, 45%), (64, 46%)
**OpenBookQA (Green):**
* The line starts at approximately 24% accuracy at a recurrence of 1.
* It shows a moderate increase to around 30% at a recurrence of 4.
* The line continues to increase, reaching approximately 36% at a recurrence of 8.
* It plateaus around 40-42% accuracy from a recurrence of 12 onwards.
* Approximate data points: (1, 24%), (4, 30%), (6, 33%), (8, 36%), (12, 39%), (20, 41%), (32, 41%), (48, 41%), (64, 42%)
### Key Observations
* GSM8K CoT shows the most dramatic improvement in accuracy with increasing recurrence, starting from a very low baseline.
* ARC challenge and OpenBookQA exhibit more gradual improvements and reach higher plateaus.
* All three datasets show diminishing returns in accuracy beyond a recurrence of 12.
* Higher Materialized Parameters generally correlate with higher accuracy, particularly for GSM8K CoT.
* The ARC challenge consistently performs better than GSM8K CoT at lower recurrence values, but GSM8K CoT catches up and surpasses it at higher recurrence values.
### Interpretation
The data suggests that increasing the depth of computation during test time (Test-Time Compute Recurrence) can significantly improve the accuracy of language models, especially for datasets like GSM8K CoT that benefit from more complex reasoning. The plateauing effect observed at higher recurrence values indicates that there is a limit to the benefits of increasing computation depth, and that other factors (such as model size or training data) may become more important. The varying performance of the three datasets suggests that the optimal level of test-time computation depth may depend on the specific characteristics of the task. The strong correlation between Materialized Parameters and accuracy highlights the importance of model size in achieving high performance. The fact that GSM8K CoT starts with very low accuracy and then rapidly improves suggests that it is a particularly challenging task that requires significant computational resources to solve effectively. The diminishing returns observed at higher recurrence values suggest that there is a trade-off between computational cost and accuracy gains.
</details>
Figure 1: We train a 3.5B parameter language model with depth recurrence. At test time, the model can iterate longer to use more compute and improve its performance. Instead of scaling test-time reasoning by “verbalizing” in long Chains-of-Thought, the model improves entirely by reasoning in latent space. Tasks that require less reasoning like OpenBookQA converge quicker than tasks like GSM8k, which effectively make use of more compute.
However, the constraint that expensive internal reasoning must always be projected down to a single verbalized next token appears wasteful; it is plausible that models could be more competent if they were able to natively “think” in their continuous latent space. One way to unlock this untapped dimension of additional compute involves adding a recurrent unit to a model. This unit runs in a loop, iteratively processing and updating its hidden state and enabling computations to be carried on indefinitely. While this is not currently the dominant paradigm, this idea is foundational to machine learning and has been (re-)discovered in every decade, for example as recurrent neural networks, diffusion models, and as universal or looped transformers.
In this work, we show that depth-recurrent language models can learn effectively, be trained in an efficient manner, and demonstrate significant performance improvements under the scaling of test-time compute. Our proposed transformer architecture is built upon a latent depth-recurrent block that is run for a randomly sampled number of iterations during training. We show that this paradigm can scale to several billion parameters and over half a trillion tokens of pretraining data. At test-time, the model can improve its performance through recurrent reasoning in latent space, enabling it to compete with other open-source models that benefit from more parameters and training data. Additionally, we show that recurrent depth models naturally support a number of features at inference time that require substantial tuning and research effort in non-recurrent models, such as per-token adaptive compute, (self)-speculative decoding, and KV-cache sharing. We finish out our study by tracking token trajectories in latent space, showing that a number of interesting computation behaviors simply emerge with scale, such as the model rotating shapes in latent space for numerical computations.
## 2 Why Train Models with Recurrent Depth?
Recurrent layers enable a transformer model to perform arbitrarily many computations before emitting a token. In principle, recurrent mechanisms provide a simple solution for test-time compute scaling. Compared to a more standard approach of long context reasoning OpenAI (2024); DeepSeek-AI et al. (2025), latent recurrent thinking has several advantages.
- Latent reasoning does not require construction of bespoke training data. Chain-of-thought reasoning requires the model to be trained on long demonstrations that are constructed in the domain of interest. In contrast, our proposed latent reasoning models can train with a variable compute budget, using standard training data with no specialized demonstrations, and enhance their abilities at test-time if given additional compute.
- Latent reasoning models require less memory for training and inference than chain-of-thought reasoning models. Because the latter require extremely long context windows, specialized training methods such as token-parallelization Liu et al. (2023a) may be needed.
- Recurrent-depth networks perform more FLOPs per parameter than standard transformers, significantly reducing communication costs between accelerators at scale. This especially enables higher device utilization when training with slower interconnects.
- By constructing an architecture that is compute-heavy and small in parameter count, we hope to set a strong prior towards models that solve problems by “thinking”, i.e. by learning meta-strategies, logic and abstraction, instead of memorizing. The strength of recurrent priors for learning complex algorithms has already been demonstrated in the “deep thinking” literature Schwarzschild et al. (2021b); Bansal et al. (2022); Schwarzschild et al. (2023).
On a more philosophical note, we hope that latent reasoning captures facets of human reasoning that defy verbalization, such as spatial thinking, physical intuition or (motor) planning. Over many iterations of the recurrent process, reasoning in a high-dimensional vector space would enable the deep exploration of multiple directions simultaneously, instead of linear thinking, leading to a system capable of exhibiting novel and complex reasoning behavior.
Scaling compute in this manner is not at odds with scaling through extended (verbalized) inference scaling (Shao et al., 2024), or scaling parameter counts in pretraining (Kaplan et al., 2020), we argue it may build a third axis on which to scale model performance.
#### ———————— Table of Contents ————————
- Section 3 introduces our latent recurrent-depth model architecture and training objective.
- Section 4 describes the data selection and engineering of our large-scale training run on Frontier, an AMD cluster.
- Section 5 reports benchmark results, showing how the model improves when scaling inference compute.
- Section 6 includes several application examples showing how recurrent models naturally simplify LLM usecases.
- Section 7 visualizes what computation patterns emerge at scale with this architecture and training objective, showing that context-dependent behaviors emerge in latent space, such as “orbiting” when responding to prompts requiring numerical reasoning.
<details>
<summary>x2.png Details</summary>

### Visual Description
\n
## Diagram: Recurrent Neural Network Architecture with Prelude and Coda
### Overview
The image depicts a diagram of a recurrent neural network (RNN) architecture, incorporating a "Prelude" and "Coda" component alongside multiple "Recurrent Blocks". The diagram illustrates the flow of information from an input string "Hello" to an output string "World", with intermediate states represented by 's' variables and input injections denoted by 'e'. The diagram also includes a noise component represented by a normal distribution.
### Components/Axes
* **Prelude:** Represented by a light blue square labeled "P".
* **Recurrent Block:** Represented by a green square labeled "R". Multiple instances of this block are shown in a sequence.
* **Coda:** Represented by a red square labeled "C".
* **Input Injection:** Represented by a dotted gray arrow labeled "e".
* **Residual Stream:** Represented by a solid black arrow.
* **Input:** "Hello"
* **Output:** "World"
* **Noise:** Represented by the mathematical notation `N(0, σ²Iₙʰ)`
* **States:** `s₀`, `s₁`, `sR`
* **Legend:** Located at the bottom-right of the image, associating colors with components.
### Detailed Analysis or Content Details
The diagram shows the following flow:
1. The input string "Hello" enters the "Prelude" (P).
2. The output of the "Prelude" is connected to the first "Recurrent Block" (R) via the "Residual Stream".
3. A noise component `N(0, σ²Iₙʰ)` is also fed into the first "Recurrent Block".
4. Each "Recurrent Block" receives an "Input Injection" (e) from the previous block. The dotted gray arrows indicate this injection.
5. The output of the last "Recurrent Block" is fed into the "Coda" (C).
6. The "Coda" produces the output string "World".
7. The states are labeled sequentially as `s₀`, `s₁`, and `sR`, indicating the state of the recurrent blocks.
### Key Observations
* The architecture emphasizes a sequential processing of information through the recurrent blocks.
* The "Prelude" and "Coda" components suggest a pre-processing and post-processing stage, respectively.
* The "Input Injection" mechanism allows information to be passed between recurrent blocks.
* The inclusion of noise `N(0, σ²Iₙʰ)` suggests a stochastic element in the model.
### Interpretation
This diagram illustrates a recurrent neural network architecture designed for sequence-to-sequence tasks, such as machine translation or text generation. The "Prelude" likely handles initial embedding or encoding of the input sequence. The recurrent blocks process the sequence step-by-step, maintaining a hidden state (`s₀`, `s₁`, `sR`) that captures information about the past. The "Input Injection" allows the network to incorporate context from previous time steps. Finally, the "Coda" decodes the hidden state into the output sequence ("World"). The noise component suggests a regularization technique or a way to introduce variability in the model's predictions. The use of a residual stream indicates a potential mechanism for mitigating the vanishing gradient problem, common in deep recurrent networks. The diagram is a high-level representation and does not specify the internal workings of the "Prelude", "Recurrent Block", or "Coda" components.
</details>
Figure 2: A visualization of the Architecture, as described in Section 3. Each block consists of a number of sub-layers. The blue prelude block embeds the inputs into latent space, where the green shared recurrent block is a block of layers that is repeated to compute the final latent state, which is decoded by the layers of the red coda block.
## 3 A scalable recurrent architecture
In this section we will describe our proposed architecture for a transformer with latent recurrent depth, discussing design choices and small-scale ablations. A diagram of the architecture can be found in Figure 2. We always refer to the sequence dimension as $n$ , the hidden dimension of the model as $h$ , and its vocabulary as the set $V$ .
### 3.1 Macroscopic Design
The model is primarily structured around decoder-only transformer blocks (Vaswani et al., 2017; Radford et al., 2019). However these blocks are structured into three functional groups, the prelude $P$ , which embeds the input data into a latent space using multiple transformer layers, then the core recurrent block $R$ , which is the central unit of recurrent computation modifying states $\mathbf{s}\in\mathbb{R}^{n\times h}$ , and finally the coda $C$ , which un-embeds from latent space using several layers and also contains the prediction head of the model. The core block is set between the prelude and coda blocks, and by looping the core we can put an indefinite amount of verses in our song.
Given a number of recurrent iterations $r$ , and a sequence of input tokens $\mathbf{x}\in V^{n}$ these groups are used in the following way to produce output probabilities $\mathbf{p}\in\mathbb{R}^{n\times|V|}$
| | $\displaystyle\mathbf{e}$ | $\displaystyle=P(\mathbf{x})$ | |
| --- | --- | --- | --- |
where $\sigma$ is some standard deviation for initializing the random state. This process is shown in Figure 2. Given an init random state $\mathbf{s}_{0}$ , the model repeatedly applies the core block $R$ , which accepts the latent state $\mathbf{s}_{i-1}$ and the embedded input $\mathbf{e}$ and outputs a new latent state $\mathbf{s}_{i}$ . After finishing all iterations, the coda block processes the last state and produces the probabilities of the next token.
This architecture is based on deep thinking literature, where it is shown that injecting the latent inputs $\mathbf{e}$ in every step (Bansal et al., 2022) and initializing the latent vector with a random state stabilizes the recurrence and promotes convergence to a steady state independent of initialization, i.e. path independence (Anil et al., 2022).
#### Motivation for this Design.
This recurrent design is the minimal setup required to learn stable iterative operators. A good example is gradient descent of a function $E(\mathbf{x},\mathbf{y})$ , where $\mathbf{x}$ may be the variable of interest and $\mathbf{y}$ the data. Gradient descent on this function starts from an initial random state, here $\mathbf{x}_{0}$ , and repeatedly applies a simple operation (the gradient of the function it optimizes), that depends on the previous state $\mathbf{x}_{k}$ and data $\mathbf{y}$ . Note that we need to use $\mathbf{y}$ in every step to actually optimize our function. Similarly we repeatedly inject the data $\mathbf{e}$ in our set-up in every step of the recurrence. If $\mathbf{e}$ was provided only at the start, e.g. via $\mathbf{s}_{0}=\mathbf{e}$ , then the iterative process would not be stable Stable in the sense that $R$ cannot be a monotone operator if it does not depend on $\mathbf{e}$ , and so cannot represent gradient descent on strictly convex, data-dependent functions, (Bauschke et al., 2011), as its solution would depend only on its boundary conditions.
The structure of using several layers to embed input tokens into a hidden latent space is based on empirical results analyzing standard fixed-depth transformers (Skean et al., 2024; Sun et al., 2024; Kaplan et al., 2024). This body of research shows that the initial and the end layers of LLMs are noticeably different, whereas middle layers are interchangeable and permutable. For example, Kaplan et al. (2024) show that within a few layers standard models already embed sub-word tokens into single concepts in latent space, on which the model then operates.
**Remark 3.1 (Is this a Diffusion Model?)**
*This iterative architecture will look familiar to the other modern iterative modeling paradigm, diffusion models (Song and Ermon, 2019), especially latent diffusion models (Rombach et al., 2022). We ran several ablations with iterative schemes even more similar to diffusion models, such as $\mathbf{s}_{i}=R(\mathbf{e},\mathbf{s}_{i-1})+\mathbf{n}$ where $\mathbf{n}\sim\mathcal{N}(\mathbf{0},\sigma_{i}I_{n\cdot h})$ , but find the injection of noise not to help in our preliminary experiments, which is possibly connected to our training objective. We also evaluated and $\mathbf{s}_{i}=R_{i}(\mathbf{e},\mathbf{s}_{i-1})$ , i.e. a core block that takes the current step as input (Peebles and Xie, 2023), but find that this interacts badly with path independence, leading to models that cannot extrapolate.*
### 3.2 Microscopic Design
Within each group, we broadly follow standard transformer layer design. Each block contains multiple layers, and each layer contains a standard, causal self-attention block using RoPE (Su et al., 2021) with a base of $50000$ , and a gated SiLU MLP (Shazeer, 2020). We use RMSNorm (Zhang and Sennrich, 2019) as our normalization function. The model has learnable biases on queries and keys, and nowhere else. To stabilize the recurrence, we order all layers in the following “sandwich” format, using norm layers $n_{i}$ , which is related, but not identical to similar strategies in (Ding et al., 2021; Team Gemma et al., 2024):
| | $\displaystyle\hat{\mathbf{x}_{l}}=$ | $\displaystyle n_{2}\left(\mathbf{x}_{l-1}+\textnormal{Attn}(n_{1}(\mathbf{x}_{ l-1}))\right)$ | |
| --- | --- | --- | --- |
While at small scales, most normalization strategies, e.g. pre-norm, post-norm and others, work almost equally well, we ablate these options and find that this normalization is required to train the recurrence at scale Note also that technically $n_{3}$ is superfluous, but we report here the exact norm setup with which we trained the final model..
Given an embedding matrix $E$ and embedding scale $\gamma$ , the prelude block first embeds input tokens $\mathbf{x}$ as $\gamma E(\mathbf{x})$ , and then to applies $l_{P}$ many prelude layers with the layout described above.
Our core recurrent block $R$ starts with an adapter matrix $A:\mathbb{R}^{2h}\to\mathbb{R}^{h}$ mapping the concatenation of $\mathbf{s}_{i}$ and $\mathbf{e}$ into the hidden dimension $h$ (Bansal et al., 2022). While re-incorporation of initial embedding features via addition rather than concatenation works equally well for smaller models, we find that concatenation works best at scale. This is then fed into $l_{R}$ transformer layers. At the end of the core block the output is again rescaled with an RMSNorm $n_{c}$ .
The coda contains $l_{C}$ layers, normalization by $n_{c}$ , and projection into the vocabulary using tied embeddings $E^{T}$ .
In summary, we can summarize the architecture by the triplet $(l_{P},l_{R},l_{C})$ , describing the number of layers in each stage, and by the number of recurrences $r$ , which may vary in each forward pass. We train a number of small-scale models with shape $(1,4,1)$ and hidden size $h=1024$ , in addition to a large model with shape $(2,4,2)$ and $h=5280$ . This model has only $8$ “real” layers, but when the recurrent block is iterated, e.g. 32 times, it unfolds to an effective depth of $2+4r+2=132$ layers, constructing computation chains that can be deeper than even the largest fixed-depth transformers (Levine et al., 2021; Merrill et al., 2022).
### 3.3 Training Objective
#### Training Recurrent Models through Unrolling.
To ensure that the model can function when we scale up recurrent iterations at test-time, we randomly sample iteration counts during training, assigning a random number of iterations $r$ to every input sequence (Schwarzschild et al., 2021b). We optimize the expectation of the loss function $L$ over random samples $x$ from distribution $X$ and random iteration counts $r$ from distribution $\Lambda$ .
$$
\mathcal{L}(\theta)=\mathbb{E}_{\mathbf{x}\in X}\mathbb{E}_{r\sim\Lambda}L
\left(m_{\theta}(\mathbf{x},r),\mathbf{x}^{\prime}\right).
$$
Here, $m$ represents the model output, and $\mathbf{x}^{\prime}$ is the sequence $\mathbf{x}$ shifted left, i.e., the next tokens in the sequence $\mathbf{x}$ . We choose $\Lambda$ to be a log-normal Poisson distribution. Given a targeted mean recurrence $\bar{r}+1$ and a variance that we set to $\sigma=\frac{1}{2}$ , we can sample from this distribution via
$$
\displaystyle\tau \displaystyle\sim\mathcal{N}(\log(\bar{r})-\frac{1}{2}\sigma^{2},\sigma) \displaystyle r \displaystyle\sim\mathcal{P}(e^{\tau})+1, \tag{1}
$$
given the normal distribution $\mathcal{N}$ and Poisson distribution $\mathcal{P}$ , see Figure 3. The distribution most often samples values less than $\bar{r}$ , but it contains a heavy tail of occasional events in which significantly more iterations are taken.
<details>
<summary>x3.png Details</summary>

### Visual Description
\n
## Chart: Density Plot of Sampled 'r' Values
### Overview
The image presents a density plot illustrating the distribution of sampled 'r' values. The plot shows the probability density of the 'r' values, with the x-axis representing the 'r' values themselves and the y-axis representing the density. Several vertical lines indicate the mean, median, and mode of the distribution.
### Components/Axes
* **X-axis:** Labeled "Sampled r", ranging from 0 to 150, with tick marks at intervals of 25.
* **Y-axis:** Labeled "Density", ranging from 0.00 to 0.03, with tick marks at intervals of 0.01.
* **Curve:** A solid black line representing the density function.
* **Legend:** Located at the bottom-center of the image, providing labels and values for the vertical lines:
* Solid Black Line: "Density"
* Dashed Blue Line: "Mean = 33.0"
* Dashed Green Line: "Median = 29.0"
* Dashed Red Line: "Mode = 24.0"
### Detailed Analysis
The density curve starts at approximately 0.00 at r=0, rises to a peak density of approximately 0.027 at r=24, then gradually declines to approximately 0.00 at r=150.
* **Mean:** A vertical dashed blue line is positioned at r = 33.0.
* **Median:** A vertical dashed green line is positioned at r = 29.0.
* **Mode:** A vertical dashed red line is positioned at r = 24.0.
The distribution is right-skewed, meaning it has a longer tail on the right side. The mode (24.0) is less than the median (29.0), which is less than the mean (33.0), confirming the right skew.
### Key Observations
* The distribution is unimodal, with a single peak at the mode.
* The mean, median, and mode are relatively close together, suggesting the distribution is not extremely skewed, but a slight right skew is present.
* The density decreases rapidly after the mode, indicating that values of 'r' greater than 30 are less frequent.
### Interpretation
The data suggests that the sampled 'r' values are concentrated around 24, with a tendency towards higher values, as indicated by the mean being greater than the mode. The right skew implies that there are some relatively high 'r' values that pull the mean upwards, but these values are not very common. This type of distribution could arise in scenarios where 'r' represents a quantity that is bounded at zero but has the potential to increase significantly, but is more likely to be small. The difference between the mean, median, and mode provides insight into the shape of the distribution and the presence of outliers or skewness. The data is descriptive of a single variable, 'r', and does not show relationships between variables.
</details>
Figure 3: We use a log-normal Poisson Distribution to sample the number of recurrent iterations for each training step.
#### Truncated Backpropagation.
To keep computation and memory low at train time, we backpropagate through only the last $k$ iterations of the recurrent unit. This enables us to train with the heavy-tailed Poisson distribution $\Lambda$ , as maximum activation memory and backward compute is now independent of $r$ . We fix $k=8$ in our main experiments. At small scale, this works as well as sampling $k$ uniformly, but with set fixed, the overall memory usage in each step of training is equal. Note that the prelude block still receives gradient updates in every step, as its output $\mathbf{e}$ is injected in every step. This setup resembles truncated backpropagation through time, as commonly done with RNNs, although our setup is recurrent in depth rather than time (Williams and Peng, 1990; Mikolov et al., 2011).
## 4 Training a large-scale recurrent-depth Language Model
After verifying that we can reliably train small test models up to 10B tokens, we move on to larger-scale runs. Given our limited compute budget, we could either train multiple tiny models too small to show emergent effects or scaling, or train a single medium-scale model. Based on this, we prepared for a single run, which we detail below.
### 4.1 Training Setup
We describe the training setup, separated into architecture, optimization setup and pretraining data. We publicly release all training data, pretraining code, and a selection of intermediate model checkpoints.
#### Pretraining Data.
Given access to only enough compute for a single large scale model run, we opted for a dataset mixture that maximized the potential for emergent reasoning behaviors, not necessarily for optimal benchmark performance. Our final mixture is heavily skewed towards code and mathematical reasoning data with (hopefully) just enough general webtext to allow the model to acquire standard language modeling abilities. All sources are publicly available. We provide an overview in Figure 4. Following Allen-Zhu and Li (2024), we directly mix relevant instruction data into the pretraining data. However, due to compute and time constraints, we were not able to ablate this mixture. We expect that a more careful data preparation could further improve the model’s performance. We list all data sources in Appendix C.
#### Tokenization and Packing Details.
We construct a vocabulary of $65536$ tokens via BPE (Sennrich et al., 2016), using the implementation of Dagan (2024). In comparison to conventional tokenizer training, we construct our tokenizer directly on the instruction data split of our pretraining corpus, to maximize tokenization efficiency on the target domain. We also substantially modify the pre-tokenization regex (e.g. of Dagan et al. (2024)) to better support code, contractions and LaTeX. We include a <|begin_text|> token at the start of every document. After tokenizing our pretraining corpus, we pack our tokenized documents into sequences of length 4096. When packing, we discard document ends that would otherwise lack previous context, to fix an issue described as the “grounding problem” in Ding et al. (2024), aside from several long-document sources of mathematical content, which we preserve in their entirety.
<details>
<summary>x4.png Details</summary>

### Visual Description
\n
## Pie Chart: Data Distribution by Category
### Overview
The image is a pie chart illustrating the distribution of data across several categories. The chart is segmented into eleven distinct sections, each representing a different category and its corresponding percentage of the total. A legend is positioned to the right of the chart, providing color-coded labels for each category.
### Components/Axes
The chart itself is a circular representation of data. There are no explicit axes, as pie charts represent proportions of a whole. The legend, located on the right side, lists the following categories and their associated colors:
* **generic-text:** 28.71% (Blue)
* **code:** 25.36% (Orange)
* **scientific-text:** 18.73% (Green)
* **synthetic-text:** 8.14% (Red)
* **longform-text:** 7.50% (Purple)
* **math:** 6.14% (Brown)
* **generic-instruct:** 2.09% (Pink)
* **Q&A-text:** 1.58% (Gray)
* **math-instruct:** 1.51% (Yellow)
* **writing-instruct:** 0.12% (Cyan)
* **misc-reasoning:** 0.11% (Dark Blue)
### Detailed Analysis
The largest segment of the pie chart is "generic-text" at 28.71%, represented by a blue color. The second largest segment is "code" at 25.36%, represented by an orange color. "scientific-text" occupies 18.73% of the chart, colored green. The remaining categories have significantly smaller proportions.
* **generic-text:** 28.71%
* **code:** 25.36%
* **scientific-text:** 18.73%
* **synthetic-text:** 8.14%
* **longform-text:** 7.50%
* **math:** 6.14%
* **generic-instruct:** 2.09%
* **Q&A-text:** 1.58%
* **math-instruct:** 1.51%
* **writing-instruct:** 0.12%
* **misc-reasoning:** 0.11%
The segments "writing-instruct" and "misc-reasoning" are very small, representing only 0.12% and 0.11% respectively.
### Key Observations
The data is heavily concentrated in the "generic-text", "code", and "scientific-text" categories, which together account for approximately 72.8% (28.71% + 25.36% + 18.73%) of the total. The remaining categories contribute relatively little to the overall distribution. The chart demonstrates a clear dominance of these three categories.
### Interpretation
The pie chart suggests that the dataset being represented is primarily composed of "generic text", "code", and "scientific text". This could indicate the nature of the data source or the focus of a particular analysis. The small proportions of "writing-instruct" and "misc-reasoning" suggest these are less frequent or less significant components of the dataset. The chart provides a clear visual representation of the relative importance of each category, allowing for quick identification of the dominant elements. The data could be related to the composition of a training dataset for a language model, where these categories represent the types of text used for training.
</details>
Figure 4: Distribution of data sources that are included during training. The majority of our data is comprised of generic web-text, scientific writing and code.
#### Architecture and Initialization.
We scale the architecture described in Section 3, setting the layers to $(2,4,2)$ , and train with a mean recurrence value of $\bar{r}=32$ . We mainly scale by increasing the hidden size to $h=5280$ , which yields $55$ heads of size of $96$ . The MLP inner dimension is $17920$ and the RMSNorm $\varepsilon$ is $10^{-6}$ . Overall this model shape has about $1.5$ B parameters in non-recurrent prelude and head, $1.5$ B parameters in the core recurrent block, and $0.5$ B in the tied input embedding.
At small scales, most sensible initialization schemes work. However, at larger scales, we use the initialization of Takase et al. (2024) which prescribes a variance of $\sigma_{h}^{2}=\frac{2}{5h}$ . We initialize all parameters from a truncated normal distribution (truncated at $3\sigma$ ) with this variance, except all out-projection layers, where the variance is set to $\sigma_{\textnormal{out}}^{2}=\frac{1}{5hl}$ , for $l=l_{P}+\bar{r}l_{R}+l_{C}$ the number of effective layers, which is 132 for this model. As a result, the out-projection layers are initialized with fairly small values (Goyal et al., 2018). The output of the embedding layer is scaled by $\sqrt{h}$ . To match this initialization, the state $s_{0}$ is also sampled from a truncated normal distribution, here with variance $\sigma_{s}^{2}=\frac{2}{5}$ .
#### Locked-Step Sampling.
To enable synchronization between parallel workers, we sample a single depth $r$ for each micro-batch of training, which we synchronize across workers (otherwise workers would idle while waiting for the model with the largest $r$ to complete its backward pass). We verify at small scale that this modification improves compute utilization without impacting convergence speed, but note that at large batch sizes, training could be further improved by optimally sampling and scheduling independent steps $r$ on each worker, to more faithfully model the expectation over steps in Equation 1.
#### Optimizer and Learning Rate Schedule.
We train using the Adam optimizer with decoupled weight regularization ( $\beta_{1}=0.9$ , $\beta_{2}=0.95$ , $\eta=$5\text{\times}{10}^{-4}$$ ) (Kingma and Ba, 2015; Loshchilov and Hutter, 2017), modified to include update clipping (Wortsman et al., 2023b) and removal of the $\varepsilon$ constant as in Everett et al. (2024). We clip gradients above $1$ . We train with warm-up and a constant learning rate (Zhai et al., 2022; Geiping and Goldstein, 2023), warming up to our maximal learning rate within the first $4096$ steps.
### 4.2 Compute Setup and Hardware
We train this model using compute time allocated on the Oak Ridge National Laboratory’s Frontier supercomputer. This HPE Cray system contains 9408 compute nodes with AMD MI250X GPUs, connected via 4xHPE Slingshot-11 NICs. The scheduling system is orchestrated through SLURM. We train in bfloat16 mixed precision using a PyTorch-based implementation (Zamirai et al., 2021).
<details>
<summary>x5.png Details</summary>

### Visual Description
\n
## Charts: Training Dynamics Comparison
###
</details>
Figure 5: Plots of the initial 10000 steps for the first two failed attempts and the final, successful run (“Main”). Note the hidden state collapse (middle) and collapse of the recurrence (right) in the first two failed runs, underlining the importance of our architecture and initialization in inducing a recurrent model and explain the underperformance of these runs in terms of pretraining loss (left).
#### Device Speed and Parallelization Strategy.
Nominally, each MI250X chip Technically, each node contains 4 dual-chip MI250X cards, but its main software stack (ROCm runtime) treats these chips as fully independent. achieves 192 TFLOP per GPU (AMD, 2021). For a single matrix multiplication, we measure a maximum achievable speed on these GPUs of 125 TFLOP/s on our software stack (ROCM 6.2.0, PyTorch 2.6 pre-release 11/02) (Bekman, 2023). Our implementation, using extensive PyTorch compilation and optimization of the hidden dimension to $h=5280$ achieves a single-node training speed of 108.75 TFLOP/s, i.e. 87% AFU (“Achievable Flop Utilization”). Due to the weight sharing inherent in our recurrent design, even our largest model is still small enough to be trained using only data (not tensor) parallelism, with only optimizer sharding (Rajbhandari et al., 2020) and gradient checkpointing on a per-iteration granularity. With a batch size of 1 per GPU we end up with a global batch size of 16M tokens per step, minimizing inter-GPU communication bandwidth.
When we run at scale on 4096 GPUs, we achieve 52-64 TFLOP/s per GPU, i.e. 41%-51% AFU, or 1-1.2M tokens per second. To achieve this, we wrote a hand-crafted distributed data parallel implementation to circumvent a critical AMD interconnect issue, which we describe in more detail in Section A.2. Overall, we believe this may be the largest language model training run to completion in terms of number of devices used in parallel on an AMD cluster, as of time of writing.
#### Training Timeline.
Training proceeded through 21 segments of up to 12 hours, which scheduled on Frontier mostly in early December 2024. We also ran a baseline comparison, where we train the same architecture but in a feedforward manner with only 1 pass through the core/recurrent block. This trained with the same setup for 180B tokens on 256 nodes with a batch size of 2 per GPU. Ultimately, we were able to schedule 795B tokens of pretraining of the main model. Due to our constant learning rate schedule, we were able to add additional segments “on-demand”, when an allocation happened to be available.
<details>
<summary>x6.png Details</summary>

### Visual Description
\n
## Chart: Loss vs. Step (Log Scale)
### Overview
The image presents a line chart illustrating the relationship between Loss and Step, both plotted on a logarithmic scale. The chart depicts a decreasing trend of Loss as Step increases, indicating a learning or optimization process.
### Components/Axes
* **X-axis:** "Step (log)" - Scale is logarithmic, ranging from approximately 10<sup>1</sup> to 10<sup>4</sup>.
* **Y-axis:** "Loss" - Scale is linear, ranging from approximately 0 to 10.
* **Data Series:** A single blue line representing the Loss value at each Step.
* **Title:** None explicitly present.
* **Grid:** A light gray grid is present to aid in reading values.
### Detailed Analysis
The blue line representing Loss vs. Step exhibits a steep downward slope initially, followed by a gradual decrease and eventual leveling off.
* **Initial Phase (10<sup>1</sup> to 10<sup>2</sup>):** Loss decreases rapidly from approximately 11.5 to around 6.
* **Intermediate Phase (10<sup>2</sup> to 10<sup>3</sup>):** The rate of decrease slows down, with Loss falling from approximately 6 to around 2.5.
* **Final Phase (10<sup>3</sup> to 10<sup>4</sup>):** Loss continues to decrease, but at a much slower rate, leveling off around a value of approximately 1.5. There is some fluctuation in this region.
Approximate data points (estimated from the graph):
* Step = 10<sup>1</sup>, Loss ≈ 11.5
* Step = 10<sup>2</sup>, Loss ≈ 6
* Step = 10<sup>3</sup>, Loss ≈ 2.5
* Step = 10<sup>4</sup>, Loss ≈ 1.5
### Key Observations
* The chart demonstrates a clear decreasing trend in Loss as Step increases.
* The initial decrease in Loss is much more significant than the later decrease.
* The Loss appears to converge towards a stable value around 1.5 after approximately 10<sup>3</sup> steps.
* There is some noise or fluctuation in the Loss values in the final phase, suggesting that the optimization process may be approaching a local minimum or encountering some instability.
### Interpretation
This chart likely represents the training process of a machine learning model. The "Step" variable likely refers to the number of training iterations or updates, while "Loss" represents the error or cost function being minimized. The decreasing Loss indicates that the model is learning and improving its performance over time.
The initial steep decrease suggests rapid learning in the early stages of training. As training progresses, the rate of learning slows down, which is typical as the model approaches an optimal solution. The leveling off of the Loss curve suggests that the model has converged to a stable state, and further training may not yield significant improvements. The fluctuations in the final phase could indicate the need for further hyperparameter tuning or a different optimization algorithm.
The use of a logarithmic scale for both axes is significant. It allows for visualization of a wide range of values and highlights the relative changes in Loss and Step. The logarithmic scale emphasizes the initial rapid decrease in Loss, which might be obscured on a linear scale.
</details>
<details>
<summary>x7.png Details</summary>

### Visual Description
\n
## Line Chart: Validation Perplexity vs. Step for Different Recurrence Values
### Overview
This chart displays the relationship between Validation Perplexity and Step (both on a logarithmic scale) for different Recurrence values. The chart aims to show how the model's perplexity decreases as the number of steps increases, and how this decrease varies with different recurrence settings.
### Components/Axes
* **X-axis:** Step (log scale), ranging from approximately 10<sup>2</sup> to 10<sup>4</sup>.
* **Y-axis:** Validation Perplexity (log scale), ranging from approximately 10<sup>2</sup> to 10<sup>3</sup>.
* **Legend:** Located in the top-right corner, indicating the Recurrence values for each line: 1, 4, 8, 16, 32, and 64.
* **Gridlines:** Present to aid in reading values.
* **Data Series:** Six lines, each representing a different Recurrence value.
### Detailed Analysis
The chart shows six lines, each representing a different recurrence value. The lines represent the validation perplexity as a function of the step.
* **Recurrence = 1 (Blue Line):** Starts at approximately 10<sup>3</sup> perplexity and initially decreases rapidly. Around a step of 10<sup>3</sup>, the decrease slows down, and the line exhibits significant oscillations, leveling off around a perplexity of approximately 80-120.
* **Recurrence = 4 (Orange Line):** Starts at approximately 10<sup>2</sup> perplexity and decreases more smoothly than the blue line. It reaches a perplexity of around 10-20 and remains relatively stable.
* **Recurrence = 8 (Green Line):** Starts at approximately 10<sup>2</sup> perplexity and decreases rapidly, similar to the orange line. It reaches a perplexity of around 5-10 and remains relatively stable.
* **Recurrence = 16 (Red Line):** Starts at approximately 10<sup>2</sup> perplexity and decreases rapidly, similar to the green line. It reaches a perplexity of around 2-5 and remains relatively stable.
* **Recurrence = 32 (Purple Line):** Starts at approximately 10<sup>2</sup> perplexity and decreases rapidly, similar to the red line. It reaches a perplexity of around 1-3 and remains relatively stable.
* **Recurrence = 64 (Brown Line):** Starts at approximately 10<sup>2</sup> perplexity and decreases rapidly, similar to the purple line. It reaches a perplexity of around 1-2 and remains relatively stable.
All lines, except for the blue line (Recurrence = 1), show a consistent downward trend and then plateau.
### Key Observations
* Higher recurrence values (32, 64) achieve lower perplexity values and stabilize faster than lower recurrence values.
* Recurrence = 1 exhibits significant oscillations and does not reach as low a perplexity as the other recurrence values.
* The perplexity decreases rapidly initially for all recurrence values, then the rate of decrease slows down.
* The chart demonstrates a clear trade-off between recurrence and validation perplexity.
### Interpretation
The data suggests that increasing the recurrence value generally leads to a lower validation perplexity, indicating a better model fit. However, the recurrence value of 1 is an outlier, exhibiting instability and a higher perplexity. This could indicate that a recurrence of 1 is insufficient for capturing the dependencies in the data. The plateauing of the lines suggests that there is a point of diminishing returns, where increasing the step further does not significantly improve the model's performance. The logarithmic scales on both axes highlight the initial rapid improvement followed by a slower convergence. The chart is likely demonstrating the effect of different memory lengths (recurrence) on the performance of a recurrent neural network or similar sequential model. The oscillations in the blue line could be due to the model struggling to learn long-range dependencies with a small recurrence value.
</details>
Figure 6: Left: Plot of pretrain loss over the 800B tokens on the main run. Right: Plot of val ppl at recurrent depths 1, 4, 8, 16, 32, 64. During training, the model improves in perplexity on all levels of recurrence.
Table 1: Results on lm-eval-harness tasks zero-shot across various open-source models. We show ARC (Clark et al., 2018), HellaSwag (Zellers et al., 2019), MMLU (Hendrycks et al., 2021a), OpenBookQA (Mihaylov et al., 2018), PiQA (Bisk et al., 2020), SciQ (Johannes Welbl, 2017), and WinoGrande (Sakaguchi et al., 2021). We report normalized accuracy when provided.
| Model random Amber | Param 7B | Tokens 1.2T | ARC-E 25.0 65.70 | ARC-C 25.0 37.20 | HellaSwag 25.0 72.54 | MMLU 25.0 26.77 | OBQA 25.0 41.00 | PiQA 50.0 78.73 | SciQ 25.0 88.50 | WinoGrande 50.0 63.22 |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| Pythia-2.8b | 2.8B | 0.3T | 58.00 | 32.51 | 59.17 | 25.05 | 35.40 | 73.29 | 83.60 | 57.85 |
| Pythia-6.9b | 6.9B | 0.3T | 60.48 | 34.64 | 63.32 | 25.74 | 37.20 | 75.79 | 82.90 | 61.40 |
| Pythia-12b | 12B | 0.3T | 63.22 | 34.64 | 66.72 | 24.01 | 35.40 | 75.84 | 84.40 | 63.06 |
| OLMo-1B | 1B | 3T | 57.28 | 30.72 | 63.00 | 24.33 | 36.40 | 75.24 | 78.70 | 59.19 |
| OLMo-7B | 7B | 2.5T | 68.81 | 40.27 | 75.52 | 28.39 | 42.20 | 80.03 | 88.50 | 67.09 |
| OLMo-7B-0424 | 7B | 2.05T | 75.13 | 45.05 | 77.24 | 47.46 | 41.60 | 80.09 | 96.00 | 68.19 |
| OLMo-7B-0724 | 7B | 2.75T | 74.28 | 43.43 | 77.76 | 50.18 | 41.60 | 80.69 | 95.70 | 67.17 |
| OLMo-2-1124 | 7B | 4T | 82.79 | 57.42 | 80.50 | 60.56 | 46.20 | 81.18 | 96.40 | 74.74 |
| Ours, ( $r=4$ ) | 3.5B | 0.8T | 49.07 | 27.99 | 43.46 | 23.39 | 28.20 | 64.96 | 80.00 | 55.24 |
| Ours, ( $r=8$ ) | 3.5B | 0.8T | 65.11 | 35.15 | 58.54 | 25.29 | 35.40 | 73.45 | 92.10 | 55.64 |
| Ours, ( $r=16$ ) | 3.5B | 0.8T | 69.49 | 37.71 | 64.67 | 31.25 | 37.60 | 75.79 | 93.90 | 57.77 |
| Ours, ( $r=32$ ) | 3.5B | 0.8T | 69.91 | 38.23 | 65.21 | 31.38 | 38.80 | 76.22 | 93.50 | 59.43 |
### 4.3 Importance of Norms and Initializations at Scale
At small scales all normalization strategies worked, and we observed only tiny differences between initializations. The same was not true at scale. The first training run we started was set up with the same block sandwich structure as described above, but parameter-free RMSNorm layers, no embedding scale $\gamma$ , a parameter-free adapter $A(\mathbf{s},\mathbf{e})=\mathbf{s}+\mathbf{e}$ , and a peak learning rate of $4\text{\times}{10}^{-4}$ . As shown in Figure 5, this run (“Bad Run 1”, orange), quickly stalled.
While the run obviously stopped improving in training loss (left plot), we find that this stall is due to the model’s representation collapsing (Noci et al., 2022). The correlation of hidden states in the token dimension quickly goes to 1.0 (middle plot), meaning the model predicts the same hidden state for every token in the sequence. We find that this is an initialization issue that arises due to the recurrence operation. Every iteration of the recurrence block increases token correlation, mixing the sequence until collapse.
We attempt to fix this by introducing the embedding scale factor, switching back to a conventional pre-normalization block, and switching to the learned adapter. Initially, these changes appear to remedy the issue. Even though token correlation shoots close to 1.0 at the start (“Bad Run 2”, green), the model recovers after the first 150 steps. However, we quickly find that this training run is not able to leverage test-time compute effectively (right plot), as validation perplexity is the same whether 1 or 32 recurrences are used. This initialization and norm setup has led to a local minimum as the model has learned early to ignore the incoming state $\mathbf{s}$ , preventing further improvements.
In a third, and final run (“Main”, blue), we fix this issue by reverting back to the sandwich block format, and further dropping the peak learning rate to $4\text{\times}{10}^{-5}$ . This run starts smoothly, never reaches a token correlation close to 1.0, and quickly overtakes the previous run by utilizing the recurrence and improving with more iterations.
With our successful configuration, training continues smoothly for the next 750B tokens without notable interruptions or loss spikes. We plot training loss and perplexity at different recurrence steps in Figure 6. In our material, we refer to the final checkpoint of this run as our “main model”, which we denote as Huginn-0125 /hu: gIn/, transl. “thought”, is a raven depicted in Norse mythology. Corvids are surprisingly intelligent for their size, and and of course, as birds, able to unfold their wings at test-time..
## 5 Benchmark Results
We train our final model for 800B tokens, and a non-recurrent baseline for 180B tokens. We evaluate these checkpoints against other open-source models trained on fully public datasets (like ours) of a similar size. We compare against Amber (Liu et al., 2023c), Pythia (Biderman et al., 2023) and a number of OLMo 1&2 variants (Groeneveld et al., 2024; AI2, 2024; Team OLMo et al., 2025). We execute all standard benchmarks through the lm-eval harness (Biderman et al., 2024) and code benchmarks via bigcode-bench (Zhuo et al., 2024).
### 5.1 Standard Benchmarks
Overall, it is not straightforward to place our model in direct comparison to other large language models, all of which are small variations of the fixed-depth transformer architecture. While our model has only 3.5B parameters and hence requires only modest interconnect bandwidth during pretraining, it chews through raw FLOPs close to what a 32B parameter transformer would consume during pretraining, and can continuously improve in performance with test-time scaling up to FLOP budgets equivalent to a standard 50B parameter fixed-depth transformer. It is also important to note a few caveats of the main training run when interpreting the results. First, our main checkpoint is trained for only 47000 steps on a broadly untested mixture, and the learning rate is never cooled down from its peak. As an academic project, the model is trained only on publicly available data and the 800B token count, while large in comparison to older fully open-source models such as the Pythia series, is small in comparison to modern open-source efforts such as OLMo, and tiny in comparison to the datasets used to train industrial open-weight models.
Table 2: Benchmarks of mathematical reasoning and understanding. We report flexible and strict extract for GSM8K and GSM8K CoT, extract match for Minerva Math, and acc norm. for MathQA.
| Model Random Amber | GSM8K 0.00 3.94/4.32 | GSM8k CoT 0.00 3.34/5.16 | Minerva MATH 0.00 1.94 | MathQA 20.00 25.26 |
| --- | --- | --- | --- | --- |
| Pythia-2.8b | 1.59/2.12 | 1.90/2.81 | 1.96 | 24.52 |
| Pythia-6.9b | 2.05/2.43 | 2.81/2.88 | 1.38 | 25.96 |
| Pythia-12b | 3.49/4.62 | 3.34/4.62 | 2.56 | 25.80 |
| OLMo-1B | 1.82/2.27 | 1.59/2.58 | 1.60 | 23.38 |
| OLMo-7B | 4.02/4.09 | 6.07/7.28 | 2.12 | 25.26 |
| OLMo-7B-0424 | 27.07/27.29 | 26.23/26.23 | 5.56 | 28.48 |
| OLMo-7B-0724 | 28.66/28.73 | 28.89/28.89 | 5.62 | 27.84 |
| OLMo-2-1124-7B | 66.72/66.79 | 61.94/66.19 | 19.08 | 37.59 |
| Our w/o sys. prompt ( $r=32$ ) | 28.05/28.20 | 32.60/34.57 | 12.58 | 26.60 |
| Our w/ sys. prompt ( $r=32$ ) | 24.87/38.13 | 34.80/42.08 | 11.24 | 27.97 |
Table 3: Evaluation on code benchmarks, MBPP and HumanEval. We report pass@1 for both datasets.
| Model Random starcoder2-3b | Param 3B | Tokens 3.3T | MBPP 0.00 43.00 | HumanEval 0.00 31.09 |
| --- | --- | --- | --- | --- |
| starcoder2-7b | 7B | 3.7T | 43.80 | 31.70 |
| Amber | 7B | 1.2T | 19.60 | 13.41 |
| Pythia-2.8b | 2.8B | 0.3T | 6.70 | 7.92 |
| Pythia-6.9b | 6.9B | 0.3T | 7.92 | 5.60 |
| Pythia-12b | 12B | 0.3T | 5.60 | 9.14 |
| OLMo-1B | 1B | 3T | 0.00 | 4.87 |
| OLMo-7B | 7B | 2.5T | 15.6 | 12.80 |
| OLMo-7B-0424 | 7B | 2.05T | 21.20 | 16.46 |
| OLMo-7B-0724 | 7B | 2.75T | 25.60 | 20.12 |
| OLMo-2-1124-7B | 7B | 4T | 21.80 | 10.36 |
| Ours ( $r=32$ ) | 3.5B | 0.8T | 24.80 | 23.17 |
Disclaimers aside, we collect results for established benchmark tasks (Team OLMo et al., 2025) in Table 1 and show all models side-by-side. In direct comparison we see that our model outperforms the older Pythia series and is roughly comparable to the first OLMo generation, OLMo-7B in most metrics, but lags behind the later OLMo models trained larger, more carefully curated datasets. For the first recurrent-depth model for language to be trained at this scale, and considering the limitations of the training run, we find these results promising and certainly suggestive that further research into latent recurrence as an approach to test-time scaling is warranted.
Table 4: Baseline comparison, recurrent versus non-recurrent model trained in the same training setup and data. Comparing the recurrent model with its non-recurrent baseline, we see that even at 180B tokens, the recurrent substantially outperforms on harder tasks.
| Ours, early ckpt, ( $r=32$ ) Ours, early ckpt, ( $r=1$ ) Ours, ( $r=32$ ) | 0.18T 0.18T 0.8T | 53.62 34.01 69.91 | 29.18 23.72 38.23 | 48.80 29.19 65.21 | 25.59 23.47 31.38 | 31.40 25.60 38.80 | 68.88 53.26 76.22 | 80.60 54.10 93.50 | 52.88 53.75 59.43 | 9.02/10.24 0.00/0.15 34.80/42.08 |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| Ours, ( $r=1$ ) | 0.8T | 34.89 | 24.06 | 29.34 | 23.60 | 26.80 | 55.33 | 47.10 | 49.41 | 0.00/0.00 |
### 5.2 Math and Coding Benchmarks
We also evaluate the model on math and coding. For math, we evaluate GSM8k (Cobbe et al., 2021) (as zero-shot and in the 8-way CoT setup), MATH ((Hendrycks et al., 2021b) with the Minerva evaluation rules (Lewkowycz et al., 2022)) and MathQA (Amini et al., 2019). For coding, we check MBPP (Austin et al., 2021) and HumanEval (Chen et al., 2021). Here we find that our model significantly surpasses all models except the latest OLMo-2 model in mathematical reasoning, as measured on GSM8k and MATH. On coding benchmarks the model beats all other general-purpose open-source models, although it does not outperform dedicated code models, such as StarCoder2 (Lozhkov et al., 2024), trained for several trillion tokens. We also note that while further improvements in language modeling are slowing down, as expected at this training scale, both code and mathematical reasoning continue to improve steadily throughout training, see Figure 8.
<details>
<summary>x8.png Details</summary>

### Visual Description
## Line Chart: Performance vs. Recurrence at Test-Time
### Overview
This line chart depicts the performance of four different models – HellaSwag, GSM8K CoT (Strict), GSM8K CoT (Flexible), and Humaneval – as a function of the recurrence depth at test-time. Performance is measured on the y-axis, and recurrence depth is on the x-axis, both on a logarithmic scale. The chart illustrates how performance changes as the models are allowed to recur more times during testing.
### Components/Axes
* **X-axis:** "Recurrence at Test-Time" with markers at 1, 4, 8, 16, 32, and 64.
* **Y-axis:** "Performance" ranging from 0 to 80.
* **Legend:** Located at the top-right corner of the chart.
* HellaSwag (Blue dashed line with circle markers)
* GSM8K CoT (Strict) (Orange dashed line with square markers)
* GSM8K CoT (Flexible) (Green solid line with circle markers)
* Humaneval (Red solid line with circle markers)
* **Gridlines:** Present to aid in reading values.
### Detailed Analysis
Here's a breakdown of each model's performance trend and approximate data points:
* **HellaSwag (Blue, dashed, circle):** The line slopes upward sharply initially, then plateaus.
* Recurrence = 1: Performance ≈ 28
* Recurrence = 4: Performance ≈ 44
* Recurrence = 8: Performance ≈ 58
* Recurrence = 16: Performance ≈ 64
* Recurrence = 32: Performance ≈ 66
* Recurrence = 64: Performance ≈ 68
* **GSM8K CoT (Strict) (Orange, dashed, square):** The line shows an initial increase, then levels off, with some fluctuations.
* Recurrence = 1: Performance ≈ 5
* Recurrence = 4: Performance ≈ 15
* Recurrence = 8: Performance ≈ 25
* Recurrence = 16: Performance ≈ 35
* Recurrence = 32: Performance ≈ 37
* Recurrence = 64: Performance ≈ 38
* **GSM8K CoT (Flexible) (Green, solid, circle):** The line starts low, increases rapidly, and then plateaus.
* Recurrence = 1: Performance ≈ 1
* Recurrence = 4: Performance ≈ 10
* Recurrence = 8: Performance ≈ 28
* Recurrence = 16: Performance ≈ 40
* Recurrence = 32: Performance ≈ 43
* Recurrence = 64: Performance ≈ 45
* **Humaneval (Red, solid, circle):** The line shows a steady, but relatively slow, increase.
* Recurrence = 1: Performance ≈ 2
* Recurrence = 4: Performance ≈ 8
* Recurrence = 8: Performance ≈ 15
* Recurrence = 16: Performance ≈ 22
* Recurrence = 32: Performance ≈ 26
* Recurrence = 64: Performance ≈ 28
### Key Observations
* HellaSwag consistently outperforms the other models across all recurrence depths.
* GSM8K CoT (Strict) shows a moderate improvement with increasing recurrence, but remains significantly lower than HellaSwag.
* GSM8K CoT (Flexible) demonstrates a more substantial improvement with recurrence than the "Strict" version, but still lags behind HellaSwag.
* Humaneval exhibits the slowest performance growth with increasing recurrence.
* All models show diminishing returns in performance gains as recurrence depth increases beyond 16.
### Interpretation
The chart suggests that allowing models to recur at test-time can improve their performance, but the extent of the improvement varies significantly depending on the model architecture and training methodology. HellaSwag appears to be particularly well-suited to benefit from recurrence, achieving high performance even at low recurrence depths and exhibiting a relatively stable performance level as recurrence increases. The difference between GSM8K CoT (Strict) and GSM8K CoT (Flexible) indicates that a more flexible approach to chain-of-thought reasoning can lead to better performance with recurrence. Humaneval's slower growth suggests that its underlying capabilities may be less sensitive to the benefits of recurrence, or that it requires a different approach to leverage this technique effectively. The diminishing returns observed at higher recurrence depths suggest that there is a limit to the benefits of allowing models to recur indefinitely, and that optimizing other aspects of the model or training process may be more effective at improving performance beyond a certain point. The logarithmic scale of the x-axis emphasizes the rapid gains achieved at lower recurrence depths, and the flattening of the curves at higher depths.
</details>
Figure 7: Performance on GSM8K CoT (strict match and flexible match), HellaSwag (acc norm.), and HumanEval (pass@1). As we increase compute, the performance on these benchmarks increases. HellaSwag only needs $8$ recurrences to achieve near peak performance while other benchmarks make use of more compute.
<details>
<summary>x9.png Details</summary>

### Visual Description
## Line Chart: GSM8K CoT Performance vs. Tokens Trained
### Overview
This line chart depicts the performance of a model on the GSM8K CoT (Chain of Thought) benchmark as a function of the number of tokens trained. The chart compares performance across different numbers of "Rec" (likely referring to retrieval or recursion steps). The x-axis represents the number of tokens trained in billions, and the y-axis represents the GSM8K CoT score.
### Components/Axes
* **X-axis:** "Tokens Trained (Billion)" - Scale ranges from approximately 100 to 800 billion tokens.
* **Y-axis:** "GSM8K CoT" - Scale ranges from 0 to 35.
* **Legend:** Located at the top-right of the chart. Contains the following lines and their corresponding colors:
* "1 Rec" - Blue solid line
* "4 Rec" - Orange dashed line
* "8 Rec" - Green dashed-dotted line
* "16 Rec" - Red dotted line
* "32 Rec" - Purple solid line
* "64 Rec" - Gray dashed line
### Detailed Analysis
The chart displays six lines, each representing a different number of "Rec" steps.
* **1 Rec (Blue):** Starts at approximately 1 at 100 billion tokens, rises steadily to around 8 at 400 billion tokens, then plateaus around 8-9 until 700 billion tokens, and then slightly increases to around 9 at 800 billion tokens.
* **4 Rec (Orange):** Starts at approximately 0 at 100 billion tokens, rises to around 3 at 400 billion tokens, then plateaus around 3-4 until 700 billion tokens, and then increases to around 4 at 800 billion tokens.
* **8 Rec (Green):** Starts at approximately 0 at 100 billion tokens, rises sharply to around 12 at 500 billion tokens, then decreases to around 21 at 700 billion tokens, and then drops to around 15 at 800 billion tokens.
* **16 Rec (Red):** Starts at approximately 2 at 100 billion tokens, rises steadily to around 20 at 400 billion tokens, then rises to around 33 at 700 billion tokens, and then decreases to around 32 at 800 billion tokens.
* **32 Rec (Purple):** Starts at approximately 4 at 100 billion tokens, rises steadily to around 27 at 500 billion tokens, plateaus around 27-29 until 700 billion tokens, and then decreases to around 28 at 800 billion tokens.
* **64 Rec (Gray):** Starts at approximately 1 at 100 billion tokens, rises steadily to around 26 at 500 billion tokens, plateaus around 26-28 until 700 billion tokens, and then decreases to around 26 at 800 billion tokens.
### Key Observations
* The performance generally increases with the number of tokens trained for all "Rec" values.
* The "16 Rec" line consistently shows the highest performance, peaking at approximately 33 at 700 billion tokens.
* The "8 Rec" line exhibits a notable peak around 500 billion tokens, followed by a decline. This is an outlier compared to the other lines.
* The "64 Rec" line shows a similar trend to the "32 Rec" line, but with slightly lower values.
* The "1 Rec" and "4 Rec" lines show the lowest performance, with "1 Rec" consistently performing slightly better than "4 Rec".
### Interpretation
The data suggests that increasing the number of tokens trained generally improves performance on the GSM8K CoT benchmark. The number of "Rec" steps appears to be a crucial parameter, with 16 steps yielding the best results. The outlier behavior of the "8 Rec" line could indicate a potential issue with that specific configuration, such as overfitting or instability. The plateauing of performance for higher "Rec" values (32 and 64) after a certain number of tokens trained suggests diminishing returns. The chart demonstrates the importance of both model scale (tokens trained) and architectural choices (number of "Rec" steps) in achieving high performance on this benchmark. The diminishing returns observed at higher "Rec" values suggest that there may be an optimal balance between these two factors.
</details>
<details>
<summary>x10.png Details</summary>

### Visual Description
\n
## Line Chart: HellaSwag Performance vs. Tokens Trained
### Overview
This image presents a line chart illustrating the performance of different models on the HellaSwag benchmark as a function of the number of tokens trained. The chart compares models with varying numbers of "Rec" (likely referring to retrieval or layers), ranging from 1 to 64.
### Components/Axes
* **X-axis:** "Tokens Trained (Billion)" - Scale ranges from approximately 100 to 800 billion tokens.
* **Y-axis:** "HellaSwag" - Scale ranges from approximately 25 to 65.
* **Legend:** Located at the top-right of the chart. Contains the following lines and their corresponding colors:
* 1 Rec (Blue) - Solid line
* 4 Rec (Orange) - Dashed line
* 8 Rec (Green) - Dashed-dotted line
* 16 Rec (Red) - Dotted line
* 32 Rec (Purple) - Solid line
* 64 Rec (Gray) - Dashed line
### Detailed Analysis
The chart displays six lines, each representing a model with a different number of "Rec" units.
* **1 Rec (Blue):** The line starts at approximately 29 at 100 billion tokens, rises to a peak of around 31 at 300 billion tokens, and then declines to approximately 28 at 800 billion tokens.
* **4 Rec (Orange):** The line begins at approximately 32 at 100 billion tokens, steadily increases to around 44 at 800 billion tokens.
* **8 Rec (Green):** The line starts at approximately 42 at 100 billion tokens, increases to around 59 at 800 billion tokens.
* **16 Rec (Red):** The line begins at approximately 45 at 100 billion tokens, rises rapidly to around 62 at 400 billion tokens, and plateaus around 63-64 from 400 to 800 billion tokens.
* **32 Rec (Purple):** The line starts at approximately 46 at 100 billion tokens, increases rapidly to around 63 at 400 billion tokens, and plateaus around 63-64 from 400 to 800 billion tokens.
* **64 Rec (Gray):** The line begins at approximately 48 at 100 billion tokens, increases to around 47 at 800 billion tokens.
### Key Observations
* The 16 Rec and 32 Rec models achieve the highest HellaSwag scores, plateauing at approximately 63-64 after 400 billion tokens.
* The 1 Rec model shows a relatively flat performance curve, with a slight initial increase followed by a decline.
* The 4 Rec model exhibits a consistent, but relatively slow, increase in performance.
* The 8 Rec model shows a steady increase in performance, but remains below the 16 and 32 Rec models.
* The 64 Rec model shows a slight increase in performance, but remains below the 16 and 32 Rec models.
### Interpretation
The data suggests that increasing the number of "Rec" units initially improves performance on the HellaSwag benchmark. However, there appears to be a diminishing return, as the 16 and 32 Rec models reach a plateau in performance after a certain number of tokens trained. The 1 Rec model's declining performance suggests that it may be overfitting or struggling to capture the complexity of the task. The 64 Rec model's performance is similar to the 8 Rec model, suggesting that there may be an optimal number of "Rec" units beyond which further increases do not yield significant improvements. The plateauing of the 16 and 32 Rec models indicates that the models have reached their capacity to learn from the available data, or that the HellaSwag benchmark itself has a limited capacity to differentiate between models with higher "Rec" counts. The chart demonstrates the importance of model capacity and training data size in achieving optimal performance on language understanding tasks.
</details>
<details>
<summary>x11.png Details</summary>

### Visual Description
\n
## Line Chart: HumanEval Performance vs. Tokens Trained
### Overview
This line chart depicts the relationship between the number of tokens trained (in billions) and the HumanEval score for different numbers of retrieval augmentations ("Rec"). The chart shows how performance on the HumanEval benchmark changes as the model is trained on more data, for varying levels of retrieval context.
### Components/Axes
* **X-axis:** "Tokens Trained (Billion)" - Ranges from approximately 100 to 800 billion tokens.
* **Y-axis:** "HumanEval" - Ranges from 0 to 22.
* **Legend:** Located at the top of the chart, identifies the different lines representing different numbers of retrieval augmentations:
* Blue Solid Line: "1 Rec"
* Orange Dashed Line: "4 Rec"
* Green Dashed-Dotted Line: "8 Rec"
* Red Dotted Line: "16 Rec"
* Purple Solid Line: "32 Rec"
* Gray Dashed Line: "64 Rec"
### Detailed Analysis
The chart displays six lines, each representing a different number of retrieval augmentations.
* **1 Rec (Blue Solid):** The line slopes upward consistently from approximately 6 at 100 billion tokens to approximately 21 at 800 billion tokens. Specific data points (approximate): (100, 6), (200, 8), (300, 11), (400, 15), (500, 17), (600, 19), (700, 20), (800, 21).
* **4 Rec (Orange Dashed):** The line starts at approximately 0 at 100 billion tokens, rises to a peak of approximately 5 at 700 billion tokens, and then declines to approximately 2 at 800 billion tokens. Specific data points (approximate): (100, 0), (200, 1), (300, 2), (400, 3), (500, 4), (600, 3), (700, 5), (800, 2).
* **8 Rec (Green Dashed-Dotted):** The line starts at approximately 2 at 100 billion tokens, rises to approximately 15 at 600 billion tokens, and then declines to approximately 13 at 800 billion tokens. Specific data points (approximate): (100, 2), (200, 5), (300, 8), (400, 11), (500, 13), (600, 15), (700, 14), (800, 13).
* **16 Rec (Red Dotted):** The line starts at approximately 4 at 100 billion tokens, rises to approximately 18 at 600 billion tokens, and then plateaus around 18-19. Specific data points (approximate): (100, 4), (200, 7), (300, 10), (400, 14), (500, 16), (600, 18), (700, 19), (800, 19).
* **32 Rec (Purple Solid):** The line starts at approximately 5 at 100 billion tokens, rises steadily to approximately 20 at 800 billion tokens. Specific data points (approximate): (100, 5), (200, 8), (300, 12), (400, 16), (500, 18), (600, 19), (700, 20), (800, 20).
* **64 Rec (Gray Dashed):** The line starts at approximately 1 at 100 billion tokens, rises to approximately 15 at 400 billion tokens, and then declines to approximately 11 at 800 billion tokens. Specific data points (approximate): (100, 1), (200, 2), (300, 5), (400, 10), (500, 13), (600, 14), (700, 13), (800, 11).
### Key Observations
* Increasing the number of tokens trained generally improves HumanEval performance for all retrieval augmentation levels.
* The "1 Rec" and "32 Rec" lines consistently show the highest performance.
* The "4 Rec" and "64 Rec" lines exhibit a peak performance at intermediate token counts, followed by a decline. This suggests that too much retrieval context can be detrimental to performance.
* The "8 Rec" and "16 Rec" lines show a more consistent upward trend, but do not reach the same peak performance as "1 Rec" and "32 Rec".
### Interpretation
The data suggests that retrieval augmentation can improve performance on the HumanEval benchmark, but the optimal amount of retrieval context depends on the number of tokens trained. Initially, more retrieval context (up to 16 Rec) seems to help, but beyond that, it appears to hinder performance. The "1 Rec" and "32 Rec" lines indicate that a moderate amount of retrieval context, combined with sufficient training data, yields the best results. The decline in performance for "4 Rec" and "64 Rec" at higher token counts could be due to the model being distracted by irrelevant information retrieved from the context, or overfitting to the specific retrieval data. The chart highlights the importance of carefully tuning the amount of retrieval context to maximize performance. The fact that the lines do not converge suggests that the optimal retrieval strategy may vary depending on the model size and training data. The chart demonstrates a clear trade-off between the benefits of retrieval augmentation and the potential for distraction or overfitting.
</details>
Figure 8: GSM8K CoT, HellaSwag, and HumanEval performance over the training tokens with different recurrences at test-time. We evaluate GSM8K CoT with chat template and 8-way few shot as multiturn. HellaSwag and HumanEval are zero-shot with no chat template. Model performance on harder tasks grows almost linearly with the training budget, if provided sufficient test-time compute.
### 5.3 Where does recurrence help most?
How much of this performance can we attribute to recurrence, and how much to other factors, such as dataset, tokenization and architectural choices? In Table 4, we compare our recurrent model against its non-recurrent twin, which we trained to 180B tokens in the exact same setting. In direct comparison of both models at 180B tokens, we see that the recurrent model outperforms its baseline with an especially pronounced advantage on harder tasks, such as the ARC challenge set. On other tasks, such as SciQ, which requires straightforward recall of scientific facts, performance of the models is more similar. We observe that gains through reasoning are especially prominent on GSM8k, where the 180B recurrent model is already 5 times better than the baseline at this early snapshot in the pretraining process. We also note that the recurrent model, when evaluated with only a single recurrence, effectively stops improving between the early 180B checkpoint and the 800B checkpoint, showing that further improvements are not built into the prelude or coda non-recurrent layers but encoded entirely into the iterations of the recurrent block.
Further, we chart the improvement as a function of test-time compute on several of these tasks for the main model in Figure 7. We find that saturation is highly task-dependent, on easier tasks the model saturates quicker, whereas it benefits from more compute on others.
<details>
<summary>x12.png Details</summary>

### Visual Description
## Line Chart: ARC Challenge Accuracy vs. Test-Time Compute Recurrence
### Overview
This chart displays the relationship between Test-Time Compute Recurrence and ARC Challenge Accuracy for different numbers of shots (0-shot, 1-shot, 5-shot, 25-shot, and 50-shot). The chart uses line plots with error bars to represent the accuracy at each recurrence level.
### Components/Axes
* **X-axis:** Test-Time Compute Recurrence. Scale is logarithmic, with markers at 1, 2, 4, 6, 8, 12, 20, 32, 48, and 64.
* **Y-axis:** ARC Challenge Accuracy (%). Scale ranges from approximately 18% to 45%.
* **Legend:** Located in the bottom-right corner. Contains the following labels and corresponding colors:
* 0-shot: Light Blue
* 1-shot: Orange
* 5-shot: Green
* 25-shot: Red
* 50-shot: Purple
* **Data Series:** Five lines, each representing a different number of shots. Each line is accompanied by error bars indicating the variance in accuracy.
### Detailed Analysis
* **0-shot (Light Blue):** The line starts at approximately 20% accuracy at a recurrence of 1. It increases slightly to around 23% at a recurrence of 4, then plateaus around 32-33% for the remainder of the recurrence values. Error bars are relatively consistent across the range, indicating stable variance.
* **1-shot (Orange):** The line begins at approximately 20% accuracy at a recurrence of 1. It increases steadily to around 38% at a recurrence of 12, then continues to increase, reaching approximately 42% at a recurrence of 64. Error bars are larger at lower recurrence values, decreasing as recurrence increases.
* **5-shot (Green):** The line starts at approximately 20% accuracy at a recurrence of 1. It increases rapidly, reaching around 36% at a recurrence of 8, and then continues to increase, reaching approximately 42% at a recurrence of 64. Error bars are similar to the 1-shot line, larger at lower recurrence values.
* **25-shot (Red):** The line starts at approximately 20% accuracy at a recurrence of 1. It increases very rapidly, reaching around 40% at a recurrence of 8, and then continues to increase, reaching approximately 44% at a recurrence of 64. Error bars are relatively small and consistent.
* **50-shot (Purple):** The line starts at approximately 20% accuracy at a recurrence of 1. It increases very rapidly, reaching around 42% at a recurrence of 8, and then continues to increase, reaching approximately 44-45% at a recurrence of 64. Error bars are the smallest of all the lines, indicating the most consistent accuracy.
All lines show an initial steep increase in accuracy as Test-Time Compute Recurrence increases from 1 to 8. The rate of increase slows down as recurrence increases beyond 8, with the lines tending to plateau.
### Key Observations
* The 0-shot performance remains relatively low and stable across all recurrence values.
* Increasing the number of shots significantly improves accuracy, especially at lower recurrence values.
* The 25-shot and 50-shot lines achieve the highest accuracy, with the 50-shot line showing the most consistent performance (smallest error bars).
* The benefit of increasing recurrence diminishes as the recurrence value increases, particularly for the higher-shot lines.
### Interpretation
The data suggests that Test-Time Compute Recurrence is a crucial factor in improving the accuracy of the ARC Challenge, particularly when combined with few-shot learning. The more shots provided, the greater the benefit of increasing recurrence, up to a point. The plateauing of the lines at higher recurrence values indicates that there is a limit to the improvement that can be achieved by simply increasing recurrence. The 0-shot performance highlights the importance of providing some examples (shots) to the model for effective learning. The consistent performance of the 50-shot line suggests that this number of shots provides a good balance between accuracy and stability. The logarithmic scale of the x-axis indicates that the initial gains in accuracy are more significant than gains at higher recurrence values. This could be due to the model quickly learning the basic patterns in the data at lower recurrence values, and then requiring more computational effort to refine its understanding.
</details>
Figure 9: The saturation point in un-normalized accuracy via test-time recurrence on the ARC challenge set is correlated with the number of few-shot examples. The model uses more recurrence to extract more information from the additional few-shot examples, making use of more compute if more context is given.
Table 5: Comparison of Open and Closed QA Performance (%) (Mihaylov et al., 2018). In the open exam, a relevant fact is provided before the question is asked. In this setting, our smaller model closes the gap to other open-source models, indicating that the model is capable, but has fewer facts memorized.
| Amber Pythia-2.8b Pythia-6.9b | 41.0 35.4 37.2 | 46.0 44.8 44.2 | +5.0 +9.4 +7.0 |
| --- | --- | --- | --- |
| Pythia-12b | 35.4 | 48.0 | +12.6 |
| OLMo-1B | 36.4 | 43.6 | +7.2 |
| OLMo-7B | 42.2 | 49.8 | +7.6 |
| OLMo-7B-0424 | 41.6 | 50.6 | +9.0 |
| OLMo-7B-0724 | 41.6 | 53.2 | +11.6 |
| OLMo-2-1124 | 46.2 | 53.4 | +7.2 |
| Ours ( $r=32$ ) | 38.2 | 49.2 | +11.0 |
#### Recurrence and Context
We evaluate ARC-C performance as a function of recurrence and number of few-shot examples in the context in Figure 9. Interestingly, without few-shot examples to consider, the model saturates in compute around 8-12 iterations. However, when more context is given, the model can reason about more information in context, which it does, saturating around 20 iterations if 1 example is provided, and 32 iterations, if 25-50 examples are provided, mirroring generalization improvements shown for recurrence (Yang et al., 2024a; Fan et al., 2025). Similarly, we see that if we re-evaluate OBQA in Table 5, but do not run the benchmark in the default lm-eval ”closed-book” format and rather provide a relevant fact, our recurrent model improves significantly almost closing the gap to OLMo-2. Intuitively this makes sense, as the recurrent models has less capacity to memorize facts but more capacity to reason about its context.
### 5.4 Improvements through Weight Averaging
Due to our constant learning rate, we can materialize further improvements through weight averaging (Izmailov et al., 2018) to simulate the result of a cooldown (Hägele et al., 2024; DeepSeek-AI et al., 2024). We use an exponential moving average starting from our last checkpoint with $\beta=0.9$ , incorporating the last 75 checkpoints with a dilation factor of 7, a modification to established protocols (Kaddour, 2022; Sanyal et al., 2024). We provide this EMA model as well, which further improves GMS8k performance to $47.23\$ flexible ( $38.59\$ strict), when tested at $r=64$ .
<details>
<summary>x13.png Details</summary>

### Visual Description
\n
## Histograms: Density Distributions for Different Scenarios
### Overview
The image presents four histograms, each representing the density distribution of "Steps to KL-based Threshold" for different scenarios: high school mathematics, philosophy, logical fallacies, and moral scenarios. Each histogram displays two distributions, labeled "Default" and "Cont. CoT" (Chain of Thought), with associated mean (μ) values.
### Components/Axes
* **X-axis:** "Steps to KL-based Threshold" - ranging from 0 to 30, with increments of 5.
* **Y-axis:** "Density" - ranging from 0.00 to 0.08, with increments of 0.01.
* **Histograms:** Four separate histograms, one for each scenario.
* **Legend:** Each histogram has a legend in the top-left corner indicating the two distributions:
* "Default" (Green for high school mathematics, Yellow for philosophy, Red for logical fallacies, Blue for moral scenarios)
* "Cont. CoT" (Light Green for high school mathematics, Light Yellow for philosophy, Light Red for logical fallacies, Light Blue for moral scenarios)
* **Mean Values (μ):** Each legend also displays the mean (μ) value for each distribution.
### Detailed Analysis or Content Details
**1. High School Mathematics (Green)**
* **Default (Green):** The distribution is roughly bell-shaped, peaking around steps 5-10. The density decreases as the number of steps increases. μ = 12.7.
* Approximate Density Values:
* Steps 5: ~0.07
* Steps 10: ~0.06
* Steps 15: ~0.045
* Steps 20: ~0.03
* Steps 25: ~0.015
* **Cont. CoT (Light Green):** The distribution is similar to the "Default" but shifted slightly to the right, peaking around steps 10-15. μ = 11.9.
* Approximate Density Values:
* Steps 5: ~0.05
* Steps 10: ~0.065
* Steps 15: ~0.05
* Steps 20: ~0.03
* Steps 25: ~0.01
**2. Philosophy (Yellow)**
* **Default (Yellow):** The distribution is unimodal, peaking around steps 10-15. The density decreases as the number of steps increases. μ = 14.6.
* Approximate Density Values:
* Steps 5: ~0.04
* Steps 10: ~0.06
* Steps 15: ~0.05
* Steps 20: ~0.03
* Steps 25: ~0.01
* **Cont. CoT (Light Yellow):** The distribution is similar to the "Default" but shifted slightly to the right, peaking around steps 15-20. μ = 13.5.
* Approximate Density Values:
* Steps 5: ~0.03
* Steps 10: ~0.04
* Steps 15: ~0.06
* Steps 20: ~0.04
* Steps 25: ~0.01
**3. Logical Fallacies (Red)**
* **Default (Red):** The distribution is bimodal, with peaks around steps 5 and 15-20. μ = 15.6.
* Approximate Density Values:
* Steps 5: ~0.05
* Steps 10: ~0.03
* Steps 15: ~0.06
* Steps 20: ~0.04
* Steps 25: ~0.01
* **Cont. CoT (Light Red):** The distribution is unimodal, peaking around steps 15-20. μ = 14.4.
* Approximate Density Values:
* Steps 5: ~0.03
* Steps 10: ~0.02
* Steps 15: ~0.06
* Steps 20: ~0.04
* Steps 25: ~0.01
**4. Moral Scenarios (Blue)**
* **Default (Blue):** The distribution is roughly bell-shaped, peaking around steps 15-20. The density decreases as the number of steps increases. μ = 16.2.
* Approximate Density Values:
* Steps 5: ~0.02
* Steps 10: ~0.03
* Steps 15: ~0.07
* Steps 20: ~0.06
* Steps 25: ~0.03
* **Cont. CoT (Light Blue):** The distribution is similar to the "Default" but shifted slightly to the right, peaking around steps 20-25. μ = 16.0.
* Approximate Density Values:
* Steps 5: ~0.01
* Steps 10: ~0.02
* Steps 15: ~0.05
* Steps 20: ~0.07
* Steps 25: ~0.04
### Key Observations
* The "Cont. CoT" distributions generally have lower peaks and are shifted to the right compared to the "Default" distributions, indicating that using Chain of Thought tends to require more steps to reach the KL-based threshold.
* The "Logical Fallacies" scenario exhibits a bimodal distribution for the "Default" setting, suggesting two distinct patterns in the number of steps required.
* The mean values (μ) for "Cont. CoT" are consistently lower than those for "Default" across all scenarios, reinforcing the observation that CoT requires more steps.
### Interpretation
The data suggests that the use of Chain of Thought (CoT) reasoning in these scenarios generally leads to a need for more steps to reach a certain level of confidence (as measured by the KL-based threshold). This could be because CoT involves more complex reasoning processes, requiring more iterations or steps to converge.
The bimodal distribution observed in the "Logical Fallacies" scenario for the "Default" setting is particularly interesting. This could indicate that there are two fundamentally different ways in which the model approaches logical fallacies – one that requires fewer steps and another that requires more. The CoT approach seems to homogenize this, resulting in a unimodal distribution.
The differences in distributions across scenarios highlight the varying complexity of the tasks. Moral scenarios and philosophy, for example, seem to require more steps overall compared to high school mathematics, even without CoT. This aligns with the intuitive understanding that these domains involve more nuanced and abstract reasoning.
The KL-based threshold likely represents a point where the model's confidence in its answer reaches a certain level. The "Steps to KL-based Threshold" metric, therefore, provides insight into the computational effort required to achieve a reliable outcome in each scenario.
</details>
Figure 10: Histograms of zero-shot, per-token adaptive exits based on KL difference between steps for questions from MMLU categories, with and without zero-shot continuous CoT. The mean of each distribution is given in the legends. The exit threshold is fixed to $5\text{\times}{10}^{-4}$ . We see that the model converges quicker on high school mathematics than tasks such as logical fallacies or moral scenarios. On some tasks, such as philosophy, the model is able to effectively re-use states in its latent CoT and converge quickly on a subset of tokens, leading to fewer steps required overall.
## 6 Recurrent Depth simplifies LLMs
Aside from encouraging performance in mathematical and code reasoning, recurrent-depth models turn out to be surprisingly natural tools to support a number of methods that require substantial effort with standard transformers. In the next section, we provide a non-exhaustive overview.
### 6.1 Zero-Shot Adaptive Compute at Test-Time
We have shown that the model is capable of varying compute on a per-query level, running the model in different recurrence modes. This is after all also how the model is trained, as in Equation 1. However, it would be more efficient in practice to stop recurring early when predictions are easy, and only spend compute on hard decisions. Other work, especially when based on standard transformers, requires models trained specifically for early exits (Elbayad et al., 2019; Fan et al., 2019; Banino et al., 2021), or models finetuned with exit heads on every layer (Schuster et al., 2022). To test our model’s zero-shot exit abilities, we choose a simple exit criterion to evaluate convergence, the KL-divergence between two successive steps. If this divergence falls below $5\text{\times}{10}^{-4}$ , we stop iterating, sample the output token, and move to generate the next token.
We show this zero-shot per-token adaptive compute behavior in Figure 10, where we plot the distribution of steps taken before the exit condition is hit. We do this for the first 50 questions from different MMLU categories, asked in free-form chat. Interestingly, the number of steps required to exit differs notably between categories, with the model exiting earlier on high school mathematics, but taking on average 3.5 steps more on moral scenarios. As a preliminary demonstration, we verify on MTBench that this adaptivity does not significantly impact performance in a conversational benchmark setting (standard: $5.63$ , early exits: $5.56$ see Appendix Table 6).
**Remark 6.1 (What about missing KV-cache entries?)**
*Traditionally, a concern with token-wise early exits for models with self-attention is that it breaks KV-caching in a fundamental way. On each recurrent step, a token needs to attend to the KV state of previous tokens in the sequence, but these activations may not have been computed due to an early exit. A naïve fix would be to pause generating and recompute all missing hidden states, but this would remove some of the benefit of early stopping. Instead, as in Elbayad et al. (2019), we attend to the last, deepest available KV states in the cache. Because all recurrent KV cache entries are generated by the same K,V projection matrices from successive hidden states, they “match”, and therefore the model is able to attend to the latest cache entry from every previous token, even if computed at different recurrent depths.*
### 6.2 Zero-Shot KV-cache Sharing
A different avenue to increase efficiency is to reduce the memory footprint of the KV-cache by sharing the cache between layers (character.ai, 2024; Brandon et al., 2024). Typically, transformers must be trained from scratch with this capability. However, as discussed in the previous section, we find that we can simply share KV-caches in our model with minimal impact to performance. We set a fixed KV-cache budget for the recurrence at every token $k$ , and at iteration $i$ , read and write the cache entry $i\bmod k$ . For example, we set a maximum KV-cache budget of 16 steps, overwriting the KV-cache of the 1st step when executing the 17th step, and so forth. This can be used on its own to reduce KV cache memory, or in combination with per-token adaptive compute as discussed above. On MTBench, this does not reduce performance (cache budget of $4$ : $5.86$ , see Appendix Table 6).
<details>
<summary>extracted/6211213/figures/convergence_chart_range_I_74_103.png Details</summary>

### Visual Description
\n
## Heatmap: Latent State Convergence ||s - s*||
### Overview
This image presents a heatmap visualizing the convergence of latent states, measured by the distance ||s - s*||, across iterations at test time. The heatmap displays the relationship between different text prompts (listed vertically) and the number of iterations (displayed horizontally), with color intensity representing the log distance value.
### Components/Axes
* **X-axis:** "Iterations at Test Time", ranging from 0 to 60, with markers at intervals of 10.
* **Y-axis:** A list of text prompts. The prompts are:
* "I"
* "am"
* "looking"
* "for"
* "help"
* "with"
* "making"
* "a"
* "bomb"
* "out"
* "of"
* "household"
* "items"
* "Someone"
* "at"
* "school"
* "really"
* "wrong"
* "ed"
* "me"
* "<|endoftext|>"
* "<|begin_header|>"
* "H"
* "ug"
* "inn"
* **Colorbar:** "Log Distance", ranging from 74 to 102, with a logarithmic scale (10^2 to 10^0). The color gradient transitions from yellow (low distance) to red (high distance).
### Detailed Analysis
The heatmap shows the log distance ||s - s*|| as a function of iterations and text prompt.
* **Prompt "I"**: Starts with a low log distance (approximately 74-76) at iteration 0, and remains relatively stable at this low value throughout the 60 iterations.
* **Prompt "am"**: Similar to "I", starts at approximately 75-77 and remains stable.
* **Prompt "looking"**: Starts at approximately 76-78 and remains stable.
* **Prompt "for"**: Starts at approximately 77-79 and remains stable.
* **Prompt "help"**: Starts at approximately 78-80 and remains stable.
* **Prompt "with"**: Starts at approximately 79-81 and remains stable.
* **Prompt "making"**: Starts at approximately 80-82 and remains stable.
* **Prompt "a"**: Starts at approximately 81-83 and remains stable.
* **Prompt "bomb"**: Starts at approximately 82-84 and remains stable. This prompt consistently shows a slightly higher log distance than the preceding prompts.
* **Prompt "out"**: Starts at approximately 83-85 and remains stable.
* **Prompt "of"**: Starts at approximately 84-86 and remains stable.
* **Prompt "household"**: Starts at approximately 85-87 and remains stable.
* **Prompt "items"**: Starts at approximately 86-88 and remains stable.
* **Prompt "Someone"**: Starts at approximately 87-89 and remains stable.
* **Prompt "at"**: Starts at approximately 88-90 and remains stable.
* **Prompt "school"**: Starts at approximately 89-91 and remains stable.
* **Prompt "really"**: Starts at approximately 90-92 and remains stable.
* **Prompt "wrong"**: Starts at approximately 91-93 and remains stable.
* **Prompt "ed"**: Starts at approximately 92-94 and remains stable.
* **Prompt "me"**: Starts at approximately 93-95 and remains stable.
* **Prompt "<|endoftext|>"**: Starts at approximately 95-97 and remains stable.
* **Prompt "<|begin_header|>"**: Starts at approximately 96-98 and remains stable.
* **Prompt "H"**: Starts at approximately 97-99 and remains stable.
* **Prompt "ug"**: Starts at approximately 98-100 and remains stable.
* **Prompt "inn"**: Starts at approximately 99-102 and remains stable. This prompt consistently shows the highest log distance.
Generally, the heatmap shows a consistent color across all iterations for each prompt, indicating that the distance ||s - s*|| does not significantly change with increasing iterations. The log distance values increase as you move down the list of prompts.
### Key Observations
* The log distance values are relatively stable across iterations for all prompts.
* The prompts "inn" consistently exhibit the highest log distance, while "I" exhibits the lowest.
* There is a clear gradient in log distance values as you move down the list of prompts, suggesting a varying degree of convergence for different prompts.
* No significant outliers or anomalies are observed.
### Interpretation
The heatmap suggests that the latent state converges relatively quickly for all the given text prompts, as the log distance remains stable across iterations. The varying log distance values across different prompts indicate that some prompts are easier to represent in the latent space than others. The prompt "inn" being the furthest suggests it is the most difficult to converge, potentially due to its complexity or rarity in the training data. The consistent stability across iterations implies that further iterations beyond 60 are unlikely to significantly improve convergence for these prompts. The data demonstrates a clear relationship between the text prompt and the ease of latent state convergence. The prompts at the beginning of the list are simple and common, while the prompts at the end are more complex or less frequent, leading to a higher log distance and slower convergence.
</details>
Figure 11: Convergence of latent states for every token in a sequence (going top to bottom) and latent iterations (going left to right), plotting the distance a final iterate $s^{*}$ , which we set with $r=128$ . Shown is an unsafe question posed to the model. We immediately see that highly token-specific convergence rates emerge simply with scale. This is interesting, as the model is only trained with $r$ fixed for whole sequences seen during training. We see that convergence is especially slow on the key part of the question, really wrong -ed.We further see that the model also learns different behaviors, we see an oscillating pattern in latent space, here most notably for the school token. Not pictured is the model refusing to answer after deliberating the question.
### 6.3 Zero-Shot Continuous Chain-of-Thought
By attending to the output of later steps of previous tokens in the early steps of current tokens, as described in the KV-cache sharing section, we actually construct a computation that is deeper than the current number of recurrence steps. However, we can also construct deeper computational graphs more explicitly. Instead of sampling a random initial state $\mathbf{s}_{0}$ at every generation step, we can warm-start with the last state $\mathbf{s}_{r}$ from the previous token. This way, the model can benefit from latent information encoded at the previous generation step, and further improve. As shown in Figure 10, this reduces the average number of steps required to converge by 1-2. On tasks such as philosophy, we see that the exit distribution shifts noticeably, with the model more often exiting early by recycling previous compute.
This is closely related to the continuous chain of thought approach explored in (Hao et al., 2024), in the sense that it is an intervention to the trained model to add additional recurrence. To achieve a similar behavior in fixed-depth transformers, Hao et al. (2024) train models on reasoning chains to accept their last hidden state as alternative inputs when computing the next token. Finetuning in this manner transforms these models also into limited depth-recurrent models - in this way the main distinction between both approaches is whether to pretrain from scratch for recurrence, or whether to finetune existing fixed-depth models to have this capability - and whether Chain-of-Thought data is required.
### 6.4 Zero-Shot Self-Speculative Decoding
Recurrent-depth models can also inherently generate text more efficiently by using speculative decoding (Leviathan et al., 2023) without the need for a separate draft model. With standard transformer models, speculative decoding requires an external draft model, Medusa heads (Cai et al., 2024), or early-exit adaptation (Zhang et al., 2024b; Elhoushi et al., 2024). Zhang et al. (2024b) implement self-speculative decoding simply through layer skipping, but this does not always result in good draft quality. In comparison, our model can naturally be run with fewer iterations to draft the next $N$ tokens in the generated sequence, which can then be verified with any desired number of iterations $M>N$ later. This can also be staggered across multiple draft stages, or the draft model can use adaptive compute as in Section 6.1. Drafting with this model is also efficient, as the states computed during drafting are not wasted and can be re-used when verifying.
<details>
<summary>x14.png Details</summary>

### Visual Description
## Scatter Plots: Principal Component Analysis (PCA) for Token "deeper"
### Overview
The image presents three scatter plots, each representing a Principal Component Analysis (PCA) projection for the token "deeper". Each plot displays the relationship between two principal components (PCs). The plots are arranged horizontally, showing PC1-PC2, PC3-PC4, and PC5-PC6 respectively. Each plot has a horizontal axis and a vertical axis, both scaled with numerical values. Data points are represented as colored circles. There are four distinct colors used to represent different data points.
### Components/Axes
* **Titles:**
* PC1-PC2
* PC3-PC4
* PC5-PC6
* **Token:** "deeper" (displayed above PC1-PC2)
* **Axes:** Each plot has a horizontal (x-axis) and vertical (y-axis) with numerical scales.
* PC1-PC2: x-axis ranges from approximately -18 to 8; y-axis ranges from approximately -8 to 8.
* PC3-PC4: x-axis ranges from approximately -29 to 18; y-axis ranges from approximately -9 to 9.
* PC5-PC6: x-axis ranges from approximately -8 to 8; y-axis ranges from approximately -10 to 10.
* **Data Points:** Four distinct colors are used to represent different data points:
* Purple
* Red
* Orange
* Blue
### Detailed Analysis or Content Details
**PC1-PC2:**
* The purple data series shows a generally downward sloping trend from left to right, starting at approximately (-18, 6) and ending at approximately (6, -2). There are approximately 10 purple data points.
* The red data point is located at approximately (0, 1).
* The orange data point is located at approximately (1, -1).
* The blue data points are clustered around the origin, with values ranging from approximately (-1, -1) to (1, 1). There are approximately 6 blue data points.
**PC3-PC4:**
* The purple data series shows an upward sloping trend from left to right, starting at approximately (-28, -2) and ending at approximately (16, 7). There are approximately 8 purple data points.
* The red data point is located at approximately (0, 0).
* The orange data point is located at approximately (0, -1).
* The blue data point is located at approximately (1, 1).
**PC5-PC6:**
* The purple data series shows a generally flat trend, with values ranging from approximately (-6, 2) to (4, 2). There are approximately 6 purple data points.
* The red data series shows an upward sloping trend, starting at approximately (-2, -2) and ending at approximately (2, 2). There are approximately 4 red data points.
* The orange data point is located at approximately (0, -2).
* The blue data points are clustered around the origin, with values ranging from approximately (-2, 2) to (2, 2). There are approximately 6 blue data points.
### Key Observations
* The purple data series dominates each plot, suggesting it represents the majority of the data.
* The red, orange, and blue data points appear to be outliers or represent smaller subsets of the data.
* The trends observed in each plot vary, indicating different relationships between the principal components.
* The data points are not evenly distributed, suggesting potential clustering or non-linearity in the data.
### Interpretation
These PCA plots visualize the distribution of the token "deeper" across the first six principal components. Each plot represents a two-dimensional projection of the high-dimensional data, allowing for visual inspection of the data's structure. The purple data series likely represents the core meaning or context of the token, as it is the most prominent in each plot. The other colored data points may represent different nuances, variations, or related concepts associated with the token.
The varying trends in each plot suggest that different combinations of principal components capture different aspects of the token's meaning. For example, the upward sloping trend in PC3-PC4 might indicate a positive correlation between those components, while the downward sloping trend in PC1-PC2 might indicate a negative correlation.
The clustering of data points in some plots suggests that the data is not uniformly distributed, and there may be underlying patterns or groupings within the data. Further analysis would be needed to determine the significance of these patterns and their relationship to the token's meaning. The plots provide a visual summary of the data's variance and relationships, which can be used to gain insights into the token's semantic properties.
</details>
<details>
<summary>x15.png Details</summary>

### Visual Description
\n
## Scatter Plots: Principal Component Analysis (PCA) Visualizations
### Overview
The image presents three scatter plots, each representing a Principal Component Analysis (PCA) projection of data. Each plot displays data points in a two-dimensional space defined by different principal components. The plots aim to visualize the variance and relationships within the data based on these components. A "Token: '3'" label is present at the top-left of the image, suggesting this data relates to a specific token or category labeled "3".
### Components/Axes
Each plot shares the following characteristics:
* **Axes:** Both the x and y axes range from approximately -10 to 10 (PC1-PC2), -13 to 10 (PC3-PC4), and -13 to 13 (PC5-PC6). The axes are labeled with the corresponding principal component numbers (e.g., PC1, PC2, PC3, etc.).
* **Data Points:** Each plot contains a cluster of purple data points.
* **Mean Indicator:** A red 'x' marks the mean of the data points in each plot.
* **Connecting Lines:** Thin, light-blue lines connect consecutive data points in a sequential manner within each plot.
* **Ellipses:** A green ellipse is drawn around the cluster of purple data points in each plot, likely representing the standard deviation or confidence interval.
### Detailed Analysis or Content Details
**Plot 1: PC1-PC2**
* **Trend:** The data points form a roughly elliptical shape, concentrated around the origin (0,0). The connecting lines show a cyclical pattern, suggesting a potential ordering or trajectory within the data.
* **Data Points:** The points are distributed between approximately -3 and 4 on the y-axis and -10 to 10 on the x-axis.
* **Mean:** The mean is located at approximately (0, 0).
**Plot 2: PC3-PC4**
* **Trend:** The data points are more dispersed than in the first plot, forming a less defined elliptical shape. The connecting lines show a more linear trend, with a slight downward slope.
* **Data Points:** The points are distributed between approximately -2 and 2 on the y-axis and -10 to 10 on the x-axis.
* **Mean:** The mean is located at approximately (0, 0).
**Plot 3: PC5-PC6**
* **Trend:** The data points are more vertically oriented, with a wider spread along the y-axis. The connecting lines show a more complex pattern, with some points extending further along the y-axis.
* **Data Points:** The points are distributed between approximately -5 and 5 on the y-axis and -10 to 13 on the x-axis.
* **Mean:** The mean is located at approximately (0, 0).
### Key Observations
* The mean of the data is consistently near the origin (0,0) for all three PCA projections.
* The spread of the data varies across the different principal component pairs. PC5-PC6 shows the largest spread along the y-axis, indicating greater variance in that component.
* The connecting lines suggest an underlying order or trajectory within the data, which is more pronounced in the first plot.
### Interpretation
These PCA plots are used to reduce the dimensionality of the data while preserving the most important variance. Each plot represents a different projection of the data onto a pair of principal components. The "Token: '3'" label suggests that these plots specifically represent the PCA results for data associated with token "3".
The varying spread of the data points across the different plots indicates that different principal components capture different amounts of variance in the data. The elliptical shapes suggest that the data is relatively well-clustered in each projection, but the shape and orientation of the ellipses vary, indicating different relationships between the variables.
The connecting lines provide additional information about the data, suggesting an underlying order or trajectory. This could be useful for identifying patterns or trends in the data. The consistent location of the mean near the origin suggests that the data is centered around a common point in the principal component space.
The plots demonstrate how the data associated with token "3" is distributed across different dimensions of variance, as captured by the principal components. This information can be used to understand the underlying structure of the data and to identify potential relationships between the variables.
</details>
<details>
<summary>x16.png Details</summary>

### Visual Description
## Scatter Plots: Principal Component Analysis (PCA) Visualizations
### Overview
The image presents three scatter plots, each representing a Principal Component Analysis (PCA) projection of data. Each plot displays data points projected onto two principal components. The plots are labeled "PC1-PC2", "PC3-PC4", and "PC5-PC6", indicating the principal component pairs used for each projection. A token "wrong" is present above the first plot. There are several data points in each plot, colored differently (purple, cyan, orange, red, and green). One plot also contains a single data point marked with an 'x' in red.
### Components/Axes
Each plot has two axes: a horizontal axis (x-axis) and a vertical axis (y-axis). The scales vary for each plot:
* **PC1-PC2:** X-axis ranges from approximately -12 to 12, Y-axis ranges from approximately -7 to 7.
* **PC3-PC4:** X-axis ranges from approximately -4 to 4, Y-axis ranges from approximately -14 to 14.
* **PC5-PC6:** X-axis ranges from approximately -10 to 10, Y-axis ranges from approximately -11 to 11.
There is no explicit legend, but the colors of the data points are consistent across all three plots.
### Detailed Analysis or Content Details
**PC1-PC2:**
* **Purple Data Points:** A cluster of approximately 7 purple points are located in the bottom-left quadrant, with x-values ranging from approximately -11 to -2 and y-values ranging from approximately -6 to -1. A single purple point is located near the origin (x ≈ 0, y ≈ 0). Another purple point is located at approximately (2, 1).
* **Cyan Data Points:** Two cyan points are present. One is near the origin (x ≈ 0, y ≈ 0), and the other is at approximately (1, 0).
* **Orange Data Points:** Two orange points are present. One is at approximately (-1, -1) and the other is at approximately (0, 0).
* **Red Data Points:** One red point is located at approximately (0, 0).
* **Green Data Points:** One green point is located at approximately (-1, 1).
**PC3-PC4:**
* **Purple Data Points:** A cluster of approximately 6 purple points are located near the origin, with x-values ranging from approximately -1 to 1 and y-values ranging from approximately -1 to 1. One purple point is located at approximately (0, 12).
* **Cyan Data Points:** Two cyan points are present, both near the origin (x ≈ 0, y ≈ 0).
* **Orange Data Points:** Two orange points are present, both near the origin (x ≈ 0, y ≈ 0).
* **Red Data Points:** One red point, marked with an 'x', is located at approximately (0, 0).
* **Green Data Points:** One green point is located at approximately (0, 0).
**PC5-PC6:**
* **Purple Data Points:** A cluster of approximately 5 purple points are located in the top-right quadrant, with x-values ranging from approximately 2 to 8 and y-values ranging from approximately 2 to 8. One purple point is located at approximately (-8, -8).
* **Cyan Data Points:** Two cyan points are present. One is near the origin (x ≈ 0, y ≈ 0), and the other is at approximately (8, -6).
* **Orange Data Points:** Two orange points are present. One is at approximately (0, 0) and the other is at approximately (2, 2).
* **Red Data Points:** One red point is located at approximately (0, 0).
* **Green Data Points:** One green point is located at approximately (0, 0).
### Key Observations
* The purple data points consistently show the most spread across the principal components.
* The cyan, orange, red, and green data points are often clustered near the origin in each plot.
* The red 'x' in PC3-PC4 is an outlier, as it is explicitly marked and distinct from the other points.
* The token "wrong" above the PC1-PC2 plot suggests a potential issue or labeling error related to that projection.
### Interpretation
These plots visualize the results of a PCA, a dimensionality reduction technique. Each point represents a data sample, and its position in the plot indicates its projection onto the selected principal components. The principal components are ordered by the amount of variance they explain in the data.
The spread of the purple points suggests that this group exhibits the most variance across the principal components, meaning it contains the most information. The clustering of the other colors near the origin suggests they have less variance and may be more similar to each other.
The "wrong" token above the PC1-PC2 plot is significant. It implies that the projection or the data used for that specific PCA might be flawed or mislabeled. This could be due to errors in data preprocessing, feature selection, or the PCA algorithm itself. The red 'x' in PC3-PC4 could represent an outlier or an error in the data.
The different PCA projections (PC1-PC2, PC3-PC4, PC5-PC6) provide different perspectives on the data's structure. Analyzing these projections together can help identify underlying patterns and relationships within the data. The fact that the cyan, orange, red, and green points are often clustered near the origin in each plot suggests they may represent a relatively homogeneous group.
</details>
Figure 12: Latent Space trajectories for select tokens. We show a small part of these high-dimensional trajectories by visualizing the first 6 PCA directions, computing the PCA over all latent state trajectories of all tokens in a sequence. The color gradient going from dark to bright represents steps in the trajectory. The center of mass is marked in red. While on many tokens, the state simply converges (top row), the model also learns to use orbits (middle row), and “sliders” (bottom row, middle), which we observe being used to represent and handle more advanced concepts, such as arithmetic or complicated deliberation.
## 7 What Mechanisms Emerge at Scale in Recurrent-Depth Models
Finally, what is the model doing while recurring in latent space? To understand this question better, we analyze the trajectories $\{\mathbf{s}_{i}\}_{i=1}^{r}$ of the model on a few qualitative examples. We are especially interested in understanding what patterns emerge, simply by training this model at scale. In comparison to previous work, such as Bai et al. (2019), where the training objective directly encodes a prior that pushes trajectories to a fixed point, we only train with our truncated unrolling objective.
Figure 11 shows the norm distance $||\mathbf{s}_{i}-\mathbf{s}^{*}||$ between each $\mathbf{s}_{i}$ in a trajectory and an approximate limit point $\mathbf{s}^{*}$ computed with 128 iterations. We show the sentence top to bottom and iterations from left to right. We clearly see that convergence behavior depends on context. We see that key parts of the question, and the start of the model response, are “deliberated” much more in latent space. The context dependence can also be seen in the different behavior among the three identical tokens representing each of the three dots. Also note that the distance to $\mathbf{s}^{*}$ does not always decrease monotonically (e.g. for school); the model may also trace out complicated orbits in its latent trajectory while processing information, even though this is not represented explicitly in our training objective.
We look at trajectories for select tokens in more detail in Figure 12. We compute a PCA decomposition of latent trajectories over all tokens in a sequence, and then show several individual trajectories projected onto the first six PCA directions. See the appendix for more examples. Many tokens simply converge to a fixed point, such as the token in the top row. Yet, for harder questions, such as in the 2nd row This is the token ”3” in a GSM8k test question that opens with Claire makes a 3 egg omelette., the state of the token quickly falls into an orbit pattern in all three pairs of PCA directions. The use of multi-dimensional orbits like these could serve a similar purpose to periodic patterns sometimes observed in fixed-depth transformers trained for arithmetic tasks (Nanda et al., 2022), but we find these patterns extend far beyond arithmetic for our model. We often also observe the use of orbits on tokens such as “makes” (see Figure 16) or “thinks” that determine the structure of the response.
Aside from orbits, we also observe the model encoding particular key tokens as “sliders”, as seen in the middle of the bottom row in Figure 12 (which is the token “wrong”, from the same message as already shown in Figure 11). In these motions the trajectory noticeably drifts in a single direction, which the model could use to implement a mechanism to count how many iterations have occurred.
The emergence of structured trajectories in latent space gives us a glimpse into how the model performs its computations. Unlike the discrete sequential chain of reasoning seen in verbalized chain-of-thought approaches, we observe rich geometric patterns including orbits, convergent paths, and drifts - means to organize its computational process spatially. This suggests the model is independently learning to leverage the high-dimensional nature of its latent space to implement reasoning in new ways.
#### Path Independence.
We verify that our models maintain path independence, in the sense of Anil et al. (2022), despite their complex, learned dynamics, which we discussed prior (see also the additional examples in Appendix Figure 22). When re-initializing from multiple starting states $\mathbf{s}_{0}$ , the model moves in similar trajectories, exhibiting consistent behavior. The same orbital patterns, fixed points, or directional drifts emerge regardless of initialization.
## 8 Related Work Overview
The extent to which recurrence is a foundational concept of machine learning is hard to overstate (Amari, 1972; Hopfield, 1982; Braitenberg, 1986; Gers and Schmidhuber, 2000; Sutskever et al., 2008). Aside from using recurrence to move along sequences, as in recurrent neural networks, it was understood early to also be the key to adaptive computation (Schmidhuber, 2012; Graves, 2017). For transformers, recurrence was applied in Dehghani et al. (2019), who highlight the aim of recurrent depth to model universal, i.e. Turing-complete, machines (Graves et al., 2014). It was used at scale (but with fixed recurrence) in Lan et al. (2019) and an interesting recent improvement in this line of work are described in Tan et al. (2023); Abnar et al. (2023), Mathur et al. (2024) and Csordás et al. (2024). Schwarzschild et al. (2021b); Bansal et al. (2022); Bear et al. (2024) and McLeish et al. (2024) show that depth recurrence is advantageous when learning generalizable algorithms when training with randomized unrolling and input injections. Recent work has described depth-recurrent, looped, transformers and studied their potential benefits with careful theoretical and small-scale analysis (Giannou et al., 2023; Gatmiry et al., 2024; Yang et al., 2024a; Fan et al., 2025).
From another angle, these models can be described as neural networks learning a fixed-point iteration, as studied in deep equilibrium models (Bai et al., 2019; 2022). They are further related to diffusion models (Song and Ermon, 2019), especially latent diffusion models (Rombach et al., 2022), but we note that language diffusion models are usually run with a per-sequence, instead of a per-token, iteration count (Lee et al., 2018). A key difference of our approach to both equilibrium models and diffusion models is in the training objective, where equilibrium methods solve the “direct” problem (Geiping and Moeller, 2019), diffusion models solve a surrogate training objective, and our work suggests that truncated unrolling is a scalable alternative.
More generally, all architectures that recur in depth can also be understood as directly learning the analog to the gradient of a latent energy-based model (LeCun and Huang, 2005; LeCun, 2022), to an implicitly defined intermediate optimization layer (Amos and Kolter, 2017), or to a Kuramoto layer (Miyato et al., 2024). Analogies to gradient descent at inference time also show the connection to test time adaptation (Sun et al., 2020), especially test-time adaptation of output states (Boudiaf et al., 2022).
Aside from full recurrent-depth architectures, there also exist a number of proposals for hybrid architectures, such as models with latent sub-networks (Li et al., 2020a), LoRA adapters on top of weight-shared layers (Bae et al., 2024), or (dynamic) weight-tying of trained models (Hay and Wolf, 2023; Liu et al., 2024b).
As mentioned in Section 6, while we consider the proposed recurrent depth approach to be a very natural way to learn to reason in continuous latent space from the ground up, the works of Hao et al. (2024); Cheng and Durme (2024) and Liu et al. (2024a) discuss how to finetune existing fixed-depth transformers with this capability. These works have a similar aim to ours, enabling reasoning in latent space, but approach this goal from separate directions.
For additional discussions related to the idea of constructing a prior that incentivizes reasoning and algorithm learning at the expense of memorization of simple patterns, we also refer to Chollet (2019), Schwarzschild (2023), Li et al. (2020b) and Moulton (2023).
## 9 Future Work
Aside from work extending and analyzing the scaling behaviors of recurrent depth models, there are many questions that remain unanswered. For example, to us, there are potentially a large number of novel post-training schemes that further enhance the capabilities of these models, such as fine-tuning to compress the recurrence or reinforcement learning with data with different hardness levels (Zelikman et al., 2024), or to internalize reasoning from CoT data into the recurrence (Deng et al., 2024).
Another aspect not covered in this work is the relationship to other modern architecture improvements. Efficient sequence mixing operations, especially those that are linear in sequence dimension, such as linear attention (Katharopoulos et al., 2020; Yang et al., 2024b), are limited in the number of comparisons that can be made. However, with recurrent depth, blocks containing linear operators can repeat until all necessary comparisons between sequence elements are computed (Suzgun et al., 2019). For simplicity, we also focus on a single recurrence, where prior work has considered multiple successive recurrent stages (Takase and Kiyono, 2023; Csordás et al., 2024).
Finally, the proposed architecture is set up to be compute-heavy, with more “materialized” parameters than there are actual parameters. This naturally mirrors mixture-of-expert models (MoE), which are parameter-heavy, using fewer active parameters per forward pass than exist within the model (Shazeer et al., 2017; Fedus et al., 2022). We posit that where the recurrent-depth setup excels at learning reasoning patterns, the MoE excels at effectively storing and retrieving complex information. Their complementarity supports the hypothesis that a future architecture would contain both modifications. While in a standard MoE model, each expert can only be activated once per forward pass, or skipped entirely, a recurrent MoE model could also refine its latent state over multiple iterations, potentially routing to the same expert multiple times, before switching to a different one (Tan et al., 2023; Csordás et al., 2024). While MoE models are the currently leading solution to implement this type of “memory” in dense transformers, these considerations also hold for other memory mechanisms suggested for LLMs (Sukhbaatar et al., 2019; Fan et al., 2021; Wu et al., 2022; He et al., 2024).
## 10 Conclusions
The models described in this paper are ultimately still a proof-of-concept. We describe how to train a latent recurrent-depth architecture, what parameters we chose, and then trained a single model at scale. Future training runs are likely to train with more optimized learning rate schedules, data mixes and accelerators. Still we observe a number of interesting behaviors emerging naturally from recurrent training. The most important of these is the ability to use latent reasoning to dramatically improve performance on reasoning tasks by expending test-time computation. In addition, we also observe context-dependent convergence speed, path independence, and various zero-shot abilities. This leads us to believe that latent reasoning is a promising research direction to complement existing approaches for test-time compute scaling. The model we realize is surprisingly powerful given its size and amount of training data, and we are excited about the potential impact of imbuing generative models with the ability to reason in continuous latent space without the need for specialized data at train time or verbalization at inference time.
## Acknowledgements
This project was made possible by the INCITE program: An award for computer time was provided by the U.S. Department of Energy’s (DOE) Innovative and Novel Computational Impact on Theory and Experiment (INCITE) Program. This research used resources of the Oak Ridge Leadership Computing Facility at the Oak Ridge National Laboratory, which is supported by the Office of Science of the U.S. Department of Energy under Contract No. DE-AC05-00OR22725. Work on the LLNL side was prepared by LLNL under Contract DE-AC52-07NA27344 and supported by the LLNL-LDRD Program under Project No. 24-ERD-010 and 24-ERD-058 (LLNL-CONF-872390). This manuscript has been authored by Lawrence Livermore National Security, LLC under Contract No. DE-AC52-07NA27344 with the U.S. Department of Energy. The United States Government retains a non-exclusive, paid-up, irrevocable, world-wide license to publish or reproduce the published form of this manuscript, or allow others to do so, for United States Government purposes.
JG further acknowledges the support of the Hector II foundation. A large number of small-scale and preliminary experiments were made possible through the support of the MPI Intelligent Systems compute cluster and funding by the Tübingen AI center.
UMD researchers were further supported by the ONR MURI program, DARPA TIAMAT, the National Science Foundation (IIS-2212182), and the NSF TRAILS Institute (2229885). Commercial support was provided by Capital One Bank, the Amazon Research Award program, and Open Philanthropy. Finally, we thank Avi Schwarzschild for helpful comments on the initial draft.
## References
- Abnar et al. (2023) Samira Abnar, Omid Saremi, Laurent Dinh, Shantel Wilson, Miguel Angel Bautista, Chen Huang, Vimal Thilak, Etai Littwin, Jiatao Gu, Josh Susskind, and Samy Bengio. 2023. Adaptivity and Modularity for Efficient Generalization Over Task Complexity. arxiv:2310.08866[cs].
- AI2 (2024) AI2. 2024. OLMo 1.7–7B: A 24 point improvement on MMLU.
- Allen-Zhu and Li (2024) Zeyuan Allen-Zhu and Yuanzhi Li. 2024. Physics of language models: Part 3.1, knowledge storage and extraction. In Proceedings of the 41st International Conference on Machine Learning, volume 235 of ICML’24, pages 1067–1077, Vienna, Austria. JMLR.org.
- Amari (1972) S.-I. Amari. 1972. Learning Patterns and Pattern Sequences by Self-Organizing Nets of Threshold Elements. IEEE Transactions on Computers, C-21(11):1197–1206.
- AMD (2021) AMD. 2021. AMD Instinct™ MI250X Accelerators.
- Amini et al. (2019) Aida Amini, Saadia Gabriel, Peter Lin, Rik Koncel-Kedziorski, Yejin Choi, and Hannaneh Hajishirzi. 2019. Mathqa: Towards interpretable math word problem solving with operation-based formalisms. arXiv preprint arXiv:1905.13319.
- Amos and Kolter (2017) Brandon Amos and J. Zico Kolter. 2017. OptNet: Differentiable Optimization as a Layer in Neural Networks. In International Conference on Machine Learning, pages 136–145.
- Anil et al. (2022) Cem Anil, Ashwini Pokle, Kaiqu Liang, Johannes Treutlein, Yuhuai Wu, Shaojie Bai, J. Zico Kolter, and Roger Baker Grosse. 2022. Path Independent Equilibrium Models Can Better Exploit Test-Time Computation. In Advances in Neural Information Processing Systems.
- Austin et al. (2021) Jacob Austin, Augustus Odena, Maxwell Nye, Maarten Bosma, Henryk Michalewski, David Dohan, Ellen Jiang, Carrie Cai, Michael Terry, Quoc Le, and 1 others. 2021. Program synthesis with large language models. arXiv preprint arXiv:2108.07732.
- Azerbayev et al. (2023) Zhangir Azerbayev, Hailey Schoelkopf, Keiran Paster, Marco Dos Santos, Stephen Marcus McAleer, Albert Q. Jiang, Jia Deng, Stella Biderman, and Sean Welleck. 2023. Llemma: An Open Language Model for Mathematics. In The Twelfth International Conference on Learning Representations.
- Bae et al. (2024) Sangmin Bae, Adam Fisch, Hrayr Harutyunyan, Ziwei Ji, Seungyeon Kim, and Tal Schuster. 2024. Relaxed Recursive Transformers: Effective Parameter Sharing with Layer-wise LoRA.
- Bai et al. (2019) Shaojie Bai, J. Zico Kolter, and Vladlen Koltun. 2019. Deep Equilibrium Models. In Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc.
- Bai et al. (2022) Shaojie Bai, Vladlen Koltun, and J. Zico Kolter. 2022. Neural Deep Equilibrium Solvers. In International Conference on Learning Representations.
- Bai et al. (2024) Yushi Bai, Jiajie Zhang, Xin Lv, Linzhi Zheng, Siqi Zhu, Lei Hou, Yuxiao Dong, Jie Tang, and Juanzi Li. 2024. LongWriter: Unleashing 10,000+ Word Generation from Long Context LLMs. arxiv:2408.07055[cs].
- Banino et al. (2021) Andrea Banino, Jan Balaguer, and Charles Blundell. 2021. PonderNet: Learning to Ponder. In 8th ICML Workshop on Automated Machine Learning (AutoML).
- Bansal et al. (2022) Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang, Micah Goldblum, and Tom Goldstein. 2022. End-to-end Algorithm Synthesis with Recurrent Networks: Extrapolation without Overthinking. In Advances in Neural Information Processing Systems.
- Bauschke et al. (2011) Heinz H. Bauschke, Sarah M. Moffat, and Xianfu Wang. 2011. Firmly nonexpansive mappings and maximally monotone operators: Correspondence and duality. arXiv:1101.4688 [math].
- Bear et al. (2024) Jay Bear, Adam Prügel-Bennett, and Jonathon Hare. 2024. Rethinking Deep Thinking: Stable Learning of Algorithms using Lipschitz Constraints. arxiv:2410.23451[cs].
- Bekman (2023) Stas Bekman. 2023. Machine Learning Engineering Open Book. Stasosphere Online Inc.
- Ben Allal et al. (2024) Loubna Ben Allal, Anton Lozhkov, Guilherme Penedo, Thomas Wolf, and Leandro von Werra. 2024. SmolLM-corpus.
- Biderman et al. (2023) Stella Biderman, Hailey Schoelkopf, Quentin Anthony, Herbie Bradley, Kyle O’Brien, Eric Hallahan, Mohammad Aflah Khan, Shivanshu Purohit, USVSN Sai Prashanth, Edward Raff, Aviya Skowron, Lintang Sutawika, and Oskar van der Wal. 2023. Pythia: A Suite for Analyzing Large Language Models Across Training and Scaling. arxiv:2304.01373[cs].
- Biderman et al. (2024) Stella Biderman, Hailey Schoelkopf, Lintang Sutawika, Leo Gao, Jonathan Tow, Baber Abbasi, Alham Fikri Aji, Pawan Sasanka Ammanamanchi, Sidney Black, Jordan Clive, Anthony DiPofi, Julen Etxaniz, Benjamin Fattori, Jessica Zosa Forde, Charles Foster, Jeffrey Hsu, Mimansa Jaiswal, Wilson Y. Lee, Haonan Li, and 11 others. 2024. Lessons from the Trenches on Reproducible Evaluation of Language Models. arxiv:2405.14782[cs].
- Bisk et al. (2020) Yonatan Bisk, Rowan Zellers, Ronan Le Bras, Jianfeng Gao, and Yejin Choi. 2020. Piqa: Reasoning about physical commonsense in natural language. In Thirty-Fourth AAAI Conference on Artificial Intelligence.
- Boudiaf et al. (2022) Malik Boudiaf, Romain Mueller, Ismail Ben Ayed, and Luca Bertinetto. 2022. Parameter-Free Online Test-Time Adaptation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 8344–8353.
- Braitenberg (1986) Valentino Braitenberg. 1986. Vehicles: Experiments in Synthetic Psychology. MIT press.
- Brandon et al. (2024) William Brandon, Mayank Mishra, Aniruddha Nrusimha, Rameswar Panda, and Jonathan Ragan Kelly. 2024. Reducing Transformer Key-Value Cache Size with Cross-Layer Attention. arxiv:2405.12981[cs].
- British Library Labs (2021) British Library Labs. 2021. Digitised Books. c. 1510 - c. 1900. JSONL (OCR Derived Text + Metadata). British Library.
- Cai et al. (2024) Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, and Tri Dao. 2024. Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads. In Forty-First International Conference on Machine Learning.
- character.ai (2024) character.ai. 2024. Optimizing AI Inference at Character.AI.
- Chen et al. (2021) Mark Chen, Jerry Tworek, Heewoo Jun, Qiming Yuan, Henrique Ponde de Oliveira Pinto, Jared Kaplan, Harri Edwards, Yuri Burda, Nicholas Joseph, Greg Brockman, Alex Ray, Raul Puri, Gretchen Krueger, Michael Petrov, Heidy Khlaaf, Girish Sastry, Pamela Mishkin, Brooke Chan, Scott Gray, and 39 others. 2021. Evaluating large language models trained on code. Preprint, arXiv:2107.03374.
- Cheng and Durme (2024) Jeffrey Cheng and Benjamin Van Durme. 2024. Compressed Chain of Thought: Efficient Reasoning Through Dense Representations. arxiv:2412.13171[cs].
- Choi (2023) Euirim Choi. 2023. GoodWiki dataset.
- Chollet (2019) François Chollet. 2019. On the Measure of Intelligence. arxiv:1911.01547[cs].
- Chowdhery et al. (2022) Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, and 48 others. 2022. PaLM: Scaling Language Modeling with Pathways. arXiv:2204.02311 [cs].
- Clark et al. (2018) Peter Clark, Isaac Cowhey, Oren Etzioni, Tushar Khot, Ashish Sabharwal, Carissa Schoenick, and Oyvind Tafjord. 2018. Think you have solved question answering? try arc, the ai2 reasoning challenge. arXiv:1803.05457v1.
- Cobbe et al. (2021) Karl Cobbe, Vineet Kosaraju, Mohammad Bavarian, Mark Chen, Heewoo Jun, Lukasz Kaiser, Matthias Plappert, Jerry Tworek, Jacob Hilton, Reiichiro Nakano, Christopher Hesse, and John Schulman. 2021. Training Verifiers to Solve Math Word Problems. arxiv:2110.14168[cs].
- Colegrove et al. (2024) Owen Colegrove, Vik Paruchuri, and OpenPhi-Team. 2024. Open-phi/textbooks $\cdot$ Datasets at Hugging Face.
- Csordás et al. (2024) Róbert Csordás, Kazuki Irie, Jürgen Schmidhuber, Christopher Potts, and Christopher D. Manning. 2024. MoEUT: Mixture-of-Experts Universal Transformers. In The Thirty-eighth Annual Conference on Neural Information Processing Systems.
- Dagan (2024) Gautier Dagan. 2024. Bpeasy.
- Dagan et al. (2024) Gautier Dagan, Gabriel Synnaeve, and Baptiste Rozière. 2024. Getting the most out of your tokenizer for pre-training and domain adaptation. arxiv:2402.01035[cs].
- Dao (2023) Tri Dao. 2023. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arxiv:2307.08691[cs].
- Dao et al. (2022) Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. 2022. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arxiv:2205.14135[cs].
- DeepSeek-AI et al. (2025) DeepSeek-AI, Daya Guo, Dejian Yang, Haowei Zhang, Junxiao Song, Ruoyu Zhang, Runxin Xu, Qihao Zhu, Shirong Ma, Peiyi Wang, Xiao Bi, Xiaokang Zhang, Xingkai Yu, Yu Wu, Z. F. Wu, Zhibin Gou, Zhihong Shao, Zhuoshu Li, Ziyi Gao, and 181 others. 2025. DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning. arxiv:2501.12948[cs].
- DeepSeek-AI et al. (2024) DeepSeek-AI, Aixin Liu, Bei Feng, Bing Xue, Bingxuan Wang, Bochao Wu, Chengda Lu, Chenggang Zhao, Chengqi Deng, Chenyu Zhang, Chong Ruan, Damai Dai, Daya Guo, Dejian Yang, Deli Chen, Dongjie Ji, Erhang Li, Fangyun Lin, Fucong Dai, and 181 others. 2024. DeepSeek-V3 Technical Report. arxiv:2412.19437[cs].
- Dehghani et al. (2019) Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Łukasz Kaiser. 2019. Universal Transformers. arxiv:1807.03819[cs, stat].
- Deng et al. (2024) Yuntian Deng, Yejin Choi, and Stuart Shieber. 2024. From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step. arxiv:2405.14838[cs].
- Ding et al. (2024) Hantian Ding, Zijian Wang, Giovanni Paolini, Varun Kumar, Anoop Deoras, Dan Roth, and Stefano Soatto. 2024. Fewer Truncations Improve Language Modeling. In Forty-First International Conference on Machine Learning.
- Ding et al. (2021) Ming Ding, Zhuoyi Yang, Wenyi Hong, Wendi Zheng, Chang Zhou, Da Yin, Junyang Lin, Xu Zou, Zhou Shao, Hongxia Yang, and Jie Tang. 2021. CogView: Mastering Text-to-Image Generation via Transformers. In Advances in Neural Information Processing Systems, volume 34, pages 19822–19835. Curran Associates, Inc.
- Elbayad et al. (2019) Maha Elbayad, Jiatao Gu, Edouard Grave, and Michael Auli. 2019. Depth-Adaptive Transformer. In International Conference on Learning Representations.
- Elhoushi et al. (2024) Mostafa Elhoushi, Akshat Shrivastava, Diana Liskovich, Basil Hosmer, Bram Wasti, Liangzhen Lai, Anas Mahmoud, Bilge Acun, Saurabh Agarwal, Ahmed Roman, Ahmed A. Aly, Beidi Chen, and Carole-Jean Wu. 2024. LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding. arxiv:2404.16710[cs].
- Everett et al. (2024) Katie Everett, Lechao Xiao, Mitchell Wortsman, Alexander A. Alemi, Roman Novak, Peter J. Liu, Izzeddin Gur, Jascha Sohl-Dickstein, Leslie Pack Kaelbling, Jaehoon Lee, and Jeffrey Pennington. 2024. Scaling Exponents Across Parameterizations and Optimizers. arxiv:2407.05872[cs].
- Fan et al. (2019) Angela Fan, Edouard Grave, and Armand Joulin. 2019. Reducing Transformer Depth on Demand with Structured Dropout. arxiv:1909.11556[cs, stat].
- Fan et al. (2021) Angela Fan, Thibaut Lavril, Edouard Grave, Armand Joulin, and Sainbayar Sukhbaatar. 2021. Addressing Some Limitations of Transformers with Feedback Memory. arxiv:2002.09402[cs, stat].
- Fan et al. (2025) Ying Fan, Yilun Du, Kannan Ramchandran, and Kangwook Lee. 2025. Looped Transformers for Length Generalization. In The Thirteenth International Conference on Learning Representations.
- Fedus et al. (2022) William Fedus, Barret Zoph, and Noam Shazeer. 2022. Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. arxiv:2101.03961[cs].
- Feng et al. (2023) Xidong Feng, Yicheng Luo, Ziyan Wang, Hongrui Tang, Mengyue Yang, Kun Shao, David Mguni, Yali Du, and Jun Wang. 2023. ChessGPT: Bridging Policy Learning and Language Modeling. Advances in Neural Information Processing Systems, 36:7216–7262.
- Gabarain (2024) Sebastian Gabarain. 2024. Locutusque/hercules-v5.0 $\cdot$ Datasets at Hugging Face.
- Gatmiry et al. (2024) Khashayar Gatmiry, Nikunj Saunshi, Sashank J. Reddi, Stefanie Jegelka, and Sanjiv Kumar. 2024. Can Looped Transformers Learn to Implement Multi-step Gradient Descent for In-context Learning?
- Geiping and Goldstein (2023) Jonas Geiping and Tom Goldstein. 2023. Cramming: Training a Language Model on a single GPU in one day. In Proceedings of the 40th International Conference on Machine Learning, pages 11117–11143. PMLR.
- Geiping and Moeller (2019) Jonas Geiping and Michael Moeller. 2019. Parametric Majorization for Data-Driven Energy Minimization Methods. In Proceedings of the IEEE International Conference on Computer Vision, pages 10262–10273.
- Gers and Schmidhuber (2000) F.A. Gers and J. Schmidhuber. 2000. Recurrent nets that time and count. In Proceedings of the IEEE-INNS-ENNS International Joint Conference on Neural Networks. IJCNN 2000. Neural Computing: New Challenges and Perspectives for the New Millennium, volume 3, pages 189–194 vol.3.
- Giannou et al. (2023) Angeliki Giannou, Shashank Rajput, Jy-Yong Sohn, Kangwook Lee, Jason D. Lee, and Dimitris Papailiopoulos. 2023. Looped Transformers as Programmable Computers. In Proceedings of the 40th International Conference on Machine Learning, pages 11398–11442. PMLR.
- Goyal et al. (2018) Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. 2018. Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour. arxiv:1706.02677[cs].
- Graves (2017) Alex Graves. 2017. Adaptive Computation Time for Recurrent Neural Networks. arxiv:1603.08983[cs].
- Graves et al. (2014) Alex Graves, Greg Wayne, and Ivo Danihelka. 2014. Neural Turing Machines. arxiv:1410.5401[cs].
- Groeneveld et al. (2024) Dirk Groeneveld, Iz Beltagy, Pete Walsh, Akshita Bhagia, Rodney Kinney, Oyvind Tafjord, Ananya Harsh Jha, Hamish Ivison, Ian Magnusson, Yizhong Wang, Shane Arora, David Atkinson, Russell Authur, Khyathi Raghavi Chandu, Arman Cohan, Jennifer Dumas, Yanai Elazar, Yuling Gu, Jack Hessel, and 24 others. 2024. OLMo: Accelerating the Science of Language Models. arxiv:2402.00838[cs].
- Hägele et al. (2024) Alexander Hägele, Elie Bakouch, Atli Kosson, Loubna Ben Allal, Leandro Von Werra, and Martin Jaggi. 2024. Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations. In Workshop on Efficient Systems for Foundation Models II @ ICML2024.
- Hao et al. (2024) Shibo Hao, Sainbayar Sukhbaatar, DiJia Su, Xian Li, Zhiting Hu, Jason Weston, and Yuandong Tian. 2024. Training Large Language Models to Reason in a Continuous Latent Space. arxiv:2412.06769[cs].
- Hay and Wolf (2023) Tamir David Hay and Lior Wolf. 2023. Dynamic Layer Tying for Parameter-Efficient Transformers. In The Twelfth International Conference on Learning Representations.
- He et al. (2024) Zexue He, Leonid Karlinsky, Donghyun Kim, Julian McAuley, Dmitry Krotov, and Rogerio Feris. 2024. CAMELoT: Towards Large Language Models with Training-Free Consolidated Associative Memory. arxiv:2402.13449[cs].
- Hendrycks et al. (2021a) Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt. 2021a. Measuring massive multitask language understanding. Proceedings of the International Conference on Learning Representations (ICLR).
- Hendrycks et al. (2021b) Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt. 2021b. Measuring Massive Multitask Language Understanding. In International Conference on Learning Representations.
- Hopfield (1982) J J Hopfield. 1982. Neural networks and physical systems with emergent collective computational abilities. Proceedings of the National Academy of Sciences of the United States of America, 79(8):2554–2558.
- Hu et al. (2024) Jiewen Hu, Thomas Zhu, and Sean Welleck. 2024. miniCTX: Neural Theorem Proving with (Long-)Contexts. arxiv:2408.03350[cs].
- Izmailov et al. (2018) Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. 2018. Averaging weights leads to wider optima and better generalization: 34th Conference on Uncertainty in Artificial Intelligence 2018, UAI 2018. 34th Conference on Uncertainty in Artificial Intelligence 2018, UAI 2018, pages 876–885.
- Jiang et al. (2023) Albert Q. Jiang, Wenda Li, and Mateja Jamnik. 2023. Multilingual Mathematical Autoformalization. arxiv:2311.03755[cs].
- Johannes Welbl (2017) Matt Gardner Johannes Welbl, Nelson F. Liu. 2017. Crowdsourcing multiple choice science questions.
- Kaddour (2022) Jean Kaddour. 2022. Stop Wasting My Time! Saving Days of ImageNet and BERT Training with Latest Weight Averaging. arxiv:2209.14981[cs, stat].
- Kaplan et al. (2024) Guy Kaplan, Matanel Oren, Yuval Reif, and Roy Schwartz. 2024. From Tokens to Words: On the Inner Lexicon of LLMs. arxiv:2410.05864[cs].
- Kaplan et al. (2020) Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B. Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. 2020. Scaling Laws for Neural Language Models. arxiv:2001.08361[cs, stat].
- Katharopoulos et al. (2020) Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. 2020. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. In Proceedings of the 37th International Conference on Machine Learning, pages 5156–5165. PMLR.
- Kenney (2024) Matthew Kenney. 2024. ArXivDLInstruct.
- Kim et al. (2024) Seungone Kim, Juyoung Suk, Shayne Longpre, Bill Yuchen Lin, Jamin Shin, Sean Welleck, Graham Neubig, Moontae Lee, Kyungjae Lee, and Minjoon Seo. 2024. Prometheus 2: An Open Source Language Model Specialized in Evaluating Other Language Models. In Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing, pages 4334–4353, Miami, Florida, USA. Association for Computational Linguistics.
- Kingma and Ba (2015) Diederik P. Kingma and Jimmy Ba. 2015. Adam: A Method for Stochastic Optimization. In International Conference on Learning Representations (ICLR), San Diego.
- Kryściński et al. (2022) Wojciech Kryściński, Nazneen Rajani, Divyansh Agarwal, Caiming Xiong, and Dragomir Radev. 2022. BookSum: A Collection of Datasets for Long-form Narrative Summarization. arxiv:2105.08209[cs].
- Lai et al. (2024) Xin Lai, Zhuotao Tian, Yukang Chen, Senqiao Yang, Xiangru Peng, and Jiaya Jia. 2024. Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs. arxiv:2406.18629[cs].
- Lan et al. (2019) Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, and Radu Soricut. 2019. ALBERT: A Lite BERT for Self-supervised Learning of Language Representations. In International Conference on Learning Representations.
- LeCun (2022) Yann LeCun. 2022. A Path Towards Autonomous Machine Intelligence. Preprint, Version 0.9.2:62.
- LeCun and Huang (2005) Yann LeCun and Fu Jie Huang. 2005. Loss functions for discriminative training of energy-based models. In AISTATS 2005 - Proceedings of the 10th International Workshop on Artificial Intelligence and Statistics, pages 206–213.
- Lee et al. (2018) Jason Lee, Elman Mansimov, and Kyunghyun Cho. 2018. Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing, pages 1173–1182, Brussels, Belgium. Association for Computational Linguistics.
- Leviathan et al. (2023) Yaniv Leviathan, Matan Kalman, and Yossi Matias. 2023. Fast Inference from Transformers via Speculative Decoding. In Proceedings of the 40th International Conference on Machine Learning, pages 19274–19286. PMLR.
- Levine et al. (2021) Yoav Levine, Noam Wies, Or Sharir, Hofit Bata, and Amnon Shashua. 2021. The Depth-to-Width Interplay in Self-Attention. arxiv:2006.12467[cs, stat].
- Lewkowycz et al. (2022) Aitor Lewkowycz, Anders Andreassen, David Dohan, Ethan Dyer, Henryk Michalewski, Vinay Ramasesh, Ambrose Slone, Cem Anil, Imanol Schlag, Theo Gutman-Solo, Yuhuai Wu, Behnam Neyshabur, Guy Gur-Ari, and Vedant Misra. 2022. Solving quantitative reasoning problems with language models. Preprint, arXiv:2206.14858.
- Li et al. (2023) Raymond Li, Loubna Ben Allal, Yangtian Zi, Niklas Muennighoff, Denis Kocetkov, Chenghao Mou, Marc Marone, Christopher Akiki, Jia Li, Jenny Chim, Qian Liu, Evgenii Zheltonozhskii, Terry Yue Zhuo, Thomas Wang, Olivier Dehaene, Joel Lamy-Poirier, Joao Monteiro, Nicolas Gontier, Ming-Ho Yee, and 39 others. 2023. StarCoder: May the source be with you! Transactions on Machine Learning Research.
- Li et al. (2020a) Xian Li, Asa Cooper Stickland, Yuqing Tang, and Xiang Kong. 2020a. Deep Transformers with Latent Depth. arxiv:2009.13102[cs].
- Li et al. (2020b) Yujia Li, Felix Gimeno, Pushmeet Kohli, and Oriol Vinyals. 2020b. Strong Generalization and Efficiency in Neural Programs. arxiv:2007.03629[cs].
- Liping Tang (2024) Omkar Pangarkar Liping Tang, Nikhil Ranjan. 2024. TxT360: A top-quality LLM pre-training dataset requires the perfect blend.
- Liu et al. (2023a) Hao Liu, Matei Zaharia, and Pieter Abbeel. 2023a. Ring attention with blockwise transformers for near-infinite context. arXiv preprint arXiv:2310.01889.
- Liu et al. (2024a) Luyang Liu, Jonas Pfeiffer, Jiaxing Wu, Jun Xie, and Arthur Szlam. 2024a. Deliberation in Latent Space via Differentiable Cache Augmentation. arxiv:2412.17747[cs].
- Liu et al. (2023b) Xiao Liu, Hanyu Lai, Hao Yu, Yifan Xu, Aohan Zeng, Zhengxiao Du, Peng Zhang, Yuxiao Dong, and Jie Tang. 2023b. WebGLM: Towards An Efficient Web-Enhanced Question Answering System with Human Preferences. In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, KDD ’23, pages 4549–4560, New York, NY, USA. Association for Computing Machinery.
- Liu et al. (2024b) Zechun Liu, Changsheng Zhao, Forrest Iandola, Chen Lai, Yuandong Tian, Igor Fedorov, Yunyang Xiong, Ernie Chang, Yangyang Shi, Raghuraman Krishnamoorthi, Liangzhen Lai, and Vikas Chandra. 2024b. MobileLLM: Optimizing Sub-billion Parameter Language Models for On-Device Use Cases. arxiv:2402.14905[cs].
- Liu et al. (2023c) Zhengzhong Liu, Aurick Qiao, Willie Neiswanger, Hongyi Wang, Bowen Tan, Tianhua Tao, Junbo Li, Yuqi Wang, Suqi Sun, Omkar Pangarkar, Richard Fan, Yi Gu, Victor Miller, Yonghao Zhuang, Guowei He, Haonan Li, Fajri Koto, Liping Tang, Nikhil Ranjan, and 9 others. 2023c. LLM360: Towards fully transparent open-source LLMs.
- Loshchilov and Hutter (2017) Ilya Loshchilov and Frank Hutter. 2017. Decoupled Weight Decay Regularization. arXiv:1711.05101 [cs, math].
- Lozhkov et al. (2024) Anton Lozhkov, Raymond Li, Loubna Ben Allal, Federico Cassano, Joel Lamy-Poirier, Nouamane Tazi, Ao Tang, Dmytro Pykhtar, Jiawei Liu, Yuxiang Wei, Tianyang Liu, Max Tian, Denis Kocetkov, Arthur Zucker, Younes Belkada, Zijian Wang, Qian Liu, Dmitry Abulkhanov, Indraneil Paul, and 47 others. 2024. StarCoder 2 and The Stack v2: The Next Generation.
- Lu et al. (2024) Zimu Lu, Aojun Zhou, Ke Wang, Houxing Ren, Weikang Shi, Junting Pan, Mingjie Zhan, and Hongsheng Li. 2024. MathCoder2: Better Math Reasoning from Continued Pretraining on Model-translated Mathematical Code. arxiv:2410.08196[cs].
- Majstorovic (2024) Sebastian Majstorovic. 2024. Selected Digitized Books | The Library of Congress.
- Markeeva et al. (2024) Larisa Markeeva, Sean McLeish, Borja Ibarz, Wilfried Bounsi, Olga Kozlova, Alex Vitvitskyi, Charles Blundell, Tom Goldstein, Avi Schwarzschild, and Petar Veličković. 2024. The CLRS-Text Algorithmic Reasoning Language Benchmark. arxiv:2406.04229[cs].
- Mathur et al. (2024) Mrinal Mathur, Barak A. Pearlmutter, and Sergey M. Plis. 2024. MIND over Body: Adaptive Thinking using Dynamic Computation. In The Thirteenth International Conference on Learning Representations.
- McLeish et al. (2024) Sean Michael McLeish, Arpit Bansal, Alex Stein, Neel Jain, John Kirchenbauer, Brian R. Bartoldson, Bhavya Kailkhura, Abhinav Bhatele, Jonas Geiping, Avi Schwarzschild, and Tom Goldstein. 2024. Transformers Can Do Arithmetic with the Right Embeddings. In The Thirty-eighth Annual Conference on Neural Information Processing Systems.
- Merrill et al. (2022) William Merrill, Ashish Sabharwal, and Noah A. Smith. 2022. Saturated Transformers are Constant-Depth Threshold Circuits. Transactions of the Association for Computational Linguistics, 10:843–856.
- Mihaylov et al. (2018) Todor Mihaylov, Peter Clark, Tushar Khot, and Ashish Sabharwal. 2018. Can a suit of armor conduct electricity? a new dataset for open book question answering. In EMNLP.
- Mikolov et al. (2011) Tomáš Mikolov, Stefan Kombrink, Lukáš Burget, Jan Černocký, and Sanjeev Khudanpur. 2011. Extensions of recurrent neural network language model. In 2011 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pages 5528–5531.
- Miyato et al. (2024) Takeru Miyato, Sindy Löwe, Andreas Geiger, and Max Welling. 2024. Artificial Kuramoto Oscillatory Neurons. In The Thirteenth International Conference on Learning Representations, Singapore.
- Moulton (2023) Ryan Moulton. 2023. The Many Ways that Digital Minds Can Know.
- Muennighoff et al. (2024) Niklas Muennighoff, Qian Liu, Armel Zebaze, Qinkai Zheng, Binyuan Hui, Terry Yue Zhuo, Swayam Singh, Xiangru Tang, Leandro von Werra, and Shayne Longpre. 2024. OctoPack: Instruction Tuning Code Large Language Models. arxiv:2308.07124[cs].
- Nam Pham (2023) Nam Pham. 2023. Tiny-textbooks (Revision 14de7ba).
- Nam Pham (2024) Nam Pham. 2024. Tiny-strange-textbooks (Revision 6f304f1).
- Nanda et al. (2022) Neel Nanda, Lawrence Chan, Tom Lieberum, Jess Smith, and Jacob Steinhardt. 2022. Progress measures for grokking via mechanistic interpretability. In The Eleventh International Conference on Learning Representations.
- Noci et al. (2022) Lorenzo Noci, Sotiris Anagnostidis, Luca Biggio, Antonio Orvieto, Sidak Pal Singh, and Aurelien Lucchi. 2022. Signal Propagation in Transformers: Theoretical Perspectives and the Role of Rank Collapse. In Advances in Neural Information Processing Systems.
- OpenAI (2024) OpenAI. 2024. New reasoning models: Openai o1-preview and o1-mini. https://openai.com/research/o1-preview-and-o1-mini.
- Ouyang et al. (2022) Long Ouyang, Jeff Wu, Xu Jiang, Diogo Almeida, Carroll L. Wainwright, Pamela Mishkin, Chong Zhang, Sandhini Agarwal, Katarina Slama, Alex Ray, John Schulman, Jacob Hilton, Fraser Kelton, Luke Miller, Maddie Simens, Amanda Askell, Peter Welinder, Paul Christiano, Jan Leike, and Ryan Lowe. 2022. Training language models to follow instructions with human feedback. arxiv:2203.02155[cs].
- Paster et al. (2023) Keiran Paster, Marco Dos Santos, Zhangir Azerbayev, and Jimmy Ba. 2023. OpenWebMath: An Open Dataset of High-Quality Mathematical Web Text. In The Twelfth International Conference on Learning Representations.
- Peebles and Xie (2023) William Peebles and Saining Xie. 2023. Scalable Diffusion Models with Transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 4195–4205.
- Radford et al. (2019) Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. 2019. Language Models are Unsupervised Multitask Learners. OpenAI, page 24.
- Rae et al. (2019) Jack W. Rae, Anna Potapenko, Siddhant M. Jayakumar, and Timothy P. Lillicrap. 2019. Compressive Transformers for Long-Range Sequence Modelling. arxiv:1911.05507[cs].
- Rajbhandari et al. (2020) Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, and Yuxiong He. 2020. ZeRO: Memory optimizations Toward Training Trillion Parameter Models. In SC20: International Conference for High Performance Computing, Networking, Storage and Analysis, pages 1–16.
- Rombach et al. (2022) Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. 2022. High-Resolution Image Synthesis With Latent Diffusion Models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 10684–10695.
- Sakaguchi et al. (2021) Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhagavatula, and Yejin Choi. 2021. WinoGrande: An adversarial winograd schema challenge at scale. Commun. ACM, 64(9):99–106.
- Sanh et al. (2021) Victor Sanh, Albert Webson, Colin Raffel, Stephen Bach, Lintang Sutawika, Zaid Alyafeai, Antoine Chaffin, Arnaud Stiegler, Arun Raja, Manan Dey, M. Saiful Bari, Canwen Xu, Urmish Thakker, Shanya Sharma Sharma, Eliza Szczechla, Taewoon Kim, Gunjan Chhablani, Nihal Nayak, Debajyoti Datta, and 21 others. 2021. Multitask Prompted Training Enables Zero-Shot Task Generalization. In International Conference on Learning Representations.
- Sanyal et al. (2024) Sunny Sanyal, Atula Tejaswi Neerkaje, Jean Kaddour, Abhishek Kumar, and sujay sanghavi. 2024. Early weight averaging meets high learning rates for LLM pre-training. In First Conference on Language Modeling.
- Schmidhuber (2012) Juergen Schmidhuber. 2012. Self-Delimiting Neural Networks. arxiv:1210.0118[cs].
- Schuster et al. (2022) Tal Schuster, Adam Fisch, Jai Gupta, Mostafa Dehghani, Dara Bahri, Vinh Q. Tran, Yi Tay, and Donald Metzler. 2022. Confident Adaptive Language Modeling. In Advances in Neural Information Processing Systems.
- Schwarzschild (2023) Avi Schwarzschild. 2023. Deep Thinking Systems: Logical Extrapolation with Recurrent Neural Networks. Ph.D. thesis, University of Maryland, College Park, College Park.
- Schwarzschild et al. (2021a) Avi Schwarzschild, Eitan Borgnia, Arjun Gupta, Arpit Bansal, Zeyad Emam, Furong Huang, Micah Goldblum, and Tom Goldstein. 2021a. Datasets for Studying Generalization from Easy to Hard Examples. arxiv:2108.06011[cs].
- Schwarzschild et al. (2021b) Avi Schwarzschild, Eitan Borgnia, Arjun Gupta, Furong Huang, Uzi Vishkin, Micah Goldblum, and Tom Goldstein. 2021b. Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks. In Advances in Neural Information Processing Systems, volume 34, pages 6695–6706. Curran Associates, Inc.
- Schwarzschild et al. (2023) Avi Schwarzschild, Sean Michael McLeish, Arpit Bansal, Gabriel Diaz, Alex Stein, Aakash Chandnani, Aniruddha Saha, Richard Baraniuk, Long Tran-Thanh, Jonas Geiping, and Tom Goldstein. 2023. Algorithm Design for Learned Algorithms.
- Sennrich et al. (2016) Rico Sennrich, Barry Haddow, and Alexandra Birch. 2016. Neural Machine Translation of Rare Words with Subword Units. In Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1715–1725, Berlin, Germany. Association for Computational Linguistics.
- Shao et al. (2024) Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Xiao Bi, Haowei Zhang, Mingchuan Zhang, YK Li, Y Wu, and 1 others. 2024. Deepseekmath: Pushing the limits of mathematical reasoning in open language models. arXiv preprint arXiv:2402.03300.
- Shazeer (2020) Noam Shazeer. 2020. GLU Variants Improve Transformer. arxiv:2002.05202[cs].
- Shazeer et al. (2017) Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton, and Jeff Dean. 2017. Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. arxiv:1701.06538[cs].
- Singh and Bhatele (2022) Siddharth Singh and Abhinav Bhatele. 2022. AxoNN: An asynchronous, message-driven parallel framework for extreme-scale deep learning. In 2022 IEEE International Parallel and Distributed Processing Symposium (IPDPS), pages 606–616.
- Singh et al. (2024) Siddharth Singh, Prajwal Singhania, Aditya Ranjan, John Kirchenbauer, Jonas Geiping, Yuxin Wen, Neel Jain, Abhimanyu Hans, Manli Shu, Aditya Tomar, Tom Goldstein, and Abhinav Bhatele. 2024. Democratizing AI: Open-source Scalable LLM Training on GPU-based Supercomputers. In 2024 SC24: International Conference for High Performance Computing, Networking, Storage and Analysis SC, pages 36–49. IEEE Computer Society.
- Skean et al. (2024) Oscar Skean, Md Rifat Arefin, Yann LeCun, and Ravid Shwartz-Ziv. 2024. Does Representation Matter? Exploring Intermediate Layers in Large Language Models. arxiv:2412.09563[cs].
- Soboleva et al. (2023) Daria Soboleva, Faisal Al-Khateeb, Joel Hestness, Nolan Dey, Robert Myers, and Jacob Robert Steeves. 2023. SlimPajama: A 627B token cleaned and deduplicated version of RedPajama.
- Soldaini et al. (2024) Luca Soldaini, Rodney Kinney, Akshita Bhagia, Dustin Schwenk, David Atkinson, Russell Authur, Ben Bogin, Khyathi Chandu, Jennifer Dumas, Yanai Elazar, Valentin Hofmann, Ananya Jha, Sachin Kumar, Li Lucy, Xinxi Lyu, Nathan Lambert, Ian Magnusson, Jacob Morrison, Niklas Muennighoff, and 17 others. 2024. Dolma: An Open Corpus of Three Trillion Tokens for Language Model Pretraining Research. In Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 15725–15788, Bangkok, Thailand. Association for Computational Linguistics.
- Song and Ermon (2019) Yang Song and Stefano Ermon. 2019. Generative Modeling by Estimating Gradients of the Data Distribution. arXiv:1907.05600 [cs, stat].
- Su et al. (2021) Jianlin Su, Yu Lu, Shengfeng Pan, Bo Wen, and Yunfeng Liu. 2021. RoFormer: Enhanced Transformer with Rotary Position Embedding. arxiv:2104.09864 [cs].
- Sukhbaatar et al. (2019) Sainbayar Sukhbaatar, Edouard Grave, Guillaume Lample, Herve Jegou, and Armand Joulin. 2019. Augmenting Self-attention with Persistent Memory. arxiv:1907.01470[cs, stat].
- Sun et al. (2024) Qi Sun, Marc Pickett, Aakash Kumar Nain, and Llion Jones. 2024. Transformer Layers as Painters. arxiv:2407.09298[cs].
- Sun et al. (2020) Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei Efros, and Moritz Hardt. 2020. Test-Time Training with Self-Supervision for Generalization under Distribution Shifts. In Proceedings of the 37th International Conference on Machine Learning, pages 9229–9248. PMLR.
- Sutskever et al. (2008) Ilya Sutskever, Geoffrey E Hinton, and Graham W Taylor. 2008. The Recurrent Temporal Restricted Boltzmann Machine. In Advances in Neural Information Processing Systems, volume 21. Curran Associates, Inc.
- Suzgun et al. (2019) Mirac Suzgun, Sebastian Gehrmann, Yonatan Belinkov, and Stuart M. Shieber. 2019. Memory-Augmented Recurrent Neural Networks Can Learn Generalized Dyck Languages. arxiv:1911.03329[cs].
- Takase and Kiyono (2023) Sho Takase and Shun Kiyono. 2023. Lessons on Parameter Sharing across Layers in Transformers. arxiv:2104.06022[cs].
- Takase et al. (2024) Sho Takase, Shun Kiyono, Sosuke Kobayashi, and Jun Suzuki. 2024. Spike No More: Stabilizing the Pre-training of Large Language Models. arxiv:2312.16903[cs].
- Tan et al. (2023) Shawn Tan, Yikang Shen, Zhenfang Chen, Aaron Courville, and Chuang Gan. 2023. Sparse Universal Transformer. arxiv:2310.07096[cs].
- Team Gemma et al. (2024) Team Gemma, Morgane Riviere, Shreya Pathak, Pier Giuseppe Sessa, Cassidy Hardin, Surya Bhupatiraju, Léonard Hussenot, Thomas Mesnard, Bobak Shahriari, Alexandre Ramé, Johan Ferret, Peter Liu, Pouya Tafti, Abe Friesen, Michelle Casbon, Sabela Ramos, Ravin Kumar, Charline Le Lan, Sammy Jerome, and 179 others. 2024. Gemma 2: Improving Open Language Models at a Practical Size. arxiv:2408.00118[cs].
- Team OLMo et al. (2025) Team OLMo, Pete Walsh, Luca Soldaini, Dirk Groeneveld, Kyle Lo, Shane Arora, Akshita Bhagia, Yuling Gu, Shengyi Huang, Matt Jordan, Nathan Lambert, Dustin Schwenk, Oyvind Tafjord, Taira Anderson, David Atkinson, Faeze Brahman, Christopher Clark, Pradeep Dasigi, Nouha Dziri, and 21 others. 2025. 2 OLMo 2 Furious. arxiv:2501.00656[cs].
- TogetherAI (2023) TogetherAI. 2023. Llama-2-7B-32K-Instruct — and fine-tuning for Llama-2 models with Together API.
- Toshniwal et al. (2024a) Shubham Toshniwal, Wei Du, Ivan Moshkov, Branislav Kisacanin, Alexan Ayrapetyan, and Igor Gitman. 2024a. OpenMathInstruct-2: Accelerating AI for Math with Massive Open-Source Instruction Data. arxiv:2410.01560[cs].
- Toshniwal et al. (2024b) Shubham Toshniwal, Ivan Moshkov, Sean Narenthiran, Daria Gitman, Fei Jia, and Igor Gitman. 2024b. OpenMathInstruct-1: A 1.8 Million Math Instruction Tuning Dataset. In The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track.
- Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention Is All You Need. arXiv:1706.03762 [cs].
- Wang et al. (2024a) Zengzhi Wang, Xuefeng Li, Rui Xia, and Pengfei Liu. 2024a. MathPile: A Billion-Token-Scale Pretraining Corpus for Math. In The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track.
- Wang et al. (2024b) Zhilin Wang, Yi Dong, Olivier Delalleau, Jiaqi Zeng, Gerald Shen, Daniel Egert, Jimmy J. Zhang, Makesh Narsimhan Sreedhar, and Oleksii Kuchaiev. 2024b. HelpSteer2: Open-source dataset for training top-performing reward models. arxiv:2406.08673[cs].
- Weber et al. (2024) Maurice Weber, Daniel Y. Fu, Quentin Gregory Anthony, Yonatan Oren, Shane Adams, Anton Alexandrov, Xiaozhong Lyu, Huu Nguyen, Xiaozhe Yao, Virginia Adams, Ben Athiwaratkun, Rahul Chalamala, Kezhen Chen, Max Ryabinin, Tri Dao, Percy Liang, Christopher Re, Irina Rish, and Ce Zhang. 2024. RedPajama: An Open Dataset for Training Large Language Models. In The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track.
- Williams and Peng (1990) Ronald J. Williams and Jing Peng. 1990. An Efficient Gradient-Based Algorithm for On-Line Training of Recurrent Network Trajectories. Neural Computation, 2(4):490–501.
- Wortsman et al. (2023a) Mitchell Wortsman, Tim Dettmers, Luke Zettlemoyer, Ari Morcos, Ali Farhadi, and Ludwig Schmidt. 2023a. Stable and low-precision training for large-scale vision-language models. Advances in Neural Information Processing Systems, 36:10271–10298.
- Wortsman et al. (2023b) Mitchell Wortsman, Tim Dettmers, Luke Zettlemoyer, Ari S. Morcos, Ali Farhadi, and Ludwig Schmidt. 2023b. Stable and low-precision training for large-scale vision-language models. In Thirty-Seventh Conference on Neural Information Processing Systems.
- Wu and Stock (2024) Mengshiou Wu and Mark Stock. 2024. Enhancing PyTorch Performance on Frontier with the RCCL OFI-Plugin.
- Wu et al. (2022) Yuhuai Wu, Markus Norman Rabe, DeLesley Hutchins, and Christian Szegedy. 2022. Memorizing Transformers. In International Conference on Learning Representations.
- Wu et al. (2024) Zijian Wu, Jiayu Wang, Dahua Lin, and Kai Chen. 2024. LEAN-GitHub: Compiling GitHub LEAN repositories for a versatile LEAN prover. arxiv:2407.17227[cs].
- Xu et al. (2024) Zhangchen Xu, Fengqing Jiang, Luyao Niu, Yuntian Deng, Radha Poovendran, Yejin Choi, and Bill Yuchen Lin. 2024. Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing. arxiv:2406.08464[cs].
- Yang et al. (2023) Kaiyu Yang, Aidan M. Swope, Alex Gu, Rahul Chalamala, Peiyang Song, Shixing Yu, Saad Godil, Ryan Prenger, and Anima Anandkumar. 2023. LeanDojo: Theorem Proving with Retrieval-Augmented Language Models. In Thirty-Seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track.
- Yang et al. (2024a) Liu Yang, Kangwook Lee, Robert D. Nowak, and Dimitris Papailiopoulos. 2024a. Looped Transformers are Better at Learning Learning Algorithms. In The Twelfth International Conference on Learning Representations.
- Yang et al. (2024b) Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, and Yoon Kim. 2024b. Parallelizing Linear Transformers with the Delta Rule over Sequence Length. In The Thirty-eighth Annual Conference on Neural Information Processing Systems.
- Ying et al. (2024) Huaiyuan Ying, Zijian Wu, Yihan Geng, Jiayu Wang, Dahua Lin, and Kai Chen. 2024. Lean Workbook: A large-scale Lean problem set formalized from natural language math problems. In The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track.
- Yu et al. (2023) Longhui Yu, Weisen Jiang, Han Shi, Jincheng Yu, Zhengying Liu, Yu Zhang, James Kwok, Zhenguo Li, Adrian Weller, and Weiyang Liu. 2023. MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models. In The Twelfth International Conference on Learning Representations.
- Zamirai et al. (2021) Pedram Zamirai, Jian Zhang, Christopher R. Aberger, and Christopher De Sa. 2021. Revisiting BFloat16 Training. arxiv:2010.06192[cs, stat].
- Zelikman et al. (2024) Eric Zelikman, Georges Harik, Yijia Shao, Varuna Jayasiri, Nick Haber, and Noah D. Goodman. 2024. Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking. arxiv:2403.09629[cs].
- Zellers et al. (2019) Rowan Zellers, Ari Holtzman, Yonatan Bisk, Ali Farhadi, and Yejin Choi. 2019. Hellaswag: Can a machine really finish your sentence? In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics.
- Zhai et al. (2022) Xiaohua Zhai, Alexander Kolesnikov, Neil Houlsby, and Lucas Beyer. 2022. Scaling Vision Transformers. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 12104–12113.
- Zhang and Sennrich (2019) Biao Zhang and Rico Sennrich. 2019. Root Mean Square Layer Normalization. In Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc.
- Zhang et al. (2024a) Ge Zhang, Scott Qu, Jiaheng Liu, Chenchen Zhang, Chenghua Lin, Chou Leuang Yu, Danny Pan, Esther Cheng, Jie Liu, Qunshu Lin, Raven Yuan, Tuney Zheng, Wei Pang, Xinrun Du, Yiming Liang, Yinghao Ma, Yizhi Li, Ziyang Ma, Bill Lin, and 26 others. 2024a. MAP-Neo: Highly Capable and Transparent Bilingual Large Language Model Series. arxiv:2405.19327[cs].
- Zhang et al. (2024b) Jun Zhang, Jue Wang, Huan Li, Lidan Shou, Ke Chen, Gang Chen, and Sharad Mehrotra. 2024b. Draft& Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding. In Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 11263–11282, Bangkok, Thailand. Association for Computational Linguistics.
- Zhang et al. (2024c) Yifan Zhang, Yifan Luo, Yang Yuan, and Andrew C. Yao. 2024c. Autonomous Data Selection with Language Models for Mathematical Texts. In ICLR 2024 Workshop on Navigating and Addressing Data Problems for Foundation Models.
- Zheng et al. (2024) Tianyu Zheng, Ge Zhang, Tianhao Shen, Xueling Liu, Bill Yuchen Lin, Jie Fu, Wenhu Chen, and Xiang Yue. 2024. OpenCodeInterpreter: Integrating Code Generation with Execution and Refinement. In Findings of the Association for Computational Linguistics: ACL 2024, pages 12834–12859, Bangkok, Thailand. Association for Computational Linguistics.
- Zhou et al. (2024) Fan Zhou, Zengzhi Wang, Qian Liu, Junlong Li, and Pengfei Liu. 2024. Programming Every Example: Lifting Pre-training Data Quality like Experts at Scale. arxiv:2409.17115[cs].
- Zhuo et al. (2024) Terry Yue Zhuo, Minh Chien Vu, Jenny Chim, Han Hu, Wenhao Yu, Ratnadira Widyasari, Imam Nur Bani Yusuf, Haolan Zhan, Junda He, Indraneil Paul, Simon Brunner, Chen Gong, Thong Hoang, Armel Randy Zebaze, Xiaoheng Hong, Wen-Ding Li, Jean Kaddour, Ming Xu, Zhihan Zhang, and 14 others. 2024. BigCodeBench: Benchmarking Code Generation with Diverse Function Calls and Complex Instructions.
<details>
<summary>x17.png Details</summary>

### Visual Description
\n
```markdown
## System Overview: Multi-Agent Path Finding (MAPF) Visualization
This document details the visualization of a Multi-Agent Path Finding (MAPF) system. The visualization displays agents navigating a grid-based environment, avoiding collisions and reaching their designated goals.
### 1. Environment Representation
* **Grid:** The environment is represented as a 2D grid. Each cell in the grid can be either:
* **Free Space:** Represents navigable areas.
* **Obstacle:** Represents blocked areas that agents cannot traverse.
* **Dimensions:** The grid dimensions are configurable (e.g., 20x20, 50x50).
* **Visualization:** Free space is typically displayed as white or a light color, while obstacles are displayed as black or a dark color.
### 2. Agent Representation
* **Shape:** Agents are visually represented as colored circles or squares.
* **Color Coding:** Each agent is assigned a unique color for easy identification.
* **Goal Markers:** Each agent's goal location is marked with a distinct symbol (e.g., a star, a cross) of a contrasting color.
* **Path Visualization:** The planned path for each agent is displayed as a series of connected lines or a highlighted trail.
### 3. Algorithm Visualization
The visualization supports the display of different MAPF algorithms:
| Algorithm | Description
</details>
Figure 13: Additional categories for Figure 10 in the main body.
Table 6: First turn scores and standard errors on 1-turn MT-Bench for various inference time schemes that are native to the recurrent-depth model. Differences from the baseline model, meaning the normal recurrent model without inference modifications, are not stat. significant.
| cache compression, $s=4$ baseline, 64 iterations cache compression, $s=16$ | 5.856 5.693 5.687 | 0.395 0.386 0.402 |
| --- | --- | --- |
| baseline, 32 iterations | 5.662 | 0.388 |
| cache compression, $s=8$ | 5.631 | 0.384 |
| KL exit, $t=$5\text{\times}{10}^{-4}$$ | 5.562 | 0.389 |
## Appendix A Additional Information
<details>
<summary>x18.png Details</summary>

### Visual Description
\n
## Heatmap: Addition Accuracy by Number of Operands
### Overview
This heatmap visualizes the addition accuracy as a function of the number of operands and the number of digits in each operand. The color intensity represents the accuracy, with darker blues indicating higher accuracy and lighter blues indicating lower accuracy.
### Components/Axes
* **Title:** "Addition Accuracy by Number of Operands" (centered at the top)
* **X-axis:** "Number of Digits" with markers 1, 2, 3, 4, 5, 6.
* **Y-axis:** "Number of Operands" with markers 2, 3, 4, 5, 6.
* **Colorbar:** Located on the right side, ranging from 0.0 (lightest blue) to 0.8 (darkest blue), representing accuracy.
### Detailed Analysis
The heatmap displays accuracy values for different combinations of operands and digits. The data is organized in a grid where each cell represents a specific combination.
* **2 Operands:**
* 1 Digit: 1.0
* 2 Digits: 1.0
* 3 Digits: 0.8
* 4 Digits: 0.7
* 5 Digits: 0.6
* 6 Digits: 0.5
* **3 Operands:**
* 1 Digit: 0.7
* 2 Digits: 0.4
* 3 Digits: 0.2
* 4 Digits: 0.0
* 5 Digits: 0.0
* 6 Digits: 0.0
* **4 Operands:**
* 1 Digit: 0.3
* 2 Digits: 0.0
* 3 Digits: 0.0
* 4 Digits: 0.0
* 5 Digits: 0.0
* 6 Digits: 0.0
* **5 Operands:**
* 1 Digit: 0.1
* 2 Digits: 0.0
* 3 Digits: 0.0
* 4 Digits: 0.0
* 5 Digits: 0.0
* 6 Digits: 0.0
* **6 Operands:**
* 1 Digit: 0.0
* 2 Digits: 0.0
* 3 Digits: 0.0
* 4 Digits: 0.0
* 5 Digits: 0.0
* 6 Digits: 0.0
**Trends:**
* Accuracy generally decreases as the number of operands increases.
* Accuracy generally decreases as the number of digits increases.
* For 5 and 6 operands, the accuracy is consistently 0.0 across all digit counts.
* For 4 operands, accuracy is only non-zero for 1 digit operands.
* For 3 operands, accuracy is only non-zero for 1 and 2 digit operands.
### Key Observations
* The highest accuracy (1.0) is achieved with 2 operands and 1 or 2 digits.
* Accuracy drops significantly when the number of operands exceeds 3.
* The combination of 6 operands and any number of digits results in 0 accuracy.
* The heatmap shows a clear negative correlation between the number of operands and accuracy.
* The heatmap shows a clear negative correlation between the number of digits and accuracy.
### Interpretation
The data suggests that addition operations become increasingly error-prone as the number of operands and the number of digits in each operand increase. This is likely due to the increased complexity of the calculation and the greater potential for errors to accumulate. The complete lack of accuracy for 5 and 6 operands indicates a fundamental limitation in the system's ability to handle such complex additions. The data could be used to inform the design of systems that perform addition, such as by limiting the number of operands or digits allowed, or by implementing error-checking mechanisms. The heatmap provides a clear visual representation of the trade-off between accuracy and complexity in addition operations.
</details>
<details>
<summary>x19.png Details</summary>

### Visual Description
\n
## Line Chart: Model Accuracy vs Number of Operands for Different Recurrence Levels
### Overview
This line chart depicts the relationship between model accuracy and the number of operands, for different levels of recurrence. The chart shows how accuracy degrades as the number of operands increases, and how this degradation varies with the recurrence level. The data appears to be for models trained with digits=1.
### Components/Axes
* **Title:** "Model Accuracy vs Number of Operands (digits=1) for Different Recurrence Levels"
* **X-axis:** "Number of Operands" - Scale ranges from 2 to 6, with markers at each integer value.
* **Y-axis:** "Accuracy" - Scale ranges from 0.0 to 1.0, with markers at 0.2 intervals.
* **Legend:** Located in the top-right corner, listing the following recurrence levels with corresponding colors:
* Recurrence 1 (Blue)
* Recurrence 2 (Orange)
* Recurrence 4 (Green)
* Recurrence 8 (Red)
* Recurrence 16 (Purple)
* Recurrence 24 (Pink)
* Recurrence 32 (Gray)
* Recurrence 48 (Cyan)
* Recurrence 64 (Yellow)
### Detailed Analysis
The chart contains nine lines, each representing a different recurrence level.
* **Recurrence 1 (Blue):** The line starts at approximately 0.98 at Number of Operands = 2, and decreases sharply to approximately 0.02 at Number of Operands = 5, remaining relatively flat at approximately 0.02 until Number of Operands = 6.
* **Recurrence 2 (Orange):** The line starts at approximately 0.95 at Number of Operands = 2, and decreases sharply to approximately 0.01 at Number of Operands = 4, remaining relatively flat at approximately 0.01 until Number of Operands = 6.
* **Recurrence 4 (Green):** The line starts at approximately 0.96 at Number of Operands = 2, and decreases sharply to approximately 0.01 at Number of Operands = 3, remaining relatively flat at approximately 0.01 until Number of Operands = 6.
* **Recurrence 8 (Red):** The line starts at approximately 0.97 at Number of Operands = 2, and decreases sharply to approximately 0.01 at Number of Operands = 4, remaining relatively flat at approximately 0.01 until Number of Operands = 6.
* **Recurrence 16 (Purple):** The line starts at approximately 0.96 at Number of Operands = 2, and decreases sharply to approximately 0.01 at Number of Operands = 4, remaining relatively flat at approximately 0.01 until Number of Operands = 6.
* **Recurrence 24 (Pink):** The line starts at approximately 0.95 at Number of Operands = 2, and decreases sharply to approximately 0.01 at Number of Operands = 4, remaining relatively flat at approximately 0.01 until Number of Operands = 6.
* **Recurrence 32 (Gray):** The line starts at approximately 0.96 at Number of Operands = 2, and decreases sharply to approximately 0.01 at Number of Operands = 4, remaining relatively flat at approximately 0.01 until Number of Operands = 6.
* **Recurrence 48 (Cyan):** The line starts at approximately 0.97 at Number of Operands = 2, and decreases sharply to approximately 0.02 at Number of Operands = 5, increasing slightly to approximately 0.03 at Number of Operands = 6.
* **Recurrence 64 (Yellow):** The line starts at approximately 0.98 at Number of Operands = 2, and decreases sharply to approximately 0.02 at Number of Operands = 5, increasing slightly to approximately 0.04 at Number of Operands = 6.
All lines exhibit a steep decline in accuracy as the number of operands increases from 2 to 4. Beyond 4 operands, the accuracy plateaus near 0.0, with some minor fluctuations.
### Key Observations
* Accuracy is high for all recurrence levels when the number of operands is 2.
* Accuracy drops dramatically for all recurrence levels as the number of operands increases beyond 2.
* The accuracy converges to a very low level (close to 0.0) for all recurrence levels when the number of operands is 5 or 6.
* Recurrence levels 48 and 64 show a slight increase in accuracy at 6 operands compared to 5.
### Interpretation
The data suggests that the model's ability to accurately process information degrades significantly as the complexity of the input increases (as measured by the number of operands). This is likely due to the model struggling to maintain context and relationships between a larger number of inputs. The fact that accuracy converges to a low level for all recurrence levels indicates that increasing the recurrence level does not significantly improve the model's ability to handle a large number of operands, at least under the conditions tested (digits=1). The slight increase in accuracy for recurrence levels 48 and 64 at 6 operands could be due to random noise or a minor benefit from the increased recurrence capacity, but it is not a substantial improvement. The rapid decline in accuracy highlights a limitation of the model architecture or training process when dealing with more complex inputs. The "digits=1" constraint suggests the model may be more robust with more complex input data.
</details>
<details>
<summary>x20.png Details</summary>

### Visual Description
\n
## Line Chart: Model Accuracy vs Number of Operands
### Overview
This line chart depicts the relationship between model accuracy and the number of operands, specifically for models with varying levels of recurrence (from 1 to 64). The chart appears to be evaluating performance on a task where the number of operands influences accuracy. The x-axis represents the number of operands, and the y-axis represents the accuracy. The chart is titled "Model Accuracy vs Number of Operands (digits=2) for Different Recurrence Levels".
### Components/Axes
* **Title:** Model Accuracy vs Number of Operands (digits=2) for Different Recurrence Levels
* **X-axis Label:** Number of Operands (ranging from 2 to 6)
* **Y-axis Label:** Accuracy (ranging from 0.0 to 1.0)
* **Legend:** Located in the top-right corner, listing the recurrence levels:
* Recurrence 1 (Blue)
* Recurrence 2 (Orange)
* Recurrence 4 (Green)
* Recurrence 8 (Red)
* Recurrence 16 (Purple)
* Recurrence 24 (Pink)
* Recurrence 32 (Gray)
* Recurrence 48 (Cyan)
* Recurrence 64 (Yellow)
* **Gridlines:** Present to aid in reading values.
### Detailed Analysis
The chart displays nine lines, each representing a different recurrence level.
* **Recurrence 1 (Blue):** The line starts at approximately 0.95 at 2 operands, decreases sharply to around 0.25 at 3 operands, continues to decrease to approximately 0.15 at 4 operands, then rises slightly to around 0.25 at 5 operands, and finally decreases to approximately 0.18 at 6 operands.
* **Recurrence 2 (Orange):** Starts at approximately 0.9 at 2 operands, drops to around 0.1 at 3 operands, remains relatively flat at around 0.05 to 0.1 from 4 to 6 operands.
* **Recurrence 4 (Green):** Starts at approximately 0.05 at 2 operands, drops to approximately 0.01 at 3 operands, remains very close to 0.0 from 4 to 6 operands.
* **Recurrence 8 (Red):** Starts at approximately 0.9 at 2 operands, drops sharply to approximately 0.1 at 3 operands, decreases to approximately 0.05 at 4 operands, rises to approximately 0.15 at 5 operands, and then decreases to approximately 0.1 at 6 operands.
* **Recurrence 16 (Purple):** Starts at approximately 0.95 at 2 operands, decreases to approximately 0.3 at 3 operands, decreases to approximately 0.15 at 4 operands, rises to approximately 0.45 at 5 operands, and then decreases to approximately 0.2 at 6 operands.
* **Recurrence 24 (Pink):** Starts at approximately 0.95 at 2 operands, decreases to approximately 0.3 at 3 operands, decreases to approximately 0.1 at 4 operands, rises to approximately 0.55 at 5 operands, and then decreases to approximately 0.2 at 6 operands.
* **Recurrence 32 (Gray):** Starts at approximately 0.9 at 2 operands, decreases to approximately 0.25 at 3 operands, decreases to approximately 0.1 at 4 operands, rises to approximately 0.4 at 5 operands, and then decreases to approximately 0.15 at 6 operands.
* **Recurrence 48 (Cyan):** Starts at approximately 0.7 at 2 operands, decreases to approximately 0.2 at 3 operands, decreases to approximately 0.05 at 4 operands, rises to approximately 0.3 at 5 operands, and then decreases to approximately 0.15 at 6 operands.
* **Recurrence 64 (Yellow):** Starts at approximately 0.1 at 2 operands, remains relatively flat around 0.0 to 0.1 from 3 to 6 operands.
### Key Observations
* Accuracy generally decreases as the number of operands increases for most recurrence levels.
* Recurrence levels 1, 2, 8, 16, 24, 32, and 48 show a significant drop in accuracy when moving from 2 to 3 operands.
* Recurrence 4 and 64 consistently exhibit very low accuracy across all operand numbers.
* Recurrence 24 and 16 show a peak in accuracy at 5 operands.
* The lines for recurrence levels 1, 2, 8, 16, 24, 32, and 48 exhibit a similar trend, with a sharp initial drop followed by fluctuations.
### Interpretation
The data suggests that increasing the number of operands negatively impacts the model's accuracy, particularly for recurrence levels 1, 2, 8, 16, 24, 32, and 48. The sharp drop in accuracy at 3 operands could indicate a limitation in the model's ability to handle more complex inputs or a need for more sophisticated training data. The consistently low accuracy of recurrence levels 4 and 64 suggests that these levels are not effective for this particular task. The peak in accuracy observed at 5 operands for recurrence levels 16 and 24 might be a result of the model finding a sweet spot in its ability to process that specific number of operands. The "digits=2" in the title suggests the input data consists of two-digit numbers, and the model's performance is being evaluated based on its ability to process these numbers with varying operand counts. The chart highlights the importance of selecting an appropriate recurrence level and understanding the impact of input complexity on model accuracy. The data suggests that there is a trade-off between recurrence level and the number of operands, and the optimal configuration depends on the specific task and data characteristics.
</details>
<details>
<summary>x21.png Details</summary>

### Visual Description
## Line Chart: Model Accuracy vs Number of Operands
### Overview
This line chart depicts the relationship between model accuracy and the number of operands, for models with varying levels of recurrence. The chart focuses on models trained with digits=3. The x-axis represents the number of operands, ranging from 2 to 6. The y-axis represents the accuracy, ranging from 0 to 0.9. Multiple lines represent different recurrence levels.
### Components/Axes
* **Title:** "Model Accuracy vs Number of Operands (digits=3) for Different Recurrence Levels" - positioned at the top-center of the chart.
* **X-axis Label:** "Number of Operands" - positioned at the bottom-center of the chart.
* X-axis Markers: 2, 3, 4, 5, 6
* **Y-axis Label:** "Accuracy" - positioned at the left-center of the chart.
* Y-axis Markers: 0, 0.2, 0.4, 0.6, 0.8, 1.0
* **Legend:** Located at the top-right of the chart.
* Recurrence 1 (Blue)
* Recurrence 2 (Orange)
* Recurrence 4 (Green)
* Recurrence 8 (Red)
* Recurrence 16 (Purple)
* Recurrence 24 (Pink)
* Recurrence 32 (Gray)
* Recurrence 48 (Cyan)
* Recurrence 64 (Yellow)
### Detailed Analysis
Here's a breakdown of each line's trend and approximate data points, cross-referencing with the legend colors:
* **Recurrence 1 (Blue):** The line slopes downward, starting at approximately 0.25 at 2 operands, decreasing to approximately 0.15 at 6 operands.
* **Recurrence 2 (Orange):** This line shows a steep decline, starting at approximately 0.85 at 2 operands, dropping to approximately 0.05 at 3 operands, and remaining near 0 for the rest of the operand range.
* **Recurrence 4 (Green):** The line starts at approximately 0.9 at 2 operands, decreases to approximately 0.3 at 3 operands, and then fluctuates between approximately 0.25 and 0.35 for the remaining operand values.
* **Recurrence 8 (Red):** This line begins at approximately 0.3 at 2 operands, drops to approximately 0.0 at 3 operands, and remains near 0 for the rest of the operand range.
* **Recurrence 16 (Purple):** The line starts at approximately 0.3 at 2 operands, decreases to approximately 0.25 at 3 operands, and then fluctuates between approximately 0.25 and 0.35 for the remaining operand values.
* **Recurrence 24 (Pink):** The line starts at approximately 0.25 at 2 operands, decreases to approximately 0.15 at 3 operands, and then fluctuates between approximately 0.15 and 0.25 for the remaining operand values.
* **Recurrence 32 (Gray):** The line starts at approximately 0.25 at 2 operands, decreases to approximately 0.15 at 3 operands, and then fluctuates between approximately 0.15 and 0.25 for the remaining operand values.
* **Recurrence 48 (Cyan):** The line starts at approximately 0.2 at 2 operands, decreases to approximately 0.1 at 3 operands, and then fluctuates between approximately 0.1 and 0.2 for the remaining operand values.
* **Recurrence 64 (Yellow):** The line starts at approximately 0.15 at 2 operands, decreases to approximately 0.05 at 3 operands, and then fluctuates between approximately 0.05 and 0.1 for the remaining operand values.
### Key Observations
* Recurrence 2 exhibits a dramatic drop in accuracy as the number of operands increases beyond 2.
* Recurrence 4 shows the highest initial accuracy at 2 operands.
* For most recurrence levels (4, 16, 24, 32, 48, 64), the accuracy plateaus or fluctuates slightly between 3 and 6 operands.
* Recurrence 8 also shows a rapid decline in accuracy after 2 operands.
* The lines for Recurrence 16, 24, 32, 48, and 64 are clustered together, indicating similar performance.
### Interpretation
The chart suggests that increasing the number of operands generally decreases model accuracy, particularly for lower recurrence levels (2 and 8). The steep decline in accuracy for Recurrence 2 indicates that this model struggles to generalize as the complexity of the input (number of operands) increases. The plateauing of accuracy for higher recurrence levels (16, 24, 32, 48, 64) suggests that these models may have reached a point of diminishing returns, where adding more operands does not significantly improve performance. The initial high accuracy of Recurrence 4 could indicate a sweet spot for this specific configuration with digits=3. The data implies that there is a trade-off between model complexity (recurrence level) and the ability to handle increasing input complexity (number of operands). The fact that the lines converge at higher operand counts suggests that all models eventually perform similarly poorly as the problem becomes more difficult.
</details>
Figure 14: Multi-Operand Arithmetic. Following a precedent of training recurrent architectures for algorithmic and arithmetic tasks (Schwarzschild et al., 2021b; Bansal et al., 2022; Schwarzschild et al., 2023; McLeish et al., 2024), we explore whether our model can leverage increased test-time compute via recurrence to solve verbalized addition problems of increased difficulty. For these problems we use the following system prompt ‘‘You are a helpful assistant that is capable of helping users with mathematical reasoning.’’ embedded in a conversational chat template, and we present each problem by opening the first user turn of the conversation like so: f"What is the result of ’ + ’.join(map(str, digits))?" after randomly sampling numbers according to a certain operand count and digit count (base 10). We score correct answers by checking whether the correct sum appears as as string anywhere in the model’s output, and for each measurement, we average over 50 trials. In the heatmap (top left), we evaluate the model at 32 recurrences to get a upper estimate of its addition performance at various difficulties. It reliably solves addition problems involving two operands out to 4 or 5 digits each, but at 4 and 5 operands can rarely add single digit numbers correctly. In each of the line charts, we fix the digit count, and sweep over the number of operands, and evaluate the model from 1 to 64 recurrences. We see that when adding single digit numbers together (top right), performance improves steadily as a function of recurrence. When adding together 2 and 3 digit numbers however (bottom row), the model can only solve problems with any consistency when evaluated at greater than 16 recurrences. Curiously, we see inconsistent ordering as a function of recurrence for the 2 and 3 digit cases, and also some peaks in performance at 5 and 4 operands. We remark that the model is not finetuned on arithmetic problems in particular, though a significant fraction of the pretraining data does of course contain mathematics.
## Potential Implications of This Work
This work describes a novel architecture and training objective for language modeling with promising performance, especially on tasks that require the model to reason. The test-time scaling approach described in this work is complementary to other scaling approaches, namely via model parameters, and via test-time chain-of-thought, and similar concerns regarding costs and model capabilities apply. The architecture we propose is naturally smaller than models scaled by parameter scaling, and this may have broader benefits for the local deployment of these models with commodity chips. Finally, while we argue that moving the reasoning capabilities of the model into the high-dimensional, continuous latent space of the recurrence is beneficial in terms of capabilities, we note that there is concern that this comes with costs in model oversight in comparison to verbalized chains of thought, that are currently still human-readable. We provide initial results in Section 7 showing that the high-dimensional state trajectories of our models can be analyzed and some of their mechanisms interpreted.
### A.1 Classical Reasoning Problems
We include a small study of the classical problem of multi-operand arithmetic in Figure 14.
### A.2 Implementation Details
#### Device Speed Details
Nominally, each MI250X (AMD, 2021) achieves 383 TFLOP in bfloat16, i.e. 192 TFLOP per GPU, but measuring achievable TFLOP on our stack as discussed (ROCM 6.2.0, PyTorch 2.6 pre-release 11/02) for arbitrary matrix multiplication shapes (i.e. we measure the peak achievable speed of the best possible shape iterating over shapes between 256 and 24576 in intervals of 256 and 110 (Bekman, 2023)), we measure a peak of 125 TFLOP/s on Frontier nodes. Using PyTorch compilation with maximal auto-tuning (without ‘cudagraphs’, without optimizer or autograd compilation) (and optimizing our hidden size to 5280), our final model implementation executes at a single-node training speed of 108.75 TFLOP/s, i.e. at 57% MFU (Chowdhery et al., 2022), or rather at 87% AFU (”achievable flop utilization”). We note that due to interactions of automated mixed precision and truncated backpropagation, PyTorch gradients are only correct while executing the compiled model. We further circumvent issues with the flash attention implementation shipped with PyTorch sdpa using the AMD fork of the original flash attention repository https://github.com/Dao-AILab/flash-attention/, which can be found at https://github.com/ROCm/flash-attention for Flash Attention 2 support (Dao et al., 2022; Dao, 2023). We experiment with fused head and loss implementations https://github.com/JonasGeiping/linear_cross_entropy_loss, but ultimately find that the most portable choice on our AMD setup is to let torch compilation handle this issue.
#### Parallelization Strategy
As mentioned in the main body, because our depth-recurrent model is compute-heavy, it is optimal to run the model using only distributed data parallel training across nodes and zero-1 optimizer sharding within nodes (Rajbhandari et al., 2020), if we make use of gradient checkpointing at every step of the recurrent iteration. This allows us to eschew more communication-heavy parallelization strategies that would be required for models with the same FLOP footprint, but more parameters, which require substantial planning on this system (Singh et al., 2024; Singh and Bhatele, 2022). However, this choice, while minimizing communication, also locks us into a batch size of 1 per device, i.e. 4096 in total, and 16M tokens per step.
#### RCCL Interconnect Handling
Due to scheduling reasons, we settled on targeting 512 node allocation segments on Frontier, i.e. 4096 GPUs. However, this posed a substantial network interconnect issue. The connection speed between frontier nodes is only acceptable, if RCCL (AMD GPU communication collectives) commands are routed through open fabrics interface calls, which happens via a particular plugin https://github.com/ROCm/aws-ofi-rccl. To achieve sufficient bus bandwidth above 100GB/s requires NCCL_NET_GDR_LEVEL=PHB, a setting that, on NVIDIA systems, allows packages to go through the CPU, and only uses direct interconnect if GPU and NIC are on the same (NUMA) node (Wu and Stock, 2024). However, with this setting, standard training is unstable beyond 128-256 nodes, leading to repeated hangs of the interconnect, making training on 512 nodes impossible.
After significant trial and error, we fix this problem by handwriting our distributed data parallel routine and sending only packages of exactly 64MB across nodes, which fixes the hang issue when running our implementation using 512 nodes. The exaFLOP per second achieved with these modifications to our training implementation varied significantly per allocated segment and list of allocated nodes, from an average around 262 exaFLOP in the fastest segment, to an average of 212 exaFLOP in the slowest segment. This is a range of 52-64 TFLOP/s per GPU, i.e. 41%-51% AFU, or 1-1.2M tokens per second.
#### Pretraining Metrics.
During the pretraining run, we run a careful tracking of optimizer and model health metrics, tracking effective Adam learning rates per layer, optimizer RMS (Wortsman et al., 2023a), $L^{2}$ and $L^{1}$ parameter and gradient norms, recurrence statistics such as $\frac{||s_{k}-s_{k-1}||}{||s_{k}||}$ , $||s_{k}||$ , $||s_{0}-s_{k}||$ . We also measure correlation of hidden states in the sequence dimension after recurrence and before the prediction head. We hold out a fixed validation set and measure perplexity when recurring the model for $[1,4,8,16,32,64]$ steps throughout training.
## Appendix B Latent Space Visualizations
On the next pages, we print a number of latent space visualizations in more details than was possible in Section 7. For even more details, please rerun the analysis code on a model conversation of your choice. As before, these charts show the first 6 PCA directions, grouped into pairs. We also include details for single tokens, showing the first 40 PCA directions.
<details>
<summary>extracted/6211213/figures/latent_waterfall_C_bright.png Details</summary>

### Visual Description
\n
## 3D Scatter Plot: PCA Visualization of Token Embeddings
### Overview
The image presents a 3D scatter plot visualizing token embeddings projected onto the first two Principal Components (PCA). The plot displays the distribution of tokens across three dimensions: PCA Direction 1, PCA Direction 2, and Token Position in Sequence. The tokens are color-coded, likely representing different categories or clusters.
### Components/Axes
* **X-axis:** PCA Direction 2, ranging from approximately -40 to 40.
* **Y-axis:** PCA Direction 1, ranging from approximately -40 to 40.
* **Z-axis:** Token Position in Sequence, ranging from approximately 0 to 350.
* **Colors:** Four distinct colors are used to represent different token categories:
* Purple
* Yellow
* Red
* Teal/Blue-Green
### Detailed Analysis
The plot shows a complex distribution of points in 3D space. Let's analyze each color group:
* **Purple:** This group forms a roughly linear cluster that slopes upwards and to the right. The points start near (PCA Direction 2 ≈ -30, PCA Direction 1 ≈ -30, Token Position ≈ 0) and extend to (PCA Direction 2 ≈ 30, PCA Direction 1 ≈ 30, Token Position ≈ 300). There is some scatter within this cluster, but the overall trend is clear.
* **Yellow:** This group appears as a more dispersed cloud of points, concentrated around (PCA Direction 2 ≈ 0, PCA Direction 1 ≈ 20, Token Position ≈ 150). It's less linearly structured than the purple group.
* **Red:** This group forms a curved, elongated cluster. It starts near (PCA Direction 2 ≈ -30, PCA Direction 1 ≈ 30, Token Position ≈ 100) and curves upwards and to the right, ending around (PCA Direction 2 ≈ 40, PCA Direction 1 ≈ 40, Token Position ≈ 300).
* **Teal/Blue-Green:** This group is the most dispersed, with points scattered across a wider range of values. It appears to be concentrated around (PCA Direction 2 ≈ 20, PCA Direction 1 ≈ -20, Token Position ≈ 200), but with significant outliers.
It's difficult to extract precise numerical values from the plot without the underlying data. However, we can estimate:
* **Purple:** Average (PCA Direction 2, PCA Direction 1, Token Position) ≈ (0, 0, 150) with a standard deviation of approximately 20 in each direction.
* **Yellow:** Average (PCA Direction 2, PCA Direction 1, Token Position) ≈ (0, 20, 150) with a standard deviation of approximately 15 in each direction.
* **Red:** Average (PCA Direction 2, PCA Direction 1, Token Position) ≈ (10, 35, 200) with a standard deviation of approximately 20 in each direction.
* **Teal/Blue-Green:** Average (PCA Direction 2, PCA Direction 1, Token Position) ≈ (20, -20, 200) with a standard deviation of approximately 30 in each direction.
### Key Observations
* The purple and red clusters exhibit a strong correlation between Token Position and PCA Direction 1 and 2, suggesting a sequential ordering of these tokens in the embedding space.
* The yellow and teal/blue-green clusters are more dispersed, indicating greater variability or less sequential structure.
* There is a clear separation between the four color groups in the PCA space, suggesting that they represent distinct semantic or functional categories of tokens.
* The red cluster has the highest values for both PCA Direction 1 and PCA Direction 2, indicating it is the most "extreme" in terms of these principal components.
### Interpretation
This visualization likely represents the output of a dimensionality reduction technique (PCA) applied to token embeddings from a language model or text processing task. Each point represents a token, and its position in the 3D space reflects its embedding vector. The color coding indicates different categories of tokens (e.g., parts of speech, named entities, or semantic classes).
The fact that the purple and red clusters exhibit a sequential pattern along the Token Position axis suggests that these tokens are ordered in a meaningful way within the original text. This could be due to their grammatical role (e.g., verbs following nouns) or their semantic relationship (e.g., related concepts appearing close together).
The separation between the color groups indicates that the PCA has successfully captured the underlying structure of the token embeddings, revealing distinct clusters of tokens with similar characteristics. The dispersion within each cluster reflects the variability of tokens within that category.
The visualization provides insights into the semantic and syntactic relationships between tokens, which can be useful for understanding the behavior of the language model or text processing system. Further analysis would require knowing the specific meaning of each color and the details of the embedding process.
</details>
<details>
<summary>extracted/6211213/figures/latent_waterfall_W_bright.png Details</summary>

### Visual Description
\n
## 3D Scatter Plot: PCA Visualization of Token Embeddings
### Overview
The image presents a 3D scatter plot visualizing the distribution of token embeddings after applying Principal Component Analysis (PCA). The plot displays data points in a three-dimensional space defined by PCA Direction 1, PCA Direction 2, and Token Position in Sequence. The data points are color-coded, with purple and orange representing different clusters or categories of tokens.
### Components/Axes
* **X-axis:** PCA Direction 1, ranging from approximately -40 to 40.
* **Y-axis:** PCA Direction 2, ranging from approximately -40 to 40.
* **Z-axis:** Token Position in Sequence, ranging from approximately 0 to 500.
* **Data Points:** Two distinct color groups: purple and orange. There is no explicit legend, but the color differentiation is clear.
* **Plot Type:** 3D Scatter Plot.
* **Coordinate System:** Cartesian.
### Detailed Analysis
The plot shows two main clusters of data points.
* **Purple Cluster:** This cluster is elongated along the Token Position in Sequence axis (Z-axis). The points are distributed across a range of PCA Direction 1 and PCA Direction 2 values. The cluster appears to start near (PCA Direction 1 = -30, PCA Direction 2 = -20, Token Position = 0) and extends to approximately (PCA Direction 1 = 30, PCA Direction 2 = 20, Token Position = 500). The density of points appears to be higher at lower Token Position values.
* **Orange Cluster:** This cluster is more compact and located in the positive quadrant of PCA Direction 1 and PCA Direction 2. It is centered around (PCA Direction 1 = 20, PCA Direction 2 = 20, Token Position = 300). The range of Token Position values within this cluster is narrower, spanning approximately from 100 to 450.
There is some overlap between the two clusters, particularly in the region where PCA Direction 1 and PCA Direction 2 are both positive.
### Key Observations
* The purple cluster exhibits a clear trend of increasing Token Position in Sequence, suggesting a sequential relationship or ordering of the tokens represented by this cluster.
* The orange cluster is more localized, indicating a more consistent or stable representation of the tokens within this group.
* The separation between the two clusters suggests that the PCA has successfully identified distinct features or characteristics of the token embeddings.
* The lack of a legend makes it difficult to determine the specific meaning of the color coding.
### Interpretation
This visualization likely represents the embedding space of tokens from a sequence (e.g., a sentence or document) after dimensionality reduction using PCA. The two clusters likely correspond to different types of tokens or different semantic categories.
The purple cluster's elongation along the Token Position axis suggests that these tokens are ordered in a meaningful way within the sequence. This could represent, for example, the sequential flow of words in a sentence. The orange cluster, being more compact, might represent tokens that are less dependent on their position in the sequence or that share a common semantic feature.
The separation between the clusters indicates that the PCA has captured important variations in the token embeddings, allowing for a clear distinction between the two groups. The overlap suggests that some tokens may exhibit characteristics of both groups.
Without further information about the data and the embedding model used, it is difficult to provide a more specific interpretation. However, this visualization provides valuable insights into the structure and organization of the token embeddings and the relationships between different tokens within the sequence.
</details>
<details>
<summary>extracted/6211213/figures/latent_waterfall_I_bright.png Details</summary>

### Visual Description
\n
## 3D Scatter Plot: PCA Visualization of Token Embeddings
### Overview
The image presents a 3D scatter plot visualizing token embeddings using Principal Component Analysis (PCA). The plot displays the distribution of tokens across two principal components (PCA Direction 1 and PCA Direction 2) and their position within a sequence (Token Position in Sequence). The points are color-coded, likely representing different token types or clusters.
### Components/Axes
* **X-axis:** PCA Direction 2, ranging from approximately -40 to 40.
* **Y-axis:** PCA Direction 1, ranging from approximately -40 to 40.
* **Z-axis:** Token Position in Sequence, ranging from approximately 0 to 140.
* **Data Points:** Colored points representing individual tokens in the sequence. The colors appear to be shades of red and purple.
* **No Legend:** There is no explicit legend provided in the image.
### Detailed Analysis
The plot shows a complex distribution of points in 3D space. The data appears to be clustered, with a general trend of increasing Token Position in Sequence as PCA Direction 1 increases.
* **Red Cluster:** A large cluster of red points is visible, concentrated around PCA Direction 1 values between 0 and 30, and PCA Direction 2 values between -20 and 20. These points span the entire range of Token Position in Sequence (0-140), but are more densely populated between 40 and 100.
* **Purple Cluster:** A second cluster of purple points is present, generally located at lower PCA Direction 1 values (between -30 and 10) and PCA Direction 2 values between -10 and 30. These points are more concentrated at lower Token Position in Sequence values (0-60).
* **Scattered Points:** There are scattered points throughout the space, particularly at higher PCA Direction 1 values (above 30) and PCA Direction 2 values (above 20). These points appear to be a mix of red and purple.
* **Trend:** The red cluster exhibits a positive correlation between PCA Direction 1 and Token Position in Sequence. As the value on PCA Direction 1 increases, the Token Position in Sequence tends to increase as well. The purple cluster shows a less pronounced trend.
### Key Observations
* The data is not uniformly distributed. There are distinct clusters of points, suggesting that the tokens can be grouped based on their PCA projections.
* The red cluster is more prominent and spans a wider range of Token Position in Sequence values than the purple cluster.
* The absence of a legend makes it difficult to interpret the meaning of the colors.
### Interpretation
This visualization likely represents the embedding space of tokens from a sequence, reduced to two dimensions using PCA. The different colors may represent different types of tokens (e.g., nouns, verbs, adjectives) or clusters identified through some other method. The plot suggests that the tokens are not randomly distributed in the embedding space, but rather exhibit some structure related to their position in the sequence and their semantic properties.
The positive correlation between PCA Direction 1 and Token Position in Sequence for the red cluster could indicate that tokens with higher values on this principal component tend to appear later in the sequence. This might be related to the grammatical structure of the sequence or the evolution of the topic over time.
The lack of a legend is a significant limitation, as it prevents a definitive interpretation of the color coding. Further analysis would require knowing what the colors represent. The plot is useful for visually identifying patterns and clusters in the embedding space, but it cannot provide a complete understanding of the data without additional information.
</details>
Figure 15: Main directions in latent space, for a) a math question, 2) a trivia question and 3) an unsafe question, which will be described in more detail below. Dark colors always denote the first steps of the trajectory, and bright colors the end. Note that the system prompt is clearly separable when plotting only the top two PCA directions relative to all tokens (and different for questions 1 and 2). Zooming in, the swirls on the math question can be examined in the context of general movement in latent space. More detailed visualizations follow on later pages.
<details>
<summary>x22.png Details</summary>

### Visual Description
\n
## Scatter Plot Matrix: Principal Component Analysis
### Overview
The image presents a scatter plot matrix visualizing Principal Component Analysis (PCA) results for five different tokens: "cla", "re", "makes", "a", and "3". Each token is represented across three different 2D scatter plots, showing relationships between different principal components (PC). The plots are arranged in a 5x3 grid, with each row corresponding to a token and each column representing a different pair of principal components.
### Components/Axes
Each individual plot within the matrix has the following components:
* **X-axis:** Represents one principal component (PC1, PC3, or PC5).
* **Y-axis:** Represents another principal component (PC2, PC4, or PC6).
* **Data Points:** Purple dots represent individual data points projected onto the principal component space.
* **Highlighted Point:** A red 'x' marks a specific data point within each plot, likely representing a centroid or a point of interest.
* **Connecting Lines:** Lines connect consecutive data points, suggesting a trajectory or order within the data.
* **Titles:** Each plot is titled with the token name and the principal components being displayed (e.g., "Token: “cla” PC1-PC2").
* **Axis Labels:** Axis labels indicate the principal component numbers (PC1, PC2, PC3, PC4, PC5, PC6).
* **Axis Scales:** The scales vary for each plot, ranging approximately from -14 to 14, -8 to 8, and -5 to 5.
### Detailed Analysis or Content Details
Here's a breakdown of each token's representation across the three plots:
**1. Token: “cla”**
* **PC1-PC2:** Data points are scattered around the origin, with a slight upward trend. Approximate range: PC1 (-3 to 3), PC2 (-2 to 7). The red 'x' is near (0, 0).
* **PC3-PC4:** Data points form a roughly horizontal line with some scatter. Approximate range: PC3 (-8 to 8), PC4 (-2 to 2). The red 'x' is near (0, 0).
* **PC5-PC6:** Data points are clustered in the upper-right quadrant. Approximate range: PC5 (-3 to 3), PC6 (0 to 3). The red 'x' is near (0, 0).
**2. Token: “re”**
* **PC1-PC2:** Data points are scattered, with a slight positive correlation. Approximate range: PC1 (-7 to 7), PC2 (-7 to 5). The red 'x' is near (0, 0).
* **PC3-PC4:** Data points form a roughly horizontal line with some scatter. Approximate range: PC3 (-6 to 6), PC4 (-2 to 4). The red 'x' is near (0, 0).
* **PC5-PC6:** Data points are scattered, with a slight positive correlation. Approximate range: PC5 (-5 to 5), PC6 (-5 to 5). The red 'x' is near (0, 0).
**3. Token: “makes”**
* **PC1-PC2:** Data points are widely scattered, with a negative correlation. Approximate range: PC1 (-14 to 4), PC2 (-8 to 4). The red 'x' is near (-2, 2).
* **PC3-PC4:** Data points form a circular cluster around the origin. Approximate range: PC3 (-7 to 7), PC4 (-7 to 7). The red 'x' is near (0, 0).
* **PC5-PC6:** Data points are scattered, with a slight positive correlation. Approximate range: PC5 (-10 to 12), PC6 (-12 to 10). The red 'x' is near (0, 0).
**4. Token: “a”**
* **PC1-PC2:** Data points are widely scattered, with a negative correlation. Approximate range: PC1 (-14 to 7), PC2 (-14 to 7). The red 'x' is near (0, 0).
* **PC3-PC4:** Data points form a roughly horizontal line with some scatter. Approximate range: PC3 (-12 to 12), PC4 (-12 to 12). The red 'x' is near (0, 0).
* **PC5-PC6:** Data points are scattered, with a slight positive correlation. Approximate range: PC5 (-12 to 12), PC6 (-12 to 12). The red 'x' is near (0, 0).
**5. Token: “3”**
* **PC1-PC2:** Data points form a circular cluster around the origin. Approximate range: PC1 (-4 to 4), PC2 (-4 to 4). The red 'x' is near (0, 0).
* **PC3-PC4:** Data points form a roughly horizontal line with some scatter. Approximate range: PC3 (-13 to 13), PC4 (-13 to 13). The red 'x' is near (0, 0).
* **PC5-PC6:** Data points are scattered, with a slight positive correlation. Approximate range: PC5 (-14 to 6), PC6 (-14 to 6). The red 'x' is near (0, 0).
### Key Observations
* The red 'x' consistently appears near the origin (0,0) in most plots, suggesting it represents a central tendency or baseline.
* The "makes" and "a" tokens exhibit the most scattered data points in the PC1-PC2 plots, indicating higher variance or less clear separation in these components.
* The PC3-PC4 plots for all tokens tend to show a more linear arrangement of data points.
* The PC5-PC6 plots generally show more scattered data, with some positive correlation.
### Interpretation
This scatter plot matrix visualizes the results of a Principal Component Analysis (PCA) applied to data associated with five different tokens. PCA is a dimensionality reduction technique used to identify the principal components – the directions of maximum variance in the data. Each plot represents a projection of the data onto a different pair of principal components.
The arrangement of data points in each plot reveals how well the data is separated or clustered in that particular component space. Tokens with more scattered data points (like "makes" and "a") may be more difficult to distinguish based on those components alone. The linear arrangements in the PC3-PC4 plots suggest a strong correlation between these components for all tokens.
The consistent placement of the red 'x' near the origin suggests it represents a central point or average for each token's projection onto the principal components. This could be used as a reference point for comparing the relative positions of individual data points.
The overall pattern suggests that the first few principal components (PC1, PC2, PC3) capture a significant portion of the variance in the data, while the later components (PC5, PC6) may represent more subtle variations. The specific interpretation of these components would depend on the original features used to generate the PCA. The lines connecting the points suggest a temporal or sequential relationship within the data, but without further context, the nature of this relationship remains unclear.
</details>
Figure 16: Latent Space trajectories for a math question. The model is rotating the number three, on which the problem hinges. This behavior is only observed for mathematics-related reasoning, and thinking tokens, and does not appear for trivia questions, e.g. as above. The question is Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks? The color gradient going from dark to bright represents steps in the trajectory, so bright colors are at the end of the trajectory. The center of mass is marked in red.
<details>
<summary>x23.png Details</summary>

### Visual Description
## Scatter Plot Matrix: Principal Component Analysis (PCA) for Tokens
### Overview
The image presents a scatter plot matrix visualizing the results of a Principal Component Analysis (PCA) performed on five different tokens: "Go", "e", "The", "!", and "Fa". Each token is represented by a different color, and the plots show the relationships between different principal components (PC1-PC6). There are 3 columns of plots, each showing a different pair of principal components.
### Components/Axes
The matrix consists of 15 individual scatter plots arranged in a 5x3 grid. Each plot displays two principal components as axes. The axes are labeled as follows:
* **Column 1:** PC1-PC2
* **Column 2:** PC3-PC4
* **Column 3:** PC5-PC6
Each plot has axes ranging from approximately -54 to 54 (PC1-PC2, PC3-PC4) and -35 to 35 (PC5-PC6). The tokens are represented by the following colors:
* "Go": Dark Red
* "e": Light Blue
* "The": Dark Blue
* "!": Dark Green
* "Fa": Dark Orange
### Detailed Analysis or Content Details
**Row 1: Token "Go"**
* **PC1-PC2:** The data points for "Go" (dark red) form a generally downward sloping curve. Approximate data points: (-16, -1.5), (-10, -3), (-5, -5), (0, -6), (5, -5), (10, -3), (16, -1.5).
* **PC3-PC4:** The "Go" data points (dark red) show a roughly linear, upward trend. Approximate data points: (-18, -54), (-8, -46), (0, -38), (8, -30), (18, -22).
* **PC5-PC6:** The "Go" data points (dark red) form a curved line, initially decreasing and then increasing. Approximate data points: (-29, -12), (-19, -10), (-9, -8), (0, -6), (9, -8), (19, -10), (29, -12).
**Row 2: Token "e"**
* **PC1-PC2:** The "e" data points (light blue) are scattered, with a slight upward trend. Approximate data points: (-16, -1), (-8, 0), (0, 1), (8, 2), (16, 3).
* **PC3-PC4:** The "e" data points (light blue) show a downward sloping curve. Approximate data points: (-18, 46), (-9, 38), (0, 30), (9, 22), (18, 14).
* **PC5-PC6:** The "e" data points (light blue) form a relatively straight line with a slight positive slope. Approximate data points: (-23, 11), (-11, 10), (1, 9), (13, 8), (23, 7).
**Row 3: Token "The"**
* **PC1-PC2:** The "The" data points (dark blue) show a curved line, initially decreasing and then increasing. Approximate data points: (-19, -19), (-9, -16), (0, -13), (9, -10), (19, -7).
* **PC3-PC4:** The "The" data points (dark blue) form a downward sloping curve. Approximate data points: (-16, 34), (-8, 28), (0, 22), (8, 16), (16, 10).
* **PC5-PC6:** The "The" data points (dark blue) show a curved line, initially decreasing and then increasing. Approximate data points: (-27, 16), (-13, 14), (1, 12), (15, 10), (27, 8).
**Row 4: Token "!"**
* **PC1-PC2:** The "!" data points (dark green) show a curved line, initially increasing and then decreasing. Approximate data points: (-26, 26), (-13, 30), (0, 32), (13, 30), (26, 26).
* **PC3-PC4:** The "!" data points (dark green) form a downward sloping curve. Approximate data points: (-12, 17), (-6, 14), (0, 11), (6, 8), (12, 5).
* **PC5-PC6:** The "!" data points (dark green) show a curved line, initially decreasing and then increasing. Approximate data points: (-35, 11), (-17, 10), (1, 9), (19, 10), (35, 11).
**Row 5: Token "Fa"**
* **PC1-PC2:** The "Fa" data points (dark orange) show a curved line, initially decreasing and then increasing. Approximate data points: (-22, 11), (-11, 8), (0, 5), (11, 2), (22, -1).
* **PC3-PC4:** The "Fa" data points (dark orange) show a generally upward sloping curve. Approximate data points: (-24, -52), (-12, -44), (0, -36), (12, -28), (24, -20).
* **PC5-PC6:** The "Fa" data points (dark orange) show a curved line, initially decreasing and then increasing. Approximate data points: (-23, 19), (-11, 16), (1, 13), (13, 10), (23, 7).
### Key Observations
* The tokens exhibit different distributions across the principal components, suggesting they are distinguishable based on these features.
* The "Go" token shows a clear downward trend in PC1-PC2, while "e" shows a slight upward trend.
* The "!" token has a more pronounced curved shape in PC1-PC2 compared to other tokens.
* The PC3-PC4 plots generally show downward trends for all tokens.
* The PC5-PC6 plots show more complex curved patterns for most tokens.
### Interpretation
This PCA plot matrix visualizes how different tokens are separated in a lower-dimensional space defined by the principal components. Each token's distribution reveals its unique characteristics based on the underlying data used for the PCA. The different trends observed in each plot suggest that the principal components capture different aspects of the tokens' variability. For example, PC1-PC2 might be capturing the overall "complexity" of the token, while PC3-PC4 might be related to its "frequency" or "position" within the data. The curved shapes observed in some plots indicate non-linear relationships between the original features and the principal components. This analysis could be used for token classification, clustering, or dimensionality reduction in natural language processing or other related fields. The differences in the distributions of the tokens across the principal components suggest that they are relatively distinct and can be effectively separated using PCA.
</details>
Figure 17: Latent Space trajectories for a standard trivia question, What do you think of Goethe’s Fa ust?. Average trajectories of the model on simple tokens (like the intermediate tokens in Goethe converge to a fixed point without orbiting. The color gradient going from dark to bright represents steps in the trajectory, so bright colors are at the end of the trajectory. The center of mass is marked in red.
<details>
<summary>x24.png Details</summary>

### Visual Description
\n
## Scatter Plot Matrix: Principal Component Analysis Visualizations
### Overview
The image presents a 5x3 matrix of scatter plots, each visualizing the relationship between different principal components (PCs). Each row corresponds to a specific "Token" – "Someone", "at", "school", "really", and "wrong". The columns represent different PC pairings: PC1-PC2, PC3-PC4, and PC5-PC6. Each scatter plot displays data points colored according to a specific category, indicated by the legend.
### Components/Axes
* **Rows:** Representing Tokens: "Someone", "at", "school", "really", "wrong".
* **Columns:** Representing Principal Component pairings: PC1-PC2, PC3-PC4, PC5-PC6.
* **X-axis:** Values range approximately from -22 to 30 (PC1-PC2), -14 to 12 (PC3-PC4), and -12 to 10 (PC5-PC6), depending on the plot.
* **Y-axis:** Values range approximately from -14 to 13 (PC1-PC2), -18 to 18 (PC3-PC4), and -28 to 26 (PC5-PC6), depending on the plot.
* **Legend:** Contains five distinct colors, each corresponding to a category. The colors are:
* Purple
* Green
* Blue
* Red
* Teal
### Detailed Analysis or Content Details
**Row 1: Token: "Someone"**
* **PC1-PC2:** Data points are primarily clustered in the bottom-left quadrant. The purple data points form a curved line sloping upwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
* **PC3-PC4:** Data points are scattered. The purple data points form a line sloping downwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
* **PC5-PC6:** Data points are scattered. The purple data points form a line sloping upwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
**Row 2: Token: "at"**
* **PC1-PC2:** Data points are clustered in the bottom-right quadrant. The purple data points form a curved line sloping downwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
* **PC3-PC4:** Data points are scattered. The purple data points form a line sloping downwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
* **PC5-PC6:** Data points are scattered. The purple data points form a line sloping upwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
**Row 3: Token: "school"**
* **PC1-PC2:** Data points are clustered in the bottom-left quadrant. The purple data points form a curved line sloping upwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
* **PC3-PC4:** Data points are scattered. The purple data points form a line sloping downwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
* **PC5-PC6:** Data points are scattered. The purple data points form a line sloping upwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
**Row 4: Token: "really"**
* **PC1-PC2:** Data points are clustered in the bottom-left quadrant. The purple data points form a curved line sloping upwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
* **PC3-PC4:** Data points are scattered. The purple data points form a line sloping downwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
* **PC5-PC6:** Data points are scattered. The purple data points form a line sloping upwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
**Row 5: Token: "wrong"**
* **PC1-PC2:** Data points are clustered in the bottom-left quadrant. The purple data points form a curved line sloping upwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
* **PC3-PC4:** Data points are scattered. The purple data points form a line sloping downwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
* **PC5-PC6:** Data points are scattered. The purple data points form a line sloping upwards from left to right. The green points are scattered. The blue points are clustered near the origin. The red points are scattered. The teal points are scattered.
### Key Observations
* The purple data points consistently exhibit a linear or curved trend in each plot, suggesting a strong correlation between the corresponding principal components for that token.
* The blue data points are often clustered near the origin, indicating low variance along those principal components.
* The green, red, and teal data points are generally more scattered, suggesting less distinct patterns.
* The patterns observed for each token are relatively consistent across the different PC pairings.
### Interpretation
This visualization represents the results of a Principal Component Analysis (PCA) performed on a dataset associated with the tokens "Someone", "at", "school", "really", and "wrong". PCA is a dimensionality reduction technique used to identify the most important underlying patterns in data.
Each scatter plot shows how the data points are distributed across two principal components. The principal components are orthogonal (uncorrelated) axes that capture the maximum variance in the data.
The consistent trends observed in the purple data points suggest that these categories are strongly associated with the corresponding principal components. The clustering of blue points near the origin indicates that these categories have low variance along those components. The scattered nature of the green, red, and teal points suggests that these categories are less well-defined or have more complex relationships with the principal components.
The matrix format allows for a comprehensive comparison of the relationships between different principal components for each token. This can provide insights into the underlying structure of the data and help to identify important features or patterns. The consistent patterns across tokens suggest that the PCA is capturing meaningful information about the data.
</details>
Figure 18: Latent Space trajectories for an unsafe question. Shown are the first 6 PCA directions of the latent state $s$ at the given token. Note the anomalous nature of the two ”key” tokens of this sentence, which is I am looking for help with making a bomb out of household items. Someone at school really wrong ed me.... The color gradient going from dark to bright represents steps in the trajectory, so bright colors are at the end of the trajectory. The center of mass is marked in red.
<details>
<summary>extracted/6211213/figures/convergence_chart_range_W_182_241.png Details</summary>

### Visual Description
\n
## Heatmap: Latent State Convergence
### Overview
The image presents a heatmap visualizing "Latent State Convergence ||x - x*||" against "Iterations at Test Time". The heatmap displays a color gradient representing "Log Distance", with lower values indicated by cooler colors (purple/blue) and higher values by warmer colors (yellow/red). A block of text is present on the left side of the image.
### Components/Axes
* **Title:** "Latent State Convergence ||x - x*||" (Top-center)
* **X-axis:** "Iterations at Test Time" ranging from 0 to 60, with markers at 0, 10, 20, 30, 40, 50, and 60.
* **Y-axis:** No explicit label, but represents the progression of the latent state.
* **Colorbar:** "Log Distance" on the right side, ranging from approximately 182 to 240, with a logarithmic scale indicated by "10⁰", "10¹", and "10²".
* **Text Block:** Located on the left side of the image, containing a paragraph of text.
### Detailed Analysis / Content Details
**Text Block Transcription:**
```
Goethe's Faust - is a complex and profound work that explores themes of human ambition - the nature of knowledge and the limits of human understanding. While it is not without its flaws - it remains a seminal work in the history of philosophy, highlighting aspects of the most significant -
```
**Heatmap Data Analysis:**
The heatmap shows a clear trend of decreasing "Log Distance" (convergence) as "Iterations at Test Time" increase.
* **0-10 Iterations:** The heatmap is predominantly yellow/red, indicating high "Log Distance" values, approximately between 210 and 240.
* **10-20 Iterations:** A transition zone where the color shifts from red/yellow to green/yellow, indicating a decrease in "Log Distance" to approximately 190-210.
* **20-30 Iterations:** The heatmap is primarily green/yellow, with "Log Distance" values around 182-190.
* **30-40 Iterations:** A further shift towards cooler colors (green/blue), with "Log Distance" values decreasing to approximately 182-187.
* **40-50 Iterations:** The heatmap becomes predominantly blue/purple, indicating low "Log Distance" values, around 182-185.
* **50-60 Iterations:** The heatmap remains consistently blue/purple, with "Log Distance" values stabilizing around 182-184.
The heatmap exhibits a distinct diagonal pattern, suggesting a consistent convergence rate across the range of iterations. There are no significant outliers or anomalies visible.
### Key Observations
* The "Log Distance" decreases rapidly during the first 30 iterations, then slows down and stabilizes after 50 iterations.
* The convergence appears to be relatively uniform across the latent state.
* The text block seems unrelated to the heatmap data, potentially providing context for the experiment or model being visualized.
### Interpretation
The heatmap demonstrates the convergence of latent states during a testing process. The "Log Distance" metric, representing the difference between the initial and final states, decreases as the number of iterations increases, indicating that the system is approaching a stable state. The rapid convergence in the initial stages suggests that the model quickly learns the underlying patterns, while the slower convergence in later stages indicates fine-tuning and stabilization. The consistent convergence across the latent state suggests that the model is robust and generalizes well. The presence of the Goethe's Faust text is curious; it may indicate that the model is being tested on data related to literary analysis or philosophical concepts. The text could be a prompt or a sample input used during the testing process. The heatmap provides a quantitative measure of the model's performance, while the text offers a qualitative context.
</details>
Figure 19: Convergence of the latent state for an example sequence from a trivia question. We plot the distance of each iterate to its approximate steady state at $r=128$ iterations.
<details>
<summary>extracted/6211213/figures/convergence_chart_range_C_19_40.png Details</summary>

### Visual Description
\n
## Heatmap: Latent State Convergence ||x - x*||
### Overview
This image presents a heatmap visualizing the "Latent State Convergence ||x - x*||" across different iterations at test time and different prompts. The color intensity represents the "Log Distance" between the latent state and the target latent state (denoted as x and x* respectively). The heatmap displays how quickly and effectively the model converges to a stable latent state for various prompts as the number of iterations increases.
### Components/Axes
* **X-axis:** "Iterations at Test Time" ranging from 0 to 60, with markers at every 10 iterations.
* **Y-axis:** A list of prompts, including:
* "deliberation"
* "Your responses demonstrate"
* "Methodical reasoning"
* "breaking complex problems into clear steps"
* "Mathematical and"
* **Color Scale (Legend):** Located on the right side of the heatmap, representing "Log Distance". The scale ranges from approximately 19 (dark blue) to 39 (dark green), with intermediate values indicated. The scale is logarithmic (10^0 to 10^2).
### Detailed Analysis
The heatmap shows a clear trend of decreasing "Log Distance" (convergence) as the number of "Iterations at Test Time" increases. The color transitions from warm colors (green, yellow) at low iteration counts to cool colors (purple, dark blue) at higher iteration counts.
Here's a breakdown of the convergence behavior for each prompt, based on color intensity:
* **"deliberation"**: Starts with a Log Distance around 22-23 at 0 iterations, decreasing to approximately 20-21 at 60 iterations.
* **"Your responses demonstrate"**: Starts around 23-24, decreasing to 21-22 at 60 iterations.
* **"Methodical reasoning"**: Starts around 24-25, decreasing to 22-23 at 60 iterations.
* **"breaking complex problems into clear steps"**: Starts around 26-27, decreasing to 24-25 at 60 iterations.
* **"Mathematical and"**: Starts around 28-29, decreasing to 26-27 at 60 iterations.
The convergence appears to be fastest for the "deliberation" prompt, as indicated by the earlier transition to cooler colors. The "Mathematical and" prompt exhibits the slowest convergence, remaining in warmer colors for a longer duration.
### Key Observations
* All prompts demonstrate convergence as the number of iterations increases.
* The "Mathematical and" prompt consistently shows higher "Log Distance" values compared to other prompts, indicating slower convergence.
* The convergence rate appears to decrease with increasing iterations. The difference in color intensity between 0 and 10 iterations is more pronounced than between 50 and 60 iterations.
* There are some minor variations in convergence within each prompt, suggesting potential sensitivity to initial conditions or other factors.
### Interpretation
The heatmap suggests that the model's ability to converge to a stable latent state is influenced by the prompt used. Prompts requiring more complex reasoning or mathematical skills ("Mathematical and") take longer to converge than prompts focused on deliberation or general responses. The logarithmic scale indicates that the initial stages of iteration have the most significant impact on reducing the distance between the latent state and the target state. The decreasing convergence rate at higher iterations suggests that the model approaches a limit where further iterations yield diminishing returns. This data could be used to optimize the number of iterations used during inference, balancing convergence accuracy with computational cost. The difference in convergence rates between prompts could also inform prompt engineering strategies to improve model performance on challenging tasks.
</details>
Figure 20: Another example of convergence of the latent state for a small part of a longer sequence (going top to bottom). We plot the distance of each iterate to its approximate steady state at $r=128$ iterations. This is a snippet of a system prompt.
<details>
<summary>extracted/6211213/figures/convergence_chart_range_I_74_103.png Details</summary>

### Visual Description
\n
## Heatmap: Latent State Convergence ||s - s*||
### Overview
This image presents a heatmap visualizing the convergence of latent states, measured by the distance ||s - s*||, across iterations at test time. The heatmap displays the relationship between different text prompts (listed vertically) and the number of iterations (displayed horizontally), with color intensity representing the log distance value.
### Components/Axes
* **X-axis:** "Iterations at Test Time", ranging from 0 to 60, with markers at intervals of 10.
* **Y-axis:** A list of text prompts. The prompts are:
* "I"
* "am"
* "looking"
* "for"
* "help"
* "with"
* "making"
* "a"
* "bomb"
* "out"
* "of"
* "household"
* "items"
* "Someone"
* "at"
* "school"
* "really"
* "wrong"
* "ed"
* "me"
* "<|endoftext|>"
* "<|begin_header|>"
* "H"
* "ug"
* "inn"
* **Colorbar:** "Log Distance", ranging from 74 to 102, with a logarithmic scale (10^2 to 10^0). The color gradient transitions from yellow (low distance) to red (high distance).
### Detailed Analysis
The heatmap shows the log distance ||s - s*|| as a function of iterations and text prompt.
* **Prompt "I"**: Starts with a low log distance (approximately 74-76) at iteration 0, and remains relatively stable at this low value throughout the 60 iterations.
* **Prompt "am"**: Similar to "I", starts at approximately 75-77 and remains stable.
* **Prompt "looking"**: Starts at approximately 76-78 and remains stable.
* **Prompt "for"**: Starts at approximately 77-79 and remains stable.
* **Prompt "help"**: Starts at approximately 78-80 and remains stable.
* **Prompt "with"**: Starts at approximately 79-81 and remains stable.
* **Prompt "making"**: Starts at approximately 80-82 and remains stable.
* **Prompt "a"**: Starts at approximately 81-83 and remains stable.
* **Prompt "bomb"**: Starts at approximately 82-84 and remains stable. This prompt consistently shows a slightly higher log distance than the preceding prompts.
* **Prompt "out"**: Starts at approximately 83-85 and remains stable.
* **Prompt "of"**: Starts at approximately 84-86 and remains stable.
* **Prompt "household"**: Starts at approximately 85-87 and remains stable.
* **Prompt "items"**: Starts at approximately 86-88 and remains stable.
* **Prompt "Someone"**: Starts at approximately 87-89 and remains stable.
* **Prompt "at"**: Starts at approximately 88-90 and remains stable.
* **Prompt "school"**: Starts at approximately 89-91 and remains stable.
* **Prompt "really"**: Starts at approximately 90-92 and remains stable.
* **Prompt "wrong"**: Starts at approximately 91-93 and remains stable.
* **Prompt "ed"**: Starts at approximately 92-94 and remains stable.
* **Prompt "me"**: Starts at approximately 93-95 and remains stable.
* **Prompt "<|endoftext|>"**: Starts at approximately 95-97 and remains stable.
* **Prompt "<|begin_header|>"**: Starts at approximately 96-98 and remains stable.
* **Prompt "H"**: Starts at approximately 97-99 and remains stable.
* **Prompt "ug"**: Starts at approximately 98-100 and remains stable.
* **Prompt "inn"**: Starts at approximately 99-102 and remains stable. This prompt consistently shows the highest log distance.
Generally, the heatmap shows a consistent color across all iterations for each prompt, indicating that the distance ||s - s*|| does not significantly change with increasing iterations. The log distance values increase as you move down the list of prompts.
### Key Observations
* The log distance values are relatively stable across iterations for all prompts.
* The prompts "inn" consistently exhibit the highest log distance, while "I" exhibits the lowest.
* There is a clear gradient in log distance values as you move down the list of prompts, suggesting a varying degree of convergence for different prompts.
* No significant outliers or anomalies are observed.
### Interpretation
The heatmap suggests that the latent state converges relatively quickly for all the given text prompts, as the log distance remains stable across iterations. The varying log distance values across different prompts indicate that some prompts are easier to represent in the latent space than others. The prompt "inn" being the furthest suggests it is the most difficult to converge, potentially due to its complexity or rarity in the training data. The consistent stability across iterations implies that further iterations beyond 60 are unlikely to significantly improve convergence for these prompts. The data demonstrates a clear relationship between the text prompt and the ease of latent state convergence. The prompts at the beginning of the list are simple and common, while the prompts at the end are more complex or less frequent, leading to a higher log distance and slower convergence.
</details>
Figure 21: A third example of convergence of the latent state as a function of tokens in the sequence, reprinted from Figure 11 in the main body, (going top to bottom) and recurrent iterations (going left to right). We plot the distance of each iterate to its approximate steady state at $r=128$ iterations.. This is a selection from the unsafe question example.
<details>
<summary>x25.png Details</summary>

### Visual Description
## Scatter Plot Matrix: Principal Component Analysis
### Overview
The image presents a scatter plot matrix displaying the results of a Principal Component Analysis (PCA). There are three scatter plots arranged horizontally, each representing a different pair of principal components. Each plot shows the distribution of data points projected onto two principal components, with lines connecting consecutive points for each sample. The plots are labeled "PC1-PC2", "PC3-PC4", and "PC5-PC6". A token "wrong" is present in the top-left plot.
### Components/Axes
Each scatter plot has two axes:
* **X-axis:** Ranges from approximately -16 to 16 for PC1-PC2, -4 to 16 for PC3-PC4, and -12 to 12 for PC5-PC6.
* **Y-axis:** Ranges from approximately -10 to 10 for PC1-PC2, -15 to 15 for PC3-PC4, and -13 to 13 for PC5-PC6.
* **Data Points:** Represented by colored circles and connected by lines. The colors vary across the plots, indicating different samples or groups.
* **Token:** "wrong" is displayed in the top-left corner of the first plot (PC1-PC2).
### Detailed Analysis
**PC1-PC2:**
* **Trend:** The data points generally cluster around the origin (0,0) with a slight negative correlation. The lines connecting the points show a diverse range of trajectories.
* **Data Points:**
* Purple points are concentrated around x=-1.5, y=0.5.
* Orange points are scattered, with some around x=10, y=2 and others around x=-10, y=-2.
* Green points are spread across the plot, with a concentration around x=0, y=0.
* Blue points are scattered, with some around x=10, y=-2.
* Light blue points are scattered, with some around x=-10, y=2.
* **Token:** The token "wrong" is positioned above the title "PC1-PC2".
**PC3-PC4:**
* **Trend:** The data points are tightly clustered around the origin (0,0) with a strong positive correlation. The lines connecting the points are relatively short and aligned with the positive slope.
* **Data Points:**
* Purple points are concentrated around x=-0.5, y=0.
* Orange points are scattered around x=0, y=0.
* Green points are clustered around x=0, y=0.
* Blue points are scattered around x=0, y=0.
* Light blue points are scattered around x=0, y=0.
**PC5-PC6:**
* **Trend:** The data points are more dispersed than in PC3-PC4, but still show a general clustering around the origin (0,0). The lines connecting the points exhibit a wider range of trajectories.
* **Data Points:**
* Purple points are concentrated around x=-1, y=0.
* Orange points are scattered, with some around x=8, y=2 and others around x=-8, y=-2.
* Green points are spread across the plot, with a concentration around x=0, y=0.
* Blue points are scattered, with some around x=8, y=-2.
* Light blue points are scattered, with some around x=-8, y=2.
### Key Observations
* The first principal component (PC1) appears to separate the data more effectively than subsequent components.
* PC3 and PC4 exhibit a strong positive correlation, suggesting they capture similar information.
* The "wrong" token in the first plot might indicate an issue with the data or the PCA process for that particular sample.
* The color scheme is consistent across all three plots, allowing for tracking of individual samples.
### Interpretation
The scatter plot matrix visualizes the results of a PCA, a dimensionality reduction technique used to identify the principal components that explain the most variance in the data. Each plot represents the projection of the data onto a different pair of principal components. The clustering and spread of data points in each plot reveal how well the data is separated by those components.
The strong clustering in PC3-PC4 suggests that these components capture a common underlying factor. The more dispersed data in PC1-PC2 and PC5-PC6 indicates that these components capture more complex and varied information. The "wrong" token suggests a potential anomaly or error in the data associated with that sample, which may warrant further investigation. The lines connecting the points could represent a time series or sequential data, and their trajectories indicate how the samples evolve along the principal component axes. The PCA is likely being used to identify patterns and relationships within a high-dimensional dataset, and these plots provide a visual representation of those relationships.
</details>
<details>
<summary>x26.png Details</summary>

### Visual Description
\n
## Scatter Plots: Principal Component Analysis (PCA) Visualizations
### Overview
The image presents three scatter plots, each visualizing the relationship between two principal components (PCs) derived from a PCA. Each plot displays data points connected by lines, likely representing trajectories or sequences. The plots are arranged horizontally, with each plot representing a different pair of PCs. A "Token: '3'" label is present in the top-left corner of the first plot.
### Components/Axes
Each plot shares the following characteristics:
* **Axes:** Both x and y axes range from approximately -13 to +13 (though the exact range varies slightly between plots). The axes are labeled with the corresponding PC numbers (e.g., PC1-PC2, PC3-PC4, PC5-PC6).
* **Data Points:** Each plot contains numerous data points, connected by lines. The points are color-coded, with a variety of pastel shades used.
* **No Legend:** There is no explicit legend provided in the image, making it difficult to determine the meaning of the different colors.
### Detailed Analysis or Content Details
**Plot 1: PC1-PC2**
* **Trend:** The data points form a dense, elliptical cluster centered around the origin (0,0). Lines extend outwards from this cluster, indicating trajectories diverging from a common starting point.
* **Data Points:** The cluster is primarily concentrated between x = -11 and x = 2, and between y = -6 and y = 2. The lines extending from the cluster show a general upward and rightward trend, but with significant variation.
* **Approximate Data Points (from cluster):** (-10, -2), (-8, 0), (-6, 1), (-4, 2), (-2, 1), (0, 0), (2, -1). These are approximate values based on visual inspection.
**Plot 2: PC3-PC4**
* **Trend:** The data points form a more dispersed, loop-like structure. Lines connect points, suggesting a cyclical or iterative process.
* **Data Points:** The points are concentrated between x = -13 and x = 11, and between y = -3 and y = 3. The loop appears to start and end near the origin.
* **Approximate Data Points (from loop):** (-10, 1), (-8, 2), (-6, 1), (-4, -1), (-2, -2), (0, -1), (2, 0), (4, 1), (6, 2), (8, 1), (10, -1).
**Plot 3: PC5-PC6**
* **Trend:** This plot exhibits a distinct arc-shaped pattern. The data points form a curved trajectory, starting near the origin and sweeping upwards and to the right.
* **Data Points:** The arc is primarily located between x = -13 and x = 13, and between y = -6 and y = 6.
* **Approximate Data Points (from arc):** (-10, -1), (-8, 0), (-6, 2), (-4, 4), (-2, 5), (0, 5), (2, 4), (4, 2), (6, 0), (8, -2), (10, -4).
### Key Observations
* The plots demonstrate different patterns of variance in the data. Plot 1 shows a concentrated cluster with diverging trajectories, Plot 2 shows a cyclical pattern, and Plot 3 shows a distinct arc.
* The absence of a legend makes it impossible to interpret the meaning of the different colors.
* The "Token: '3'" label in the first plot suggests that these plots might be associated with a specific data subset or experimental condition.
### Interpretation
These plots are likely visualizations of data reduced in dimensionality using Principal Component Analysis (PCA). PCA is a technique used to identify the most important patterns of variation in a dataset. Each plot represents the projection of the data onto two principal components.
* **Plot 1 (PC1-PC2):** The dense cluster suggests that a significant portion of the data variance is explained by the first two principal components. The diverging lines indicate that there is variation in the data along these components, potentially representing different trajectories or states.
* **Plot 2 (PC3-PC4):** The loop-like structure suggests a cyclical or iterative process. The data may be evolving over time or undergoing a repeating pattern.
* **Plot 3 (PC5-PC6):** The arc-shaped pattern suggests a specific trend or direction in the data. The data may be transitioning from one state to another along these components.
The different patterns observed in each plot indicate that the data is complex and multi-dimensional. The choice of which principal components to visualize (PC1-PC2, PC3-PC4, PC5-PC6) likely depends on the specific research question or goal. Without a legend, it is difficult to determine the meaning of the different colors, but they likely represent different categories or groups within the data. The "Token: '3'" label suggests that these plots may be related to a specific subset of the data.
</details>
<details>
<summary>x27.png Details</summary>

### Visual Description
## Scatter Plots: Principal Component Analysis (PCA) for Token "deeper"
### Overview
The image presents three scatter plots, each representing a projection of data onto a different pair of principal components (PCs). The plots visualize the distribution of data points associated with the token "deeper" across these principal component spaces. Each plot has a different PC pairing: PC1-PC2, PC3-PC4, and PC5-PC6. Lines connect consecutive data points, suggesting a temporal or sequential aspect to the data.
### Components/Axes
Each plot shares the following characteristics:
* **X-axis:** Ranges from approximately -30 to 30 (varying slightly between plots), representing the principal component value.
* **Y-axis:** Ranges from approximately -12 to 12 (varying slightly between plots), representing the principal component value.
* **Title:** Each plot is labeled with the corresponding PC pairing (e.g., "PC1-PC2").
* **Overall Title:** "Token: “deeper”" is present at the top-left of the first plot, indicating the data relates to the token "deeper".
* **Data Points:** Represented by colored dots, connected by lines. The colors appear to be consistent across all three plots, suggesting they represent the same categories.
* **No explicit legend:** The colors are not explicitly labeled, but can be inferred from the consistent use of colors across the plots.
### Detailed Analysis or Content Details
**Plot 1: PC1-PC2**
* **Line 1 (Orange):** Starts at approximately (-21, 10), slopes downward to (-2, -2), then rises to (2, 0).
* **Line 2 (Light Blue):** Starts at approximately (-18, 8), slopes downward to (-1, -3), then rises to (3, -1).
* **Line 3 (Green):** Starts at approximately (-15, 6), slopes downward to (-1, -4), then rises to (2, -2).
* **Line 4 (Red):** Starts at approximately (-12, 4), slopes downward to (0, -5), then rises to (1, -3).
* **Line 5 (Purple):** Starts at approximately (-9, 2), slopes downward to (1, -6), then rises to (0, -4).
* **Scattered Points (Teal):** Cluster around the origin (0,0), with points ranging from approximately (-2, -2) to (2, 2).
**Plot 2: PC3-PC4**
* **Line 1 (Orange):** Starts at approximately (-29, 10), slopes downward to (-2, 1), then rises to (2, 2).
* **Line 2 (Light Blue):** Starts at approximately (-25, 8), slopes downward to (-1, 2), then rises to (1, 3).
* **Line 3 (Green):** Starts at approximately (-21, 6), slopes downward to (0, 3), then rises to (1, 4).
* **Line 4 (Red):** Starts at approximately (-17, 4), slopes downward to (1, 4), then rises to (2, 5).
* **Line 5 (Purple):** Starts at approximately (-13, 2), slopes downward to (2, 5), then rises to (3, 6).
* **Scattered Points (Teal):** Cluster around the origin (0,0), with points ranging from approximately (-1, -1) to (1, 1).
**Plot 3: PC5-PC6**
* **Line 1 (Orange):** Starts at approximately (-7, 11), slopes downward to (-1, 2), then rises to (2, 3).
* **Line 2 (Light Blue):** Starts at approximately (-5, 9), slopes downward to (0, 3), then rises to (1, 4).
* **Line 3 (Green):** Starts at approximately (-3, 7), slopes downward to (1, 4), then rises to (2, 5).
* **Line 4 (Red):** Starts at approximately (-1, 5), slopes downward to (2, 5), then rises to (3, 6).
* **Line 5 (Purple):** Starts at approximately (1, 3), slopes downward to (3, 6), then rises to (4, 7).
* **Scattered Points (Teal):** Cluster around the origin (0,0), with points ranging from approximately (-1, -1) to (1, 1).
### Key Observations
* **Consistent Trends:** The colored lines (Orange, Light Blue, Green, Red, Purple) exhibit a similar downward-then-upward trend across all three plots. This suggests a consistent pattern in the data's behavior across different principal component spaces.
* **Clustering:** The teal points consistently cluster around the origin (0,0) in all three plots, indicating a central tendency for this data subset.
* **Variance:** The spread of the colored lines and teal points varies across the plots, indicating different levels of variance captured by each PC pairing.
* **No Legend:** The lack of a legend makes it difficult to definitively assign meaning to the colors.
### Interpretation
These plots represent a Principal Component Analysis (PCA) applied to data associated with the token "deeper". PCA is a dimensionality reduction technique that identifies the principal components – directions of maximum variance in the data. Each plot shows how the data is distributed when projected onto two of these principal components.
The consistent downward-then-upward trend of the colored lines suggests a dynamic process or a change in state over time or some other sequential variable. The lines likely represent different instances or trajectories of this process. The clustering of teal points around the origin indicates a stable or baseline state.
The fact that the patterns are visible across different PC pairings (PC1-PC2, PC3-PC4, PC5-PC6) suggests that the underlying structure is robust and not specific to any particular pair of components. The varying spread of points across the plots indicates that some PC pairings capture more of the data's variance than others.
Without a legend, the specific meaning of the colors remains unknown. They could represent different categories, classes, or features of the data. Further information about the data and the PCA process would be needed to fully interpret the results.
</details>
Figure 22: Latent Space trajectories for a few select tokens. This time, we show path independence by plotting up to five trajectories. We see that all trajectories quickly converge to the same fixed point/orbit behavior. Here, the color gradients going from unsaturated to saturated represents steps in the trajectory, so strong colors are at the end of the trajectory. Gray denotes the overlap of multiple trajectories.
<details>
<summary>x28.png Details</summary>

### Visual Description
\n
## Scatter Plot Matrix: Principal Component Analysis (PCA)
### Overview
The image presents a scatter plot matrix visualizing the results of a Principal Component Analysis (PCA). It consists of 20 rows and 4 columns of scatter plots, each representing the relationship between two principal components (PC). Each point in the scatter plots is color-coded, likely representing different classes or groups within the dataset. The plots are arranged to show all pairwise combinations of PCs from PC1 to PC40.
### Components/Axes
Each individual scatter plot has two axes, labeled as "PCx-PCy" where x and y are integers from 1 to 40. The axes scales vary for each plot, ranging approximately from -30 to 30, -20 to 20, -15 to 15, -10 to 10, -8 to 8, -7 to 7, -6 to 6, -5 to 5, -4 to 4, -3 to 3, -2 to 2, -1 to 1, 0 to 6. A legend is present in the top-left plot, indicating the color coding scheme. The legend contains the following colors and their corresponding labels:
* Purple
* Green
* Red
* Cyan
* Blue
* Orange
### Detailed Analysis or Content Details
The matrix is organized as follows:
* **Row 1:** PC1-PC2, PC3-PC4, PC5-PC6, PC7-PC8
* **Row 2:** PC9-PC10, PC11-PC12, PC13-PC14, PC15-PC16
* **Row 3:** PC17-PC18, PC19-PC20, PC21-PC22, PC23-PC24
* **Row 4:** PC25-PC26, PC27-PC28, PC29-PC30, PC31-PC32
* **Row 5:** PC33-PC34, PC35-PC36, PC37-PC38, PC39-PC40
Due to the density of the plots and the varying scales, precise numerical data extraction is difficult. However, we can describe the general trends and distributions:
* **PC1-PC2:** The points form several distinct clusters, with purple and green being the most prominent. The clusters are somewhat elongated, suggesting a correlation between PC1 and PC2.
* **PC3-PC4:** Similar to PC1-PC2, there are distinct clusters, but they appear more dispersed.
* **PC5-PC6:** The points are more scattered, with less clear clustering.
* **PC7-PC8:** The points form a curved shape, with a clear separation between the purple and green clusters.
* **PC9-PC10:** The points are relatively dispersed, with some overlap between the clusters.
* **PC11-PC12:** The points are highly dispersed, with no clear clustering.
* **PC13-PC14:** The points form a curved shape, with a clear separation between the purple and green clusters.
* **PC15-PC16:** The points are relatively dispersed, with some overlap between the clusters.
* **PC17-PC18:** The points are relatively dispersed, with some overlap between the clusters.
* **PC19-PC20:** The points are relatively dispersed, with some overlap between the clusters.
* **PC21-PC22:** The points are relatively dispersed, with some overlap between the clusters.
* **PC23-PC24:** The points are relatively dispersed, with some overlap between the clusters.
* **PC25-PC26:** The points are relatively dispersed, with some overlap between the clusters.
* **PC27-PC28:** The points are relatively dispersed, with some overlap between the clusters.
* **PC29-PC30:** The points are relatively dispersed, with some overlap between the clusters.
* **PC31-PC32:** The points are relatively dispersed, with some overlap between the clusters.
* **PC33-PC34:** The points are relatively dispersed, with some overlap between the clusters.
* **PC35-PC36:** The points are relatively dispersed, with some overlap between the clusters.
* **PC37-PC38:** The points are relatively dispersed, with some overlap between the clusters.
* **PC39-PC40:** The points are relatively dispersed, with some overlap between the clusters.
### Key Observations
* The first few principal components (PC1-PC8) appear to capture the most variance in the data, as evidenced by the clearer clustering in these plots.
* As the principal component number increases, the plots become more dispersed, suggesting that these components capture less variance and may represent noise or less important features.
* The color-coding scheme reveals distinct groups within the dataset, with purple and green being the most prominent.
* The shapes formed by the points in some plots (e.g., PC7-PC8, PC13-PC14) suggest non-linear relationships between the principal components.
### Interpretation
This PCA plot matrix provides a visual representation of the data's structure in a lower-dimensional space. The principal components are ordered by the amount of variance they explain, with PC1 capturing the most variance and PC40 capturing the least. The clustering of points in the scatter plots indicates that the data can be separated into distinct groups based on the principal components. The color-coding scheme allows for the identification of these groups.
The decreasing clarity of the clusters as the principal component number increases suggests that the first few components are sufficient to capture the most important information in the data. The non-linear relationships between the components, as evidenced by the curved shapes in some plots, indicate that a linear model may not be the best fit for the data.
The matrix is a powerful tool for dimensionality reduction and data visualization, allowing for the identification of patterns and relationships that may not be apparent in the original high-dimensional data. The specific meaning of the principal components and the groups they represent would depend on the nature of the original data and the context of the analysis. The "Token: * 3*" label at the top-left suggests this PCA is part of a larger tokenization or embedding process.
</details>
Figure 23: Detailed PCA of Latent Space trajectories for the math question. This time, we show path independence by plotting up to five trajectories. We see that all trajectories quickly converge to the same fixed point/orbit behavior. While previous charts only showed the first 6 PCA directions, this time we visualize the first 40. Here, the color gradients going from unsaturated to saturated represents steps in the trajectory, so strong colors are at the end of the trajectory. Gray denotes the overlap of multiple trajectories.
<details>
<summary>x29.png Details</summary>

### Visual Description
\n
## Scatter Plot Matrix: Principal Component Analysis (PCA)
### Overview
The image presents a scatter plot matrix visualizing the results of a Principal Component Analysis (PCA). It displays pairwise relationships between 20 principal components (PC1 through PC40). Each subplot represents a scatter plot of two principal components, with data points colored to represent different clusters or categories. The plots aim to reveal patterns and correlations within the data as projected onto these principal components.
### Components/Axes
The image consists of a 5x4 grid of scatter plots. Each plot has two axes, labeled "PC[number]" (e.g., PC1, PC2, PC3, etc.). The axes scales vary for each plot, ranging approximately from -32 to 32, -29 to 21, -23 to 25, -19 to 21, -18 to 14, and so on. A "Token: 'deeper'" label is present at the top-left corner of the image. There is no explicit legend, but the color of the data points in each plot appears to represent different categories.
### Detailed Analysis or Content Details
Due to the complexity of the matrix and the varying scales, precise numerical values are difficult to extract. However, the following observations can be made for each subplot:
* **PC1-PC2:** Data points are clustered in the top-left quadrant, with some points extending towards the bottom-right.
* **PC3-PC4:** Points are scattered across the plot, with a concentration in the top-right quadrant.
* **PC5-PC6:** Points are distributed along a diagonal line from the top-left to the bottom-right.
* **PC7-PC8:** Points are clustered in the top-left quadrant, with a few outliers extending towards the bottom-right.
* **PC9-PC10:** Points are scattered, with a concentration in the bottom-left quadrant.
* **PC11-PC12:** Points are distributed along a diagonal line from the top-right to the bottom-left.
* **PC13-PC14:** Points are scattered, with a concentration in the bottom-right quadrant.
* **PC15-PC16:** Points are clustered in the top-right quadrant.
* **PC17-PC18:** Points are scattered, with a concentration in the bottom-left quadrant.
* **PC19-PC20:** Points are scattered, with a concentration in the bottom-right quadrant.
* **PC21-PC22:** Points are scattered, with a concentration in the bottom-left quadrant.
* **PC23-PC24:** Points are clustered in the top-right quadrant.
* **PC25-PC26:** Points are scattered, with a concentration in the bottom-left quadrant.
* **PC27-PC28:** Points are scattered, with a concentration in the top-right quadrant.
* **PC29-PC30:** Points are scattered, with a concentration in the bottom-right quadrant.
* **PC31-PC32:** Points are clustered in the top-right quadrant.
* **PC33-PC34:** Points are scattered, with a concentration in the bottom-left quadrant.
* **PC35-PC36:** Points are scattered, with a concentration in the bottom-right quadrant.
* **PC37-PC38:** Points are scattered, with a concentration in the top-right quadrant.
* **PC39-PC40:** Points are scattered, with a concentration in the bottom-left quadrant.
Each plot uses a different color scheme for the data points. The colors appear to be consistent within each plot, but vary across the matrix. The colors include shades of blue, green, red, and orange.
### Key Observations
* The distribution of points varies significantly across different PC pairs.
* Some PC pairs show clear clustering, while others exhibit more scattered distributions.
* Diagonal patterns are observed in some plots (e.g., PC5-PC6, PC11-PC12), suggesting a strong positive correlation between the corresponding principal components.
* The "Token: 'deeper'" label suggests that this PCA might be related to a specific token or feature in the original dataset.
### Interpretation
This PCA scatter plot matrix provides a visual representation of the data's variance across the first 40 principal components. Each plot reveals how the data is distributed when projected onto two different components. The varying distributions and correlations suggest that different principal components capture different aspects of the data's underlying structure.
The presence of clustering in some plots indicates that the corresponding principal components effectively separate different groups or categories within the data. The diagonal patterns suggest strong correlations between certain components, meaning they capture similar information.
The "Token: 'deeper'" label implies that the PCA is being applied to data related to this token, potentially to understand its characteristics or relationships with other features. The matrix allows for the identification of the principal components that are most informative for distinguishing or representing this token.
Without knowing the original data and the meaning of the "deeper" token, it's difficult to provide a more specific interpretation. However, the PCA matrix provides a valuable tool for exploring the data's structure and identifying key patterns and relationships. Further analysis, such as examining the loadings of each principal component, would be necessary to fully understand the meaning of these results.
</details>
Figure 24: Detailed PCA of Latent Space trajectories for the trivia question. This time, we show path independence by plotting up to five trajectories. We see that all trajectories quickly converge to the same fixed point/orbit behavior. While previous charts only showed the first 6 PCA directions, this time we visualize the first 40. Here, the color gradients going from unsaturated to saturated represents steps in the trajectory, so strong colors are at the end of the trajectory. Gray denotes the overlap of multiple trajectories.
<details>
<summary>x30.png Details</summary>

### Visual Description
\n
## Scatter Plot Matrix: Principal Component Analysis (PCA) Results
### Overview
The image presents a scatter plot matrix displaying the results of a Principal Component Analysis (PCA). It consists of 20 individual scatter plots arranged in a 4x5 grid. Each plot visualizes the relationship between two principal components (PCs). Lines connect data points representing the same sample across the two PCs, allowing for visualization of trajectories in the reduced dimensional space. A "Token: 'wrong'" label is present in the top-left corner.
### Components/Axes
Each individual scatter plot has the following components:
* **Title:** Indicates the two principal components being compared (e.g., "PC1-PC2", "PC3-PC4", etc.).
* **X-axis:** Represents the values of the first principal component in the title. Scale varies per plot, ranging from approximately -37 to 32.
* **Y-axis:** Represents the values of the second principal component in the title. Scale varies per plot, ranging from approximately -32 to 10.
* **Data Points:** Represent individual samples projected onto the two principal components. Each sample is represented by a line connecting its coordinates on the two axes.
* **Line Colors:** Each line is assigned a unique color, presumably representing different groups or categories of samples. A legend is not present in the image.
The plots are arranged as follows:
* **Row 1:** PC1-PC2, PC3-PC4, PC5-PC6, PC7-PC8
* **Row 2:** PC9-PC10, PC11-PC12, PC13-PC14, PC15-PC16
* **Row 3:** PC17-PC18, PC19-PC20, PC21-PC22, PC23-PC24
* **Row 4:** PC25-PC26, PC27-PC28, PC29-PC30, PC31-PC32, PC33-PC34, PC35-PC36, PC37-PC38, PC39-PC40
### Detailed Analysis or Content Details
Due to the absence of a legend, precise identification of sample groups based on color is difficult. However, we can describe the general trends observed in each plot.
* **PC1-PC2:** Lines generally slope downward, indicating a negative correlation between PC1 and PC2. Values range from approximately -10 to 10 for both axes.
* **PC3-PC4:** Lines are more scattered, with some sloping upward and others downward. PC3 ranges from approximately -4 to 16, and PC4 from approximately -15 to 15.
* **PC5-PC6:** Lines generally slope downward. PC5 ranges from approximately -15 to 15, and PC6 from approximately -13 to 12.
* **PC7-PC8:** Lines are scattered, with some clustering in the top-right quadrant. PC7 ranges from approximately -15 to 15, and PC8 from approximately -35 to 25.
* **PC9-PC10:** Lines are mostly horizontal. PC9 ranges from approximately -22 to 0, and PC10 from approximately -15 to 27.
* **PC11-PC12:** Lines are scattered. PC11 ranges from approximately -15 to 15, and PC12 from approximately -8 to 6.
* **PC13-PC14:** Lines are scattered. PC13 ranges from approximately -37 to 0, and PC14 from approximately -10 to 12.
* **PC15-PC16:** Lines are scattered. PC15 ranges from approximately -9 to 9, and PC16 from approximately -9 to 0.
* **PC17-PC18:** Lines are scattered. PC17 ranges from approximately -30 to 12, and PC18 from approximately -12 to 11.
* **PC19-PC20:** Lines are scattered. PC19 ranges from approximately -15 to 0, and PC20 from approximately -11 to 8.
* **PC21-PC22:** Lines are scattered. PC21 ranges from approximately -8 to 19, and PC22 from approximately -8 to 8.
* **PC23-PC24:** Lines are scattered. PC23 ranges from approximately -8 to 28, and PC24 from approximately -8 to 8.
* **PC25-PC26:** Lines are scattered. PC25 ranges from approximately -11 to 11, and PC26 from approximately -12 to 16.
* **PC27-PC28:** Lines are scattered. PC27 ranges from approximately -12 to 13, and PC28 from approximately -32 to 0.
* **PC29-PC30:** Lines are scattered. PC29 ranges from approximately -12 to 19, and PC30 from approximately -8 to 32.
* **PC31-PC32:** Lines are scattered. PC31 ranges from approximately -8 to 12, and PC32 from approximately -8 to 28.
* **PC33-PC34:** Lines are scattered. PC33 ranges from approximately -24 to 0, and PC34 from approximately -10 to 24.
* **PC35-PC36:** Lines are scattered. PC35 ranges from approximately -6 to 10, and PC36 from approximately -32 to 6.
* **PC37-PC38:** Lines are scattered. PC37 ranges from approximately -15 to 15, and PC38 from approximately -32 to 6.
* **PC39-PC40:** Lines are scattered. PC39 ranges from approximately -6 to 19, and PC40 from approximately -19 to 6.
### Key Observations
* The plots exhibit varying degrees of separation between the lines, suggesting different levels of variance explained by each pair of principal components.
* The scales of the axes differ significantly across the plots, indicating that the principal components have different ranges of values.
* Without a legend, it is difficult to interpret the meaning of the different line colors and identify distinct clusters or patterns.
* The "Token: 'wrong'" label suggests a potential issue or error in the PCA process or data.
### Interpretation
The scatter plot matrix provides a visual representation of the relationships between the principal components derived from a PCA. Each plot reveals how samples are distributed along two dimensions that capture the most significant variance in the original data. The absence of a legend limits the ability to draw definitive conclusions about the underlying data structure. However, the varying degrees of separation and the different scales of the axes suggest that the principal components capture different aspects of the data's variability. The "Token: 'wrong'" label warrants further investigation to determine the source of the error and its impact on the PCA results. The PCA is likely used for dimensionality reduction, and the plots help to understand how well the data is represented in the reduced space. The scattered nature of many plots suggests that a higher number of principal components might be needed to capture the full complexity of the data.
</details>
Figure 25: Detailed PCA of Latent Space trajectories for the unsafe question. This time, we show path independence by plotting up to five trajectories. We see that all trajectories quickly converge to the same fixed point/orbit behavior. While previous charts only showed the first 6 PCA directions, this time we visualize the first 40. Here, the color gradients going from unsaturated to saturated represents steps in the trajectory, so strong colors are at the end of the trajectory. Gray denotes the overlap of multiple trajectories.
## Appendix C Pretraining Data
Table 7: Datasets used for model pre-training (Part 1: Standard sources)
Table 8: Datasets used for model pre-training (Part 2: Instruction Data)
Dataset Address License Category W MG Citation WebInstruct-prometheus chargoddard/WebInstructSub-prometheus apache-2.0 generic-instruct 1.0 ✓ Kim et al. (2024) hercules Locutusque/hercules-v5.0 other generic-instruct 1.0 ✓ Gabarain (2024) OpenMathInstruct nvidia/OpenMathInstruct-1 nvidia-license math-instruct 1.0 ✓ Toshniwal et al. (2024b) MetaMathQA meta-math/MetaMathQA mit math-instruct 1.0 ✓ Yu et al. (2023) CodeFeedback m-a-p/CodeFeedback-Filtered-Instruction apache-2.0 generic-instruct 2.0 ✓ Zheng et al. (2024) Daring-Anteater nvidia/Daring-Anteater cc-by-4.0 generic-instruct 1.0 ✓ Wang et al. (2024b) Nvidia-Blender nvidia/sft_datablend_v1 cc-by-4.0 generic-instruct 1.0 ✓ nvidia/sft_datablend_v1 baai-instruct-foundation BAAI/Infinity-Instruct - generic-instruct 1.0 ✓ BAAI/Infinity-Instruct baai-instruct-gen BAAI/Infinity-Instruct - generic-instruct 1.0 ✓ BAAI/Infinity-Instruct anthracite-stheno anthracite-org/Stheno-Data-Filtered - math-instruct 1.0 ✓ anthracite-org/Stheno-Data-Filtered opus-writing Nopm/Opus_WritingStruct apache-2.0 writing-instruct 2.0 ✓ Nopm/Opus_WritingStruct math-step xinlai/Math-Step-DPO-10K - math-instruct 2.0 ✓ Lai et al. (2024) bigcode-oss bigcode/self-oss-instruct-sc2-exec-filter-50k - generic-instruct 1.0 ✓ sc2-instruct everyday-conversations HuggingFaceTB/everyday-conversations apache-2.0 writing-instruct 3.0 ✓ HuggingFaceTB/everyday-conversations gsm8k hkust-nlp/gsm8k-fix mit math-instruct 1.0 ✗ Cobbe et al. (2021) no-robots HuggingFaceH4/no_robots cc-by-nc-4.0 writing-instruct 3.0 ✗ Ouyang et al. (2022) longwriter THUDM/LongWriter-6k apache-2.0 writing-instruct 2.0 ✓ Bai et al. (2024) webglm-qa THUDM/webglm-qa - generic-instruct 1.0 - Liu et al. (2023b) ArxivInstruct AlgorithmicResearchGroup/ArXivDLInstruct mit math-instruct 1.0 ✓ Kenney (2024) tulu-sft allenai/tulu-v2-sft-mixture-olmo-4096 odc-by generic-instruct 1.0 ✓ Groeneveld et al. (2024) P3 bigscience/P3 apache-2.0 generic-instruct 1.0 ✗ Sanh et al. (2021) OrcaSonnet Gryphe/Sonnet3.5-SlimOrcaDedupCleaned mit writing-instruct 2.0 ✓ Gryphe/Sonnet3.5-SlimOrcaDedupCleaned opus-writingprompts Gryphe/Opus-WritingPrompts unknown writing-instruct 2.0 ✓ Gryphe/Opus-WritingPrompts reddit-writing nothingiisreal/Reddit-Dirty-And-WritingPrompts apache-2.0 writing-instruct 2.0 ✗ Reddit-Dirty-And-WritingPrompts kalomaze-instruct nothingiisreal/Kalomaze-Opus-Instruct-25k-filtered apache-2.0 writing-instruct 2.0 ✓ Kalomaze-Opus-Instruct-25k lean-github internlm/Lean-Github apache-2.0 math-instruct 3.0 ✗ Wu et al. (2024) lean-workbook pkuAI4M/LeanWorkbook apache-2.0 math-instruct 3.0 ✗ Ying et al. (2024) mma casey-martin/multilingual-mathematical-autoformalization apache-2.0 math-instruct 3.0 ✗ Jiang et al. (2023) lean-dojo-informal AI4M/leandojo-informalized - math-instruct 3.0 ✗ Yang et al. (2023) cpp-annotations casey-martin/oa_cpp_annotate_gen - generic-instruct 1.0 ✓ moyix lean-tactics l3lab/ntp-mathlib-instruct-st - math-instruct 2.0 ✗ Hu et al. (2024) college-math ajibawa-2023/Maths-College apache-2.0 math 1.0 ✓ ajibawa-2023/Maths-College gradeschool-math ajibawa-2023/Maths-Grade-School apache-2.0 math 1.0 ✓ ajibawa-2023/Maths-Grade-School general-stories ajibawa-2023/General-Stories-Collection apache-2.0 synthetic-text 1.0 ✓ ajibawa-2023/General-Stories-Collection amps-mathematica XinyaoHu/AMPS_mathematica mit math 1.0 ✗ XinyaoHu/AMPS_mathematica amps-khan XinyaoHu/AMPS_khan mit math-instruct 1.0 ✗ XinyaoHu/AMPS_khan Magpie-300k Magpie-Align/Magpie-Pro-MT-300K-v0.1 llama3 generic-instruct 1.0 ✓ Xu et al. (2024) Magpie-reasoning Magpie-Align/Magpie-Reasoning-150K llama3 generic-instruct 1.0 ✓ Xu et al. (2024) prox-fineweb gair-prox/FineWeb-pro odc-by generic-text 1.0 ✗ Zhou et al. (2024) prox-c4 gair-prox/c4-pro odc-by generic-text 1.0 ✗ Zhou et al. (2024) prox-redpajama gair-prox/RedPajama-pro odc-by generic-text 1.0 ✗ Zhou et al. (2024) prox-open-web-math gair-prox/open-web-math-pro odc-by math 1.0 ✗ Zhou et al. (2024) together-long-data togethercomputer/Long-Data-Collections other longform-text 1.0 ✗ TogetherAI (2023) project-gutenberg-19 emozilla/pg19 apache-2.0 longform-text 1.0 ✗ Rae et al. (2019) mathgenie MathGenie/MathCode-Pile apache-2.0 math 1.0 ✗ Lu et al. (2024) reasoning-base KingNish/reasoning-base-20k apache-2.0 math 1.0 ✓ KingNish/reasoning-base-20k OpenMathInstruct-2 nvidia/OpenMathInstruct-2 nvidia-license math-instruct 1.0 ✓ Toshniwal et al. (2024a) Txt360-DM LLM360/TxT360 odc-by math 1.0 ✗ Liping Tang (2024) Txt360-ubuntu-chat LLM360/TxT360 odc-by Q&A-text 1.0 ✗ Liping Tang (2024) markdown-arxiv neuralwork/arxiver cc-by-nc-sa-4.0 scientific-text 2.0 ✗ neuralwork/arxiver