# Learning richness modulates equality reasoning in neural networks
**Authors**: Double blind review
## Abstract
Equality reasoning is ubiquitous and purely abstract: sameness or difference may be evaluated no matter the nature of the underlying objects. As a result, same-different (SD) tasks have been extensively studied as a starting point for understanding abstract reasoning in humans and across animal species. With the rise of neural networks that exhibit striking apparent proficiency for abstractions, equality reasoning in these models has also gained interest. Yet despite extensive study, conclusions about equality reasoning vary widely and with little consensus. To clarify the underlying principles in learning SD tasks, we develop a theory of equality reasoning in multi-layer perceptrons (MLP). Following observations in comparative psychology, we propose a spectrum of behavior that ranges from conceptual to perceptual outcomes. Conceptual behavior is characterized by task-specific representations, efficient learning, and insensitivity to spurious perceptual details. Perceptual behavior is characterized by strong sensitivity to spurious perceptual details, accompanied by the need for exhaustive training to learn the task. We develop a mathematical theory to show that an MLP’s behavior is driven by learning richness. Rich-regime MLPs exhibit conceptual behavior, whereas lazy-regime MLPs exhibit perceptual behavior. We validate our theoretical findings in vision SD experiments, showing that rich feature learning promotes success by encouraging hallmarks of conceptual behavior. Overall, our work identifies feature learning richness as a key parameter modulating equality reasoning, and suggests that equality reasoning in humans and animals may similarly depend on learning richness in neural circuits.
Keywords: equality reasoning; same-different; neural network; conceptual and perceptual behavior
## 1 Introduction
The ability to reason abstractly is a hallmark of human intelligence. Fluency with abstractions drives both our highest intellectual achievements and many of our daily necessities like telling time, navigating traffic, and planning leisure. At the same time, neural networks have grown tremendously in sophistication and scale. The latest examples exhibit increasingly impressive competency, and the potential to automate the reasoning process itself seems imminent (OpenAI, 2024, 2023; Bubeck ., 2023; Guo ., 2025). Nonetheless, it remains unclear to what extent these models are able to reason abstractly, and how consistently they behave (McCoy ., 2023; Mahowald ., 2024; Ullman, 2023). To begin answering these questions, we require a principled understanding of how neural networks can reason.
A particularly simple and salient form of abstract reasoning is equality reasoning: determining whether two objects are the same or different. The “sense of sameness is the very keel and backbone of our thinking,” (James, 1905) promoting its study as a tractable viewport into abstract reasoning across humans and animals (E A. Wasserman Young, 2010). Despite many decades of study, the history of equality reasoning abounds with widely varying conclusions. Success at same-different (SD) tasks have been documented in a large number of animals, including non-human primates (Vonk, 2003), honeybees (Giurfa ., 2001), pigeons (E A. Wasserman Young, 2010), crows (Smirnova ., 2015), and parrots (Obozova ., 2015). Others, however, have argued that animals employ perceptual shortcuts to solve these tasks like using stimulus variability, and lack a true conception of sameness or difference (Penn ., 2008). Competence at equality reasoning may require exposure to language or some form of symbolic training (Premack, 1983). Meanwhile, pre-lingual human infants have demonstrated sensitivity to same-different relations (G F. Marcus ., 1999; Saffran Thiessen, 2003; Rabagliati ., 2019).
Equality reasoning in neural networks is no less debated. G F. Marcus . (1999) discovered that seven-month-old infants succeed at an SD task where neural networks fail, launching a lively debate that continues to present day (Seidenberg ., 1999; Seidenberg Elman, 1999; Alhama Zuidema, 2019). Others have demonstrated severe shortcomings in neural networks directed to solve visual same-different reasoning and relational tasks (Kim ., 2018; Stabinger ., 2021; Vaishnav ., 2022; Webb ., 2023). Such failures motivate a growing literature in bespoke architectural advancements geared towards relational reasoning (Webb ., 2023, 2020; Santoro ., 2017; Battaglia ., 2018). At the same time, modern large language models routinely solve complex reasoning problems (Bubeck ., 2023). Their surprising success tempers earlier categorical claims against neural networks’ reasoning abilities. Even simple models like multi-layer perceptrons (MLPs) have recently been shown to solve equality and relational reasoning tasks with surprisingly efficacy (A. Geiger ., 2023; Tong Pehlevan, 2024).
The lack of consensus on equality reasoning in either organic or silicate brains speaks to the need for a stronger theoretical foundation. To this end, we present a theory of equality reasoning in MLPs that highlights the central role of a hitherto overlooked parameter: learning richness, a measure of how much internal representations change over the course of training (Chizat ., 2019). We find that MLPs in a rich learning regime exhibit conceptual behavior, where they develop salient, conceptual representations of sameness and difference, learn the task from few training examples, and remain largely insensitive to spurious perceptual details. In contrast, lazy regime MLPs exhibit perceptual behavior, where they solve the task only after exhaustive training and show strong sensitivity to perceptual variations. Our specific contributions are the following.
Contributions
- We hand-craft a solution to our same-different task that is expressible by an MLP, demonstrating the possibility for our model to solve this task. Our solution suggests what conceptual representations may look like, guiding subsequent analysis.
- We argue that an MLP trained in a rich feature learning regime attains the hand-crafted solution, and exhibits three hallmarks of conceptual behavior: conceptual representations, efficient learning, and insensitivity to spurious perceptual details.
- We prove that an MLP trained in a lazy learning regime can also solve an equality reasoning task, but exhibits perceptual behavior: it requires exhaustive training data and shows strong sensitivity to spurious perceptual details.
- We extend our results to same-different tasks with noise, calculating Bayes optimal performance under priors that either generalize to arbitrary inputs or memorize the training set. We demonstrate that rich MLPs attain Bayes optimal performance under the generalizing prior.
- We validate our results on complex visual SD tasks, showing that our theoretical predictions continue to hold.
Our theory clarifies the understudied role of learning richness in driving successful reasoning, with potential implications for both neural network design and animal cognition.
### 1.1 Related work
In studying same-different tasks comparatively across animal species, E. Wasserman . (2017) observe a continuum between perceptual and conceptual behavior. Some animals focus on spurious perceptual details in the task stimuli like image variability, and slowly gain competence through exhaustive repetition. Other animals and humans appear to develop a conceptual understanding of sameness, allowing them to learn the task quickly and ignore irrelevant percepts. Many others fall somewhere in between, exhibiting behavior with both perceptual and conceptual components. These observations lend themselves to a theory where representations and learning mechanisms operate over a continuous domain (Carstensen Frank, 2021).
Neural networks offer a natural instantiation of such a continuous theory. However, the extent to which neural networks can reason at all remains a hotly contested topic. Famously, Fodor Pylyshyn (1988) argue that connectionist models are poorly equipped to describe human reasoning. G F. Marcus (1998) further contends that neural networks are altogether incapable of solving many simple symbolic tasks (see also G F. Marcus . (1999); G F. Marcus (2003); G. Marcus (2020)). Boix-Adsera . (2023) have also argued that MLPs are unable to generalize on relational problems like our same-different task, though this finding has been contested (Tong Pehlevan, 2024; A. Geiger ., 2023).
Negative assertions about neural network reasoning appear to weaken when considering modern LLMs, which routinely solve complex math and logic problems (Bubeck ., 2023; OpenAI, 2024; Guo ., 2025). But even here, doubts remain about whether LLMs truly reason or merely reproduce superficial aspects of their enormous training set (McCoy ., 2023; Mahowald ., 2024). Nonetheless, A. Geiger . (2023) found that simple MLPs convincingly solve same-different tasks after moderate training. Tong Pehlevan (2024) further showed that MLPs solve a wide variety of relational reasoning tasks. We support these findings by arguing that MLPs solve a same-different task, but their performance is modulated by learning richness. In resonance with E. Wasserman . (2017), varying richness pushes an MLP along a spectrum between a perceptual and conceptual solutions to the task.
Learning richness itself refers to the degree of change in a neural network’s internal representations during training. A number of network parameters were recently discovered to control learning richness, including the readout scale, initialization scheme, and learning rate (Chizat ., 2019; Woodworth ., 2020; Yang Hu, 2021; Bordelon Pehlevan, 2022). In the brain, learning richness may correspond to forming adaptive representations that encode task-specific variables, in contrast to fixed representations that remain task agnostic (Farrell ., 2023). Studies have used learning richness to understand the neural representations underlying diverse phenomena like context-dependent decision making (Flesch ., 2022), multitask cognition (Ito Murray, 2023), generalizing knowledge to new tasks (Johnston Fusi, 2023), and even consciousness (Mastrovito ., 2024).
## 2 Setup
We consider the following same-different task, inspired by the setup in A. Geiger . (2023). The task consists of input pairs $\mathbf{z}_{1},\mathbf{z}_{2}\in\mathbb{R}^{d}$ , where $\mathbf{z}_{i}=\mathbf{s}_{i}+\mathbf{\eta}_{i}$ . The labeling function $y$ is given by
$$
y(\mathbf{z}_{1},\mathbf{z}_{2})=\begin{cases}1&\mathbf{s}_{1}=\mathbf{s}_{2}\\
0&\mathbf{s}_{1}\neq\mathbf{s}_{2}\end{cases}\,.
$$
Quantities $\mathbf{s}$ correspond to “symbols,” perturbed by a small amount of noise $\mathbf{\eta}$ . Noise is distributed as $\mathbf{\eta}\sim\mathcal{N}(\mathbf{0},\sigma^{2}\mathbf{I}/d)$ , for some choice of $\sigma^{2}$ . Initially we will take $\sigma^{2}=0$ so $\mathbf{z}=\mathbf{s}$ , but we will allow $\sigma^{2}$ to be nonzero when considering a noisy extension to the task. Our definition of equality implies exact identity, up to possible noise. Other commonly studied variants include equality up to transformation (Fleuret ., 2011), hierarchical equality (Premack, 1983), context-dependent equality (Raven, 2003), among many others. We pursue exact identity for its tractability and ubiquity in the literature, and investigate more general notions of equality later with experiments in the noisy case and in vision tasks.
The model consists of a two-layer MLP without bias parameters
$$
f(\mathbf{x})=\frac{1}{\gamma\sqrt{d}}\sum_{i=1}^{m}a_{i}\,\phi(\mathbf{w}_{i}\cdot\mathbf{x})\,, \tag{1}
$$
where $\phi$ is a ReLU activation applied point-wise to its inputs. We use the standard logit link function to produce predictions $\hat{y}=1/(1+e^{-f})$ . Inputs are concatenated as $\mathbf{x}=(\mathbf{z}_{1};\mathbf{z}_{2})\in\mathbb{R}^{2d}$ before being passed to $f$ . The model is trained using binary cross entropy loss with a learning rate $\alpha=\gamma^{2}d\,\alpha_{0}$ , for a fixed $\alpha_{0}$ . Hidden weight vectors are initialized as $\mathbf{w}_{i}\sim\mathcal{N}(\mathbf{0},\mathbf{I}/m)$ , and readouts as $a_{i}\sim\mathcal{N}(0,1/m)$ . To enable interpolation between rich and lazy learning regimes, the MLP is centered such that $f(\mathbf{x})=0$ at initialization, for all inputs $\mathbf{x}$ . We use a standard procedure for centering, described in Appendix F. Occasionally, we gather all readouts $a$ and hidden weights $\mathbf{w}$ into a single set $\mathbf{\theta}$ , and write $f(\mathbf{x};\mathbf{\theta})$ to mean an MLP $f$ parameterized by weights $\mathbf{\theta}$ . We avoid considering bias parameters to simplify the analysis. In practice, because the task is symmetric about the origin, we find that bias plays little role.
The parameter $\gamma$ controls learning richness, where higher values of $\gamma$ correspond to greater richness (Chizat ., 2019; M. Geiger ., 2020; Woodworth ., 2020; Bordelon Pehlevan, 2022). A neural network trained in a rich regime experiences significant changes to its hidden activations $\phi(\mathbf{w_{i}\cdot\mathbf{x}})$ , resulting in task-specific representations. In contrast, a neural network trained in a lazy regime retains task-agnostic representations determined by their initialization. The limit $\gamma\rightarrow 0$ induces lazy behavior. Increasing $\gamma$ increases learning richness. For our tasks, we find that $\gamma=1$ produces sufficiently rich learning $\gamma=1$ produces a scaling that is similar to $\mu$ P or mean-field parametrization common elsewhere in the rich learning literature (Yang Hu, 2021; Mei ., 2018; Rotskoff Vanden-Eijnden, 2022). However, these scalings technically consider an infinite width limit. Our setting considers an infinite input dimension limit (Biehl Schwarze, 1995; Saad Solla, 1995; Goldt ., 2019), resulting in an extra $1/\sqrt{d}$ prefactor that is not present in these other scalings., and increasing $\gamma$ beyond 1 does not qualitatively change our results (Figure C3). Appendix F elaborates on our scaling scheme.
Crucially, the training set consists of a finite number of symbols $\mathbf{s}_{1},\mathbf{s}_{2},\ldots,\mathbf{s}_{L}$ . These $L$ symbols are sampled before training begins as $\mathbf{s}\sim\mathcal{N}(\mathbf{0},\mathbf{I}/d)$ , then used exclusively to train the model. Training examples are balanced such that half consist of same examples and half consist of different examples. During testing, symbols $\mathbf{s}$ are sampled afresh for every input, measuring the model’s ability to generalize on unseen test examples. If a model has learned equality reasoning, then it should attain perfect test accuracy despite having never witnessed the particular inputs. When $\sigma^{2}=0$ , this procedure is precisely equivalent to using one-hot encoded symbol inputs with a fixed embedding matrix, where the model is trained on a subset of all possible symbols. Additional details on our model and setup are enumerated in Appendix G.
### 2.1 Conceptual and perceptual behavior
Central to our framework is the distinction between conceptual and perceptual behavior. Conceptual behavior refers to a facility with abstract concepts, enabling the reasoner to learn an abstract task quickly and generalize with limited dependency on spurious details. Perceptual behavior refers to the opposite, where the reasoner solves a task through sensory association. Such learning is typically characterized by exhaustive training and marked sensitivity to spurious perceptual details.
We posit that learning richness moves an MLP between conceptual and perceptual behavior. We identify three specific characteristics of a conceptual outcome:
1. Conceptual representations. We look for evidence of task-specific representations that denote sameness or difference. Such representations should be crucial to solving the task, and contribute towards the model’s efficiency and insensitivity to spurious perceptual details (below).
1. Efficiency. We measure learning efficiency using the number of different symbols $L$ observed during training. A conceptual reasoner should solve the task with a smaller $L$ than a perceptual reasoner.
1. Insensitivity to spurious perceptual details. Spurious perceptual details refer to aspects of the task that influence the input but not the correct output. A readily measurable example is the input dimension $d$ . Sameness or difference can be evaluated regardless of $d$ . A conceptual reasoner should perform equally well when training on tasks across a variety of $d$ , whereas a perceptual reasoner may find certain $d$ harder to learn with than others. We therefore evaluate this insensitivity by comparing the test accuracy of models trained across a large range of input dimensions.
A perceptual solution is characterized by the negation of each point: it does not develop task-specific representations, it requires a large $L$ to solve the task, and test accuracy changes substantially with $d$ . While potentially possible to have a mixed solution that exhibits a subset of these points, we do not observe them in practice, and the conceptual/perceptual distinction is sufficiently descriptive of our model.
<details>
<summary>x1.png Details</summary>

### Visual Description
## Charts/Graphs: Neural Tangent Kernel Regime Performance
### Overview
The image presents a series of six sub-plots (a-f) illustrating the relationship between input dimension (d), number of symbols (L), and test accuracy in two different neural network regimes: a "rich" regime (γ = 1) and a "lazy" regime (γ ≈ 0). The plots use scatter plots and heatmaps to visualize these relationships, with a theoretical curve provided for comparison.
### Components/Axes
* **Sub-plots a & d:** Scatter plots with `aᵢ` on the x-axis and `(v₁² - v₂²)/lᵢ` on the y-axis.
* Sub-plot a is labeled "Rich regime (γ = 1)".
* Sub-plot d is labeled "Lazy regime (γ ≈ 0)".
* **Sub-plots b & e:** Heatmaps showing "Test acc." (Test Accuracy) as a function of "Input dimension (d)" on the x-axis and "# symbols (L)" on the y-axis.
* Sub-plot b corresponds to the "Rich regime".
* Sub-plot e corresponds to the "Lazy regime".
* **Sub-plots c & f:** Line plots showing "Test accuracy" on the y-axis versus "# symbols (L)" on the x-axis (log scale).
* Sub-plot c corresponds to the "Rich regime".
* Sub-plot f corresponds to the "Lazy regime".
* **Legend (Sub-plot c):** Located in the top-right corner, the legend identifies different curves based on the value of γ (gamma):
* Red dashed line: "Theory"
* Black line with circles: γ = 0.00
* Dark grey line with squares: γ = 0.05
* Grey line with triangles pointing down: γ = 0.10
* Light grey line with diamonds: γ = 0.25
* Very light grey line with plus signs: γ = 0.50
* Dotted horizontal line: "chance"
* **Axis Scales:** The x and y axes use logarithmic scales in some plots (specifically, the number of symbols L in subplots c and f).
### Detailed Analysis or Content Details
**Sub-plot a (Rich Regime):**
The scatter plot shows a sharp transition in `(v₁² - v₂²)/lᵢ` around `aᵢ = 0`. For `aᵢ < 0`, the value is approximately -1. For `aᵢ > 0`, the value is approximately 1.
**Sub-plot b (Rich Regime Heatmap):**
The heatmap shows a strong positive correlation between input dimension (d) and test accuracy. As 'd' increases, the test accuracy increases, reaching a maximum value of approximately 1.0. The heatmap also shows that test accuracy increases with the number of symbols (L), but the effect is less pronounced than with the input dimension.
* Test accuracy ranges from approximately 0.5 to 1.0.
* Input dimension (d) ranges from 2 to 737.
* Number of symbols (L) ranges from 2 to 1024.
**Sub-plot c (Rich Regime Line Plots):**
The line plots show test accuracy as a function of the number of symbols (L).
* The "Theory" curve (red dashed) starts at a test accuracy of approximately 0.6 and rapidly increases to 1.0 as L increases.
* The γ = 0.00 curve (black circles) closely follows the "Theory" curve.
* As γ increases (0.05, 0.10, 0.25, 0.50), the curves shift downward, indicating lower test accuracy for a given number of symbols.
* The "chance" line is a horizontal line at approximately 0.6.
**Sub-plot d (Lazy Regime):**
The scatter plot shows a dense cluster of points around `(v₁² - v₂²)/lᵢ ≈ 0` for all values of `aᵢ` between approximately -0.05 and 0.05.
**Sub-plot e (Lazy Regime Heatmap):**
The heatmap shows a weaker correlation between input dimension (d) and test accuracy compared to the rich regime. Test accuracy increases slightly with increasing input dimension, but the increase is less dramatic. Test accuracy also increases with the number of symbols (L), but again, the effect is less pronounced.
* Test accuracy ranges from approximately 0.6 to 0.9.
* Input dimension (d) ranges from 159 to 918.
* Number of symbols (L) ranges from 128 to 1024.
**Sub-plot f (Lazy Regime Line Plots):**
The line plots show test accuracy as a function of the number of symbols (L).
* The γ = 0.00 curve (black circles) remains relatively flat, with test accuracy around 0.65-0.7 for all values of L.
* As γ increases (0.05, 0.10, 0.25, 0.50), the curves remain close to the "chance" level.
* The "Theory" curve is not present in this subplot.
### Key Observations
* The "rich" regime (γ = 1) exhibits a strong dependence of test accuracy on both input dimension and the number of symbols.
* The "lazy" regime (γ ≈ 0) shows a much weaker dependence on these parameters.
* The theoretical curve in the "rich" regime closely matches the experimental results for γ = 0.00.
* The "chance" level represents a baseline performance, and the curves for higher values of γ in the "rich" regime fall below this level.
* The scatter plots (a and d) visually represent the distribution of values in each regime, highlighting the distinct behavior.
### Interpretation
The data suggests a clear distinction in performance between the "rich" and "lazy" regimes of neural network training. In the "rich" regime, the network is able to learn complex representations, leading to high test accuracy as the input dimension and number of symbols increase. This behavior aligns with the theoretical predictions. In contrast, the "lazy" regime exhibits limited learning capacity, resulting in test accuracy that remains close to the "chance" level regardless of the input dimension or number of symbols. This indicates that the network is not effectively utilizing its parameters in this regime.
The difference in behavior can be attributed to the value of γ, which controls the learning rate and the degree of parameter updates during training. A higher value of γ (γ = 1) allows for more significant parameter changes, enabling the network to explore a wider range of representations. A lower value of γ (γ ≈ 0) restricts parameter updates, leading to a more conservative learning process and limited representational capacity.
The heatmaps and line plots provide a visual representation of these trends, allowing for a clear comparison of performance across different regimes and parameter settings. The outliers and anomalies, such as the downward shift in test accuracy for higher values of γ in the "rich" regime, highlight the importance of carefully tuning the learning rate to achieve optimal performance.
</details>
Figure 1: Rich and lazy regime simulations. We confirm our theoretical predictions with numeric simulations. (a). Hidden weight alignment plotted against readout weights for a rich model. Weights become parallel or antiparallel, with generally higher magnitudes among negative readouts. (b) Test accuracy across different input dimensions and training symbols for a rich model ( $m=4096$ ). Accuracy is not affected by input dimension. (c) Test accuracy across different numbers of training symbols, for varying learning richness ( $d=256$ , $m=1024$ ). Richer models attain high performance with substantially fewer training symbols. The theoretically predicted rich test accuracy shows excellent agreement with our richest model. Finer-grain validation is plotted in Figure C3. (d) Hidden weight alignment plotted against readout weights for a lazy model. There is some correlation between alignment and readout weight, but the weights are nowhere near as close to being parallel/antiparallel as in the rich regime. (e) Test accuracy across input dimensions and training symbols for a lazy model ( $m=4096$ ). Accuracy is substantially affected by input dimension. Theory predicts that the number of training symbols required to maintain high accuracy scales at worst as $L\propto d^{2}$ , plotted in black. (f) Test accuracy across different input dimensions, for varying learning richness ( $L=16,m=1024$ ). Richer models show less performance decay with increasing dimension. (all) Results are computed across six runs. Shading corresponds to empirical 95 percent confidence intervals.
## 3 Same-different task analysis
We present our analysis of the SD task. We first hand-craft a solution that is expressible by our MLP, and in the process suggest what conceptual representations of sameness and difference may look like (Section 3.1). We proceed to argue that a rich-regime MLP attains the hand-crafted solution through training. It leverages its conceptual representations to learn the task with few training symbols and insensitivity to the input dimension (Section 3.2), exhibiting conceptual behavior. In contrast, a lazy-regime MLP is unable to adapt its representations to the task, and consequently incurs a high training cost and substantial sensitivity to input dimension (Section 3.3), exhibiting perceptual behavior. In an extension to a noisy version of our task, we show that a rich MLP approaches Bayes optimal performance under a generalizing prior across different noise variance $\sigma^{2}$ (Section 3.3). We validate our results on more complex, image-based tasks in Section 4, and discuss broader implications in Section 5.
### 3.1 Hand-crafted solution
To establish whether an MLP can solve the same-different task at all, we first outline a hand-crafted solution using $m=4$ hidden units. Let $\mathbf{1}=(1,1,\ldots,1)\in\mathbb{R}^{d}$ . Define the weight vector $\mathbf{w}_{1}^{+}$ by concatenation: $\mathbf{w}_{1}^{+}=(\mathbf{1};\mathbf{1})\in\mathbb{R}^{2d}$ . Further define $\mathbf{w}_{2}^{+}=(-\mathbf{1};-\mathbf{1})$ , $\mathbf{w}_{1}^{-}=(\mathbf{1};-\mathbf{1})$ , and $\mathbf{w}_{2}^{-}=(-\mathbf{1};\mathbf{1})$ . Let $a^{+}=1$ and $a^{-}=\rho$ , for some value $\rho>0$ . Our MLP is given by
$$
\displaystyle f(\mathbf{x})=\,\, \displaystyle a^{+}\big{(}\phi(\mathbf{w}_{1}^{+}\cdot\mathbf{x})+\phi(\mathbf{w}_{2}^{+}\cdot\mathbf{x})\big{)} \displaystyle- \displaystyle a^{-}\big{(}\phi(\mathbf{w}_{1}^{-}\cdot\mathbf{x})+\phi(\mathbf{w}_{2}^{-}\cdot\mathbf{x})\big{)}\,. \tag{2}
$$
Note that the weight vectors $\mathbf{w}_{1}^{+},\mathbf{w}_{2}^{+}$ , which correspond to the positive readout $a^{+}$ , are parallel: their components point in the same direction with the same magnitude. Meanwhile, weight vectors $\mathbf{w}_{1}^{-},\mathbf{w}_{2}^{-}$ corresponding to the negative readout $a^{-}$ are antiparallel: their components point in exact opposite directions with the same magnitude. Only the sign of $f$ impacts the classification, so we assign $a^{+}=1$ and $a^{-}=\rho$ to represent the relative magnitude of $a^{-}$ against $a^{+}$ .
To see how this weight configuration solves the same-different task, suppose we receive a same example $\mathbf{x}=(\mathbf{z},\mathbf{z})$ . Plugging this into Eq (2) reveals that the negative terms vanish through our antiparallel weights, leaving $f(\mathbf{x})=2\,|\mathbf{1}\cdot\mathbf{z}|>0$ , correctly classifying this example.
Now suppose we receive a different example $\mathbf{x}^{\prime}=(\mathbf{z},\mathbf{z}^{\prime})$ . Recall that these quantities are sampled independently as $\mathbf{z},\mathbf{z}^{\prime}\sim\mathcal{N}(\mathbf{0},\mathbf{I}/d)$ . As a result, we can no longer rely on a convenient cancellation. The quantity $\phi(\mathbf{w}_{1}^{+}\cdot\mathbf{x}^{\prime})+\phi(\mathbf{w}_{2}^{+}\cdot\mathbf{x}^{\prime})$ is equal in distribution to the quantity $\phi(\mathbf{w}_{1}^{-}\cdot\mathbf{x}^{\prime})+\phi(\mathbf{w}_{2}^{-}\cdot\mathbf{x}^{\prime})$ , with respect to the randomness in $\mathbf{x}^{\prime}$ . Hence, to implement a consistent negative classification, we need to raise the relative magnitude $\rho$ of our negative readout weight. Indeed, we calculate $p(f(\mathbf{x}^{\prime})<0)=\frac{2}{\pi}\tan^{-1}(\rho)$ , which approaches $1$ for $\rho\gg 1$ . Full details are recorded in Appendix B. Hence, by maintaining a large negative readout, we classify negative examples correctly with high probability. An illustration of this solution is provided in Figure B1.
An MLP need not implement this precise weight configuration to solve the SD task. Rather, our hand-crafted solution suggests two general conditions:
1. Parallel/antiparallel weight vectors. Weights associated with positive readouts must be parallel, and weight associated with negative readouts must be antiparallel. This allows us to classify any same example by canceling the contribution from negative readouts.
1. Large negative readouts. The cumulative magnitude of the negative readouts must be larger than that of the positive readouts. This allows us to classify any different example by raising the contribution from negative readouts.
Observe also that parallel and antiparallel weights are suggestive of conceptual representations for sameness and difference. Parallel weights contribute to a same classification, and exemplify the structure of a same example: the two components point in the same direction. Antiparallel weights contribute to a different classification, and exemplify the structure of a different example: the two components point as far apart as possible. We look for parallel/antiparallel weight vectors as evidence for conceptual representations of our SD task.
### 3.2 Rich regime
The rich learning regime is characterized by substantial weight changes throughout the course of training. For the MLP given in Eq (1), larger values of $\gamma$ lead to rich learning behavior. We allow $\gamma$ to vary between 0 and 1. The range $\gamma>1$ is considered in Figure C3, where we see that no qualitative changes to our results occur for larger values of $\gamma$ .
To study the rich regime, we take two approaches. First, recent theoretical work (Morwani ., 2023; Wei ., 2019; Chizat Bach, 2020) suggest that MLPs trained in a rich learning regime on a classification task discover a max margin solution: the weights maximize the distance between training points of different classes. We derive the max margin weights for an MLP with quadratic activations in Theorem 1, finding that the max margin solution consists of parallel/antiparallel weight vectors, just as required from our hand-crafted solution. We defer the proof of this theorem to Appendix C.
**Theorem 1**
*Let $\mathcal{D}=\{\mathbf{x}_{n},y_{n}\}_{n=1}^{P}$ be a training set consisting of $P$ points sampled across $L$ training symbols, as specified in Section 2. Let $f$ be the MLP given by Eq 1, with two changes:
1. Fix the readouts $a_{i}=\pm 1$ , where exactly $m/2$ readouts are positive and the remaining are negative.
1. Use quadratic activations $\phi(\cdot)=(\cdot)^{2}$ .
For weights $\mathbf{\theta}=\left\{\mathbf{w}_{i}\right\}_{i=1}^{m}$ , define the max margin set $\Delta(\mathbf{\theta})$ to be
$$
\Delta(\mathbf{\theta})=\operatorname*{arg\,max}_{\mathbf{\theta}}\frac{1}{P}\sum_{n=1}^{P}\left[(2y_{n}-1)f(\mathbf{x}_{n};\mathbf{\theta})\right]\,,
$$
subject to the norm constraints $\left|\left|\mathbf{w}_{i}\right|\right|=1$ . If $P,L\rightarrow\infty$ , then for any $\mathbf{w}_{i}=(\mathbf{v}_{i}^{1};\mathbf{v}_{i}^{2})\in\Delta(\mathbf{\theta})$ and $\ell_{i}=\left|\left|\mathbf{v}_{i}^{1}\right|\right|\,\left|\left|\mathbf{v}_{i}^{2}\right|\right|$ , we have that $\mathbf{v}_{i}^{1}\cdot\mathbf{v}_{i}^{2}/\ell_{i}=1$ if $a_{i}=1$ and $\mathbf{v}_{i}^{1}\cdot\mathbf{v}_{i}^{2}/\ell_{i}=-1$ if $a_{i}=-1$ . Further, $\left|\left|\mathbf{v}_{i}^{1}\right|\right|=\left|\left|\mathbf{v}_{i}^{2}\right|\right|$ .*
However, the max margin result does not use ReLU MLPs, relies on fixed readouts $a_{i}$ , and says nothing about learning efficiency or insensitivity to spurious perceptual details, two additional properties we require from a conceptual solution. To address these shortcomings, we extend the analysis by proposing a heuristic construction that approximates a rich ReLU MLP as an ensemble of independent Markov processes (Section C.2). Doing so enables a deeper characterization of rich learning dynamics, resulting in the following approximation of the test accuracy. Given an unseen test point $\mathbf{x},y$ , and prediction $\hat{y}$ ,
$$
p(y=\hat{y}(\mathbf{x}))\approx\frac{1}{2}+\frac{1}{2}\Phi\left(\sqrt{\frac{2(L^{2}-L)}{13(\pi-2)}}\right)\,, \tag{3}
$$
where $L$ is the number of training symbols and $\Phi$ is the CDF of a standard normal distribution. This estimate is for $L\geq 3$ . For $L=2$ , $p(y=\hat{y})=3/4$ . See Section C.5 for details. This estimate suggests that the model attains over 95 percent test accuracy with as few as $L=5$ training symbols, and test accuracy does not change with different $d$ .
We confirm our theoretical predictions with simulations in Figure 1. At the end of training, the hidden weights indeed become parallel and antiparallel, with negative coefficients gaining larger magnitude (Figure 1 a). Figures 1 b and c show that the rich model learns the same-different task with substantially fewer training symbols than lazier models, and exhibits excellent agreement with our theoretical test accuracy prediction. As predicted, the rich model’s performance does not vary with input dimension (Figure 1 b).
Altogether, the rich model develops conceptual representations, learns the same-different task given only a small number of training symbols, and exhibits clear insensitivity to input dimension. In this way, it exhibits conceptual behavior on the same-different task.
<details>
<summary>x2.png Details</summary>

### Visual Description
## Chart: Test Accuracy vs. Number of Symbols
### Overview
The image presents a series of four line plots, each representing the relationship between test accuracy and the number of symbols (L) under different noise conditions (σ²) and regularization strengths (γ). Each plot displays two data series: "Bayes gen" (represented by solid lines) and "Bayes mem" (represented by dashed lines), with multiple lines within each series corresponding to different values of γ. A horizontal dashed red line at y=1.0 indicates perfect accuracy.
### Components/Axes
* **X-axis:** "# symbols (L)" - Number of symbols, plotted on a logarithmic scale from approximately 2 to 100.
* **Y-axis:** "Test accuracy" - Ranges from 0 to 1.1, representing the accuracy of the test.
* **Plots:** Four separate plots, each with a title indicating the value of σ² (variance): σ² = 0, σ² = 1, σ² = 2, σ² = 4.
* **Legend:** Located in the top-right corner, listing the values of γ (regularization strength): γ = 10⁰, γ = 10⁻¹, γ = 10⁻², γ = 10⁻³, γ = 10⁻⁴, and γ = 0.
* **Data Series:**
* "Bayes gen" (solid lines)
* "Bayes mem" (dashed lines)
### Detailed Analysis
**Plot 1: σ² = 0**
* **Bayes gen:**
* γ = 10⁰ (red): Starts at approximately 0.5, quickly rises to 1.0 and remains there.
* γ = 10⁻¹ (orange): Starts at approximately 0.5, rises to around 0.95, and plateaus.
* γ = 10⁻² (yellow): Starts at approximately 0.5, rises to around 0.85, and plateaus.
* γ = 10⁻³ (green): Starts at approximately 0.5, rises to around 0.75, and plateaus.
* γ = 10⁻⁴ (blue): Starts at approximately 0.5, rises to around 0.65, and plateaus.
* γ = 0 (purple): Starts at approximately 0.5, rises to around 0.6, and plateaus.
* **Bayes mem:**
* γ = 10⁰ (red): Starts at approximately 0.5, rises to 1.0 and remains there.
* γ = 10⁻¹ (orange): Starts at approximately 0.5, rises to around 0.95, and plateaus.
* γ = 10⁻² (yellow): Starts at approximately 0.5, rises to around 0.85, and plateaus.
* γ = 10⁻³ (green): Starts at approximately 0.5, rises to around 0.75, and plateaus.
* γ = 10⁻⁴ (blue): Starts at approximately 0.5, rises to around 0.65, and plateaus.
* γ = 0 (purple): Starts at approximately 0.5, rises to around 0.6, and plateaus.
**Plot 2: σ² = 1**
* **Bayes gen:** Similar trend to σ² = 0, but the lines reach lower maximum accuracy values.
* **Bayes mem:** Similar trend to σ² = 0, but the lines reach lower maximum accuracy values.
**Plot 3: σ² = 2**
* **Bayes gen:** Similar trend to σ² = 1, but the lines reach even lower maximum accuracy values.
* **Bayes mem:** Similar trend to σ² = 1, but the lines reach even lower maximum accuracy values.
**Plot 4: σ² = 4**
* **Bayes gen:** Similar trend to σ² = 2, but the lines reach even lower maximum accuracy values.
* **Bayes mem:** Similar trend to σ² = 2, but the lines reach even lower maximum accuracy values.
In all plots, the lines generally slope upwards, indicating that increasing the number of symbols (L) improves test accuracy. However, the rate of improvement decreases as L increases, and the maximum achievable accuracy decreases as σ² increases. The lines corresponding to higher values of γ (stronger regularization) tend to plateau at lower accuracy values.
### Key Observations
* As the noise level (σ²) increases, the maximum achievable test accuracy decreases.
* Stronger regularization (higher γ) generally leads to lower maximum accuracy, but can prevent overfitting for smaller values of L.
* Both "Bayes gen" and "Bayes mem" show similar trends, but "Bayes gen" consistently achieves slightly higher accuracy than "Bayes mem" for a given set of parameters.
* The effect of γ is more pronounced at lower values of L.
### Interpretation
The data suggests that the performance of the Bayesian models is sensitive to both noise and regularization. Increasing the number of symbols generally improves accuracy, but the benefit diminishes as the number of symbols grows. The optimal level of regularization depends on the noise level and the number of symbols. In a low-noise environment (σ² = 0), strong regularization can hinder performance, while in a noisy environment (σ² = 4), regularization may be necessary to prevent overfitting. The slight difference in performance between "Bayes gen" and "Bayes mem" could indicate that one approach is more robust to noise or better suited for the given task. The logarithmic scale on the x-axis suggests that the initial increase in accuracy with a small number of symbols is more significant than the increase with a large number of symbols. This could be due to the diminishing returns of adding more symbols once a certain level of information has been captured.
</details>
Figure 2: Bayesian simulations. Test accuracy across different numbers of training symbols, for varying richness and noise ( $d=64$ , $m=1024$ ). Bayes optimal accuracy for both generalizing and memorizing priors are plotted with dashed lines. In all cases, rich models attain the Bayes optimal test accuracy under a generalizing prior after sufficiently many training symbols. Shaded error regions are computed across six runs and correspond to empirical 95 percent confidence intervals.
### 3.3 Lazy regime
The lazy learning regime is characterized by vanishingly small change in the model’s hidden representations after training. Smaller values of $\gamma$ lead to lazy learning behavior. The limit $\gamma\rightarrow 0$ corresponds to the Neural Tangent Kernel (NTK) regime, where the network is well-described by a linearization around its initialization (Jacot ., 2018). In our numerics, we approximate this limit by using $\gamma=(1\times 10^{-5})/\sqrt{d}$ .
Because a lazy neural network cannot adapt its representations to an arbitrary pattern, it is impossible for a lazy MLP to learn parallel/antiparallel weights. However, because the statistics of a same example differ from that of a different example, it may still be possible for a lazy MLP to succeed at the task given enough training data For example, for a same input $\mathbf{x}=(\mathbf{z};\mathbf{z})$ and a different input $\mathbf{x}^{\prime}=(\mathbf{z}_{1},\mathbf{z}_{2})$ , the variance of $\mathbf{1}\cdot\mathbf{x}$ is twice that of $\mathbf{1}\cdot\mathbf{x}^{\prime}$ . Leveraging distinct statistics like this may still allow the lazy model to learn this task.. Using standard kernel arguments (Cho Saul, 2009; Jacot ., 2018), we bound the test error of a lazy MLP in Theorem 2. The proof is deferred to Appendix D.
**Theorem 2 (informal)**
*Let $f$ be an infinite-width ReLU MLP. If $f$ is trained on a dataset consisting of $P$ points constructed from $L$ symbols with input dimension $d$ , then the test error of $f$ is upper bounded by $\mathcal{O}\left(\exp\left\{-L/d^{2}\right\}\right)$ .*
This bound suggests that to maintain a consistently low test error (or equivalently, high test accuracy), the number of training symbols $L$ needs to scale quadratically (at worst) with the input dimension: $L\propto d^{2}$ .
We support our theoretical predictions with simulations in Figure 1. Because the model is in a lazy regime, the hidden weights do not move far from initialization, and no clear parallel/antiparallel structure emerges (Figure 1 d). Figure 1 c shows how models require increasingly more training data as richness decreases. Lazier models are also substantially more impacted by changes in input dimension (Figure 1 e and f), and the scaling of training symbols with input dimension is consistent with our theory (Figure 1 e).
Altogether, the lazy model is unable to learn conceptual representations, instead relying on statistical associations that require a large amount of training data to learn and exhibit strong sensitivity to input dimension. In this way, the lazy model exhibits perceptual behavior on the same-different task.
### 3.4 Same-different with noise
Up until now, we defined equality by exact identity: even a minuscule deviation in a single coordinate is enough to break equality and classify an example as different. Reality is far less clean, and real-world objects are rarely equal up to exact identity. As a first step towards this broader setting, we relax our dependence on exact identity and consider a noisy SD task. In the notation of our setup (Section 2), we allow $\sigma^{2}>0$ .
To understand optimal performance under noise, we apply the following Bayesian framework. As a baseline, we consider a prior corresponding to an idealized model which memorizes the training symbols. This memorizing prior assumes every input symbol is distributed uniformly among the training symbols. To contrast this baseline, we consider a gold-standard prior corresponding to a model which generalizes to novel symbols. This generalizing prior assumes every input symbol follows the true underlying distribution. By comparing the test accuracy of the trained models to the posteriors computed in these two settings, we identify which prior more closely reflects the models’ operation. The calculation of these posteriors are recorded in Appendix E.
Results are plotted in Figure 2. In all cases, we find that the rich model approaches Bayes optimal under the generalizing prior. Lazier models tend to plateau at lower test accuracies; they nonetheless tend to exceed the performance of the memorizing prior at higher noise, indicating some level of generalization. Overall, learning richness appears to support convincing generalization to novel training symbols in the noisy SD task.
<details>
<summary>x3.png Details</summary>

### Visual Description
## Charts/Graphs: Performance Evaluation of Visual Representation Learning
### Overview
The image presents a series of charts evaluating the performance of a visual representation learning model across three different datasets: PSVRT, Pentomino, and CIFAR-100. The performance metric is "Test accuracy," plotted against varying parameters of each dataset (bit-patterns, patches, shapes, classes). Each chart includes a horizontal line representing "chance" level accuracy. Multiple lines within each chart represent different values of a parameter γ (gamma).
### Components/Axes
The image is organized into a 3x3 grid of subplots.
* **Datasets:** PSVRT (a, b, c), Pentomino (d, e, f), CIFAR-100 (g, h).
* **X-axes:**
* (b) # bit-patterns (scale: 2<sup>6</sup> to 2<sup>9</sup>)
* (c) # patches (scale: 5 to 10)
* (e) # shapes (scale: 5 to 15)
* (f) # patches (scale: 2 to 10)
* (h) # classes (scale: 2<sup>4</sup> to 2<sup>6</sup>)
* **Y-axes:** All charts share a "Test accuracy" scale ranging from 0.6 to 1.0.
* **Horizontal Line:** A dashed horizontal line labeled "chance" is present in all charts, positioned around y = 0.5.
* **Legend (bottom-right of h):**
* γ = 10<sup>0</sup> (solid line, dark grey)
* γ = 10<sup>-1</sup> (solid line, grey)
* γ = 10<sup>-2</sup> (solid line, light grey)
* γ = 10<sup>-3</sup> (dashed line, dark grey)
* γ = 10<sup>-4</sup> (dashed line, grey)
* γ ≈ 0 (dashed line, light grey)
* **Images (g):** Four example images from the CIFAR-100 dataset are displayed.
### Detailed Analysis or Content Details
**PSVRT (a, b, c):**
* **(b) Test accuracy vs. # bit-patterns:** All lines show an increasing trend in test accuracy as the number of bit-patterns increases. The γ = 10<sup>0</sup> line (dark grey) consistently achieves the highest accuracy, approaching 0.95 at 2<sup>9</sup> bit-patterns. The γ ≈ 0 line (light grey dashed) starts near the chance level and shows a modest increase, remaining below 0.7.
* At 2<sup>6</sup> bit-patterns: Accuracy values are approximately: γ = 10<sup>0</sup>: 0.75, γ = 10<sup>-1</sup>: 0.72, γ = 10<sup>-2</sup>: 0.68, γ = 10<sup>-3</sup>: 0.65, γ = 10<sup>-4</sup>: 0.63, γ ≈ 0: 0.58.
* At 2<sup>9</sup> bit-patterns: Accuracy values are approximately: γ = 10<sup>0</sup>: 0.95, γ = 10<sup>-1</sup>: 0.92, γ = 10<sup>-2</sup>: 0.88, γ = 10<sup>-3</sup>: 0.85, γ = 10<sup>-4</sup>: 0.82, γ ≈ 0: 0.75.
* **(c) Test accuracy vs. # patches:** The γ = 10<sup>0</sup> line (dark grey) starts high and decreases slightly as the number of patches increases, remaining above 0.85. The γ ≈ 0 line (light grey dashed) shows a slight increase, but remains below 0.7.
* At 5 patches: Accuracy values are approximately: γ = 10<sup>0</sup>: 0.90, γ = 10<sup>-1</sup>: 0.85, γ = 10<sup>-2</sup>: 0.80, γ = 10<sup>-3</sup>: 0.75, γ = 10<sup>-4</sup>: 0.70, γ ≈ 0: 0.65.
* At 10 patches: Accuracy values are approximately: γ = 10<sup>0</sup>: 0.85, γ = 10<sup>-1</sup>: 0.80, γ = 10<sup>-2</sup>: 0.75, γ = 10<sup>-3</sup>: 0.70, γ = 10<sup>-4</sup>: 0.65, γ ≈ 0: 0.60.
**Pentomino (d, e, f):**
* **(e) Test accuracy vs. # shapes:** All lines show an increasing trend in test accuracy as the number of shapes increases. The γ = 10<sup>0</sup> line (dark grey) consistently achieves the highest accuracy, approaching 1.0 at 15 shapes. The γ ≈ 0 line (light grey dashed) shows a modest increase, remaining below 0.7.
* At 5 shapes: Accuracy values are approximately: γ = 10<sup>0</sup>: 0.85, γ = 10<sup>-1</sup>: 0.80, γ = 10<sup>-2</sup>: 0.75, γ = 10<sup>-3</sup>: 0.70, γ = 10<sup>-4</sup>: 0.65, γ ≈ 0: 0.60.
* At 15 shapes: Accuracy values are approximately: γ = 10<sup>0</sup>: 0.98, γ = 10<sup>-1</sup>: 0.95, γ = 10<sup>-2</sup>: 0.90, γ = 10<sup>-3</sup>: 0.85, γ = 10<sup>-4</sup>: 0.80, γ ≈ 0: 0.70.
* **(f) Test accuracy vs. # patches:** Similar to PSVRT (c), the γ = 10<sup>0</sup> line (dark grey) starts high and decreases slightly as the number of patches increases, remaining above 0.8. The γ ≈ 0 line (light grey dashed) shows a slight increase, but remains below 0.7.
**CIFAR-100 (g, h):**
* **(h) Test accuracy vs. # classes:** All lines show an increasing trend in test accuracy as the number of classes increases. The γ = 10<sup>0</sup> line (dark grey) consistently achieves the highest accuracy, approaching 0.95 at 2<sup>6</sup> classes. The γ ≈ 0 line (light grey dashed) shows a modest increase, remaining below 0.7.
* At 2<sup>4</sup> classes: Accuracy values are approximately: γ = 10<sup>0</sup>: 0.75, γ = 10<sup>-1</sup>: 0.72, γ = 10<sup>-2</sup>: 0.68, γ = 10<sup>-3</sup>: 0.65, γ = 10<sup>-4</sup>: 0.63, γ ≈ 0: 0.58.
* At 2<sup>6</sup> classes: Accuracy values are approximately: γ = 10<sup>0</sup>: 0.95, γ = 10<sup>-1</sup>: 0.92, γ = 10<sup>-2</sup>: 0.88, γ = 10<sup>-3</sup>: 0.85, γ = 10<sup>-4</sup>: 0.82, γ ≈ 0: 0.75.
### Key Observations
* Higher values of γ generally lead to higher test accuracy across all datasets and parameters.
* The γ = 10<sup>0</sup> line consistently outperforms all other γ values.
* The "chance" level accuracy is consistently below the performance of all γ values, indicating that the model is learning something beyond random chance.
* In PSVRT (c) and Pentomino (f), increasing the number of patches beyond a certain point leads to a slight decrease in accuracy for higher γ values.
### Interpretation
The data suggests that the visual representation learning model is effective at learning representations from these datasets. The parameter γ appears to control the strength of some regularization or learning signal, with higher values leading to better performance. The slight decrease in accuracy with increasing patches in PSVRT (c) and Pentomino (f) might indicate that the model is overfitting to the specific patch configurations at higher numbers. The consistent improvement in accuracy with increasing bit-patterns, shapes, and classes suggests that the model benefits from more complex and diverse input data. The fact that all γ values outperform chance level indicates that the model is not simply memorizing the training data but is learning meaningful features. The consistent performance of γ = 10<sup>0</sup> suggests that this value represents an optimal balance between model complexity and generalization ability. The images in (g) demonstrate the diversity of the CIFAR-100 dataset, which likely contributes to the model's ability to learn robust representations.
</details>
Figure 3: Visual same-different results. (a) PSVRT examples for same (left) and different (right). (b,c) Test accuracy on PSVRT across different numbers of training bit-patterns and image widths. Richer models learn the task with fewer patterns and exhibit less sensitivity to larger sizes. (d) Pentomino examples for same (left) and different (right). (e,f) Test accuracy on Pentomino across training shapes and image widths. As before, richer models learn the task with fewer training shapes and exhibit less sensitivity to larger sizes, though performance across models tends to diminish somewhat with increasing image size. (g) CIFAR-100 examples for same (left) and different (right). (h) Test accuracy on CIFAR-100 same-different across training classes. Richer models tend to perform better with fewer classes, though the richest model in this example performs worse. For this task, very rich models may overfit, necessitating an optimal richness level. (all) Shaded error regions are computed across six runs and correspond to empirical 95 percent confidence intervals.
## 4 Validation in vision tasks
To validate our theoretical findings in a more complex, naturalistic setting, we turn to visual same-different tasks. Specifically, we examine three datasets designed originally to study visual reasoning and computer vision: 1) PSVRT (Kim ., 2018), 2) Pentomino (Gülçehre Bengio, 2016), and 3) CIFAR-100 (Krizhevsky Hinton, 2009). These tasks offer significantly more challenge over the simple SD task we examine before. Rather than reason over symbol embeddings, a model must now reason over complex visual objects. Inputs are now images, and equality is no longer exact identity: inputs can be equal up to translation (in PSVRT), rotation (in Pentomino), or merely share a class label (CIFAR-100). All additional details on model and task configurations are enumerated in Appendix G.
We continue to use the same MLP model as before. Images are flattened before input to the model. Though better performance may be attained using CNNs or Vision Transformers, our ultimate goal is to study learning richness rather than maximize performance. Nonetheless, as we will soon see, an MLP performs astonishingly well on these tasks despite its simplicity — provided it remains in a rich learning regime. To validate our theoretical findings, we should continue to see the three hallmarks of a conceptual solution (conceptual representations, efficiency, and insensitivity to spurious perceptual details), but only in rich MLPs.
### 4.1 PSVRT
The parameterized-SVRT (PSVRT) dataset is a version of the Synthetic Visual Reasoning Test (SVRT), a collection of challenging visual reasoning tasks based on abstract shapes (Kim ., 2018). PSVRT replaces the original shapes with random bit-patterns in order to better control image variability. The task input consists of an image that has two blocks of bit-patterns, placed randomly on a blank background. The model must determine whether the blocks contain the same bit pattern, or different patterns. The training set consists of a fixed number of predetermined bit patterns. The test set consists of novel bit-patterns never encountered during training.
Bit-patterns are patch-aligned: they occur in non-overlapping locations that tile the image. The width of an image may be specified by the number of patches. Figure 3 a illustrates examples from PSVRT that are three patches wide.
Results. Figure 3 b plots a model’s test accuracy on PSVRT as a function of the number of training patterns. As our theory suggests, richer models learn the task more easily and generalize after substantially fewer training patterns. To test our models under perceptual variation, we consider larger image sizes. We keep the same size of bit-patterns, but increase the number of patches to make a bigger input. Figure 3 c indicates that a rich model continues to perform perfectly irrespective of image size, whereas lazier models exhibit a performance decay with larger inputs.
Finally, we identify parallel/antiparallel analogs for PSVRT in the weights of a rich model (Figure A1 a). The presence of these conceptual representations suggests that our theory remains a reasonable description for how a rich MLP may learn a conceptual solution to the PSVRT same-different task.
### 4.2 Pentomino
The Pentomino task uses inputs that are pentomino polygons: shapes consisting of five squares glued by edge (Gülçehre Bengio, 2016). The input consists of an image with two pentominoes, placed arbitrarily on a blank background. The pentominoes may either be the same shape, or different. In contrast with the PSVRT task, sameness in this task implies equality up to rotation. After training on a fixed set of pentomino shapes, the model must generalize to entirely novel shapes. Like with PSVRT, shapes are patch-aligned. Figure 3 d illustrates example inputs from this task that are three patches wide.
Results. Figure 3 e plots a model’s test accuracy on Pentomino as a function of the number of training shapes. Consistent with our theory, richer models learn the task more easily and generalize after substantially fewer training shapes. To test our models under perceptual variation, we consider larger image sizes. Like with PSVRT, we add additional patches to enlarge the input. Figure 3 f indicates that a rich model continues to perform well on larger image sizes, though its performance does start to decay somewhat. Performance decays substantially faster for lazier models.
We again identify parallel/antiparallel analogs for Pentomino in the weights of a rich model (Figure A1 c). The presence of these conceptual representations continues to support our theoretical perspective. Notably, Gülçehre Bengio (2016) introduced this task to motivate curriculum learning, finding that their MLP fails to perform above chance. We found that curriculum learning is unnecessary in the presence of sufficient richness.
### 4.3 CIFAR-100
The CIFAR-100 dataset consists of 60 thousand real-world images, each 32 by 32 pixels (Krizhevsky Hinton, 2009). Images belong to one of 100 different classes. In this task, the input consists of two different unlabeled images that belong either to the same or different classes. After training on images from a fixed set of labels, the model must generalize to entirely novel labels. The sets of train and test labels are disjoint, making this an extremely challenging task. The labels themselves are not provided in any form during training. Example inputs are illustrated in Figure 3 g. We also experiment with providing features from VGG-16 pretrained on ImageNet (Simonyan Zisserman, 2014). We pass CIFAR-100 images to VGG-16, then use intermediate features as inputs to our MLP. The weights of VGG-16 are fixed throughout the whole process. Note also that ImageNet is disjoint from CIFAR-100, so there is limited possibility of contamination in the test images.
Results. Figure 3 h plots a model’s test accuracy on CIFAR-100 images as a function of the number of training classes. We use outputs from VGG-16 block 4, layer 3, which performed the best with our model. As before, richer models tend to perform better with fewer training classes, but with a curious exception: in contrast to the previous two tasks, the richest model does not always perform decisively the best. This is particularly evident using the activations from other intermediate VGG layers, plotted in Figure G 1. For certain layers and number of training classes, the optimal $\gamma$ appears to be somewhat less than 1. This outcome may be in part an artifact of overfitting. Given the complexity of the task and the limited data, richer models are plausibly more susceptible to idiosyncratic features of the training set that generalize poorly, analogous to overfitting effects in classical statistics that degrade the performance of powerful models. In this case, slightly less learning richness may be the optimal setting. Since CIFAR-100 images are fixed to 32 by 32 pixels, we skipped testing variable image size for this task.
As before, we identify parallel/antiparallel analogs for this task in the weights of a rich model (Figure A1 e). The general benefit of richness together with the presence of conceptual representations continues to align with our theoretical perspective. Across our three visual same-different tasks, we identified generally consistent relationships between learning richness, conceptual solutions, and good performance, supporting our theoretical findings.
## 5 Discussion
We studied equality reasoning using a simple same-different task. We showed that learning richness drives the development of either conceptual or perceptual behavior. Rich MLPs develop conceptual representations, learn from few training examples, and remain largely insensitive to perceptual variation. Meanwhile, lazy MLPs require exhaustive training examples and deteriorate substantially with spurious perceptual changes.
Varying learning richness recapitulates E. Wasserman . (2017) ’s continuum between perceptual and conceptual behavior on same-different tasks. Perhaps a pigeon’s competency at equality reasoning may be broadly comparable to a lazy MLP’s, requiring a great deal of training and exhibiting persistent sensitivity to spurious details. Perhaps equality reasoning in human or even language-trained great apes may be comparable to a rich MLP, where learning is faster, less sensitive to spurious details, and presumably involves conceptual abstractions. We suggest that a key parameter underlying these behavioral differences may be learning richness.
Learning richness is a concept imported from machine learning theory, and it is not altogether clear how to measure richness in a living brain. Since richness specifies the degree of change in a neural network’s hidden representations, the most direct analogy in the brain is to look for adaptive representations that seem to encode task-specific variables. Such approaches have implicated richness as an essential property for context-dependent decision making, multitask cognition, generalizing knowledge, among many other phenomena (Flesch ., 2022; Ito Murray, 2023; Johnston Fusi, 2023; Farrell ., 2023). Our theory predicts that greater learning richness relates to faster generalization in equality reasoning, and look forward to possible experimental validation of this principle.
Our work also contributes to the longstanding debate on a neural network’s facility with abstract reasoning. Rich MLPs demonstrate successful generalization to unseen symbols irrespective of input dimension or even high noise variance. Further, the rich MLP’s development of parallel/antiparallel components suggests the formation of abstractions, supporting the account that neural networks may indeed learn to develop and manipulate symbolic representations.
Practically, we demonstrate that learning richness is a vital hyperparameter. Increasing richness generally increases test performance substantially, improves data efficiency, and reduces sensitivity to spurious details. For complex tasks tuned with a large range of $\gamma$ , there may be an optimal level of richness. Indeed, for CIFAR-100, we observed that more richness is not always better, and an optimal level exists. We encourage more widespread application of richness parametrizations like $\mu$ P, and advocate for adding $\gamma$ to the list of tunable hyperparameters that every practitioner must consider when developing neural networks (Atanasov ., 2024).
Acknowledgments.
Special thanks to Alex Atanasov for a serendipitous conversation that inspired much of this project. We also thank Hamza Chaudhry, Ben Ruben, Sab Sainathan, Jacob Zavatone-Veth, and members of the Pehlevan Group for many helpful comments and discussions on our manuscript. WLT is supported by a Kempner Graduate Fellowship. CP is supported by NSF grant DMS-2134157, NSF CAREER Award IIS-2239780, DARPA grant DIAL-FP-038, a Sloan Research Fellowship, and The William F. Milton Fund from Harvard University. This work has been made possible in part by a gift from the Chan Zuckerberg Initiative Foundation to establish the Kempner Institute for the Study of Natural and Artificial Intelligence. The computations in this paper were run on the FASRC cluster supported by the FAS Division of Science Research Computing Group at Harvard University.
## References
- Alhama Zuidema (2019) Alhama2019-review Alhama, R G. Zuidema, W. 20191 08. A review of computational models of basic rule learning: The neural-symbolic debate and beyond A review of computational models of basic rule learning: The neural-symbolic debate and beyond. Psychon Bull Rev2641174–1194.
- Atanasov . (2024) atanasov2024ultrarich Atanasov, A., Meterez, A., Simon, J B. Pehlevan, C. 2024. The Optimization Landscape of SGD Across the Feature Learning Strength The optimization landscape of sgd across the feature learning strength. arXiv preprint arXiv:2410.04642.
- Battaglia . (2018) Battaglia2018-graph Battaglia, P W., Hamrick, J B., Bapst, V., Sanchez-Gonzalez, A., Zambaldi, V., Malinowski, M. Pascanu, R. 201817 10. Relational inductive biases, deep learning, and graph networks Relational inductive biases, deep learning, and graph networks. arXiv.
- Bernstein . (2018) signSGD Bernstein, J., Wang, Y X., Azizzadenesheli, K. Anandkumar, A. 2018. signSGD: Compressed optimisation for non-convex problems signsgd: Compressed optimisation for non-convex problems. International Conference on Machine Learning International conference on machine learning ( 560–569).
- Biehl Schwarze (1995) biehl1995learning_online_gd Biehl, M. Schwarze, H. 1995. Learning by on-line gradient descent Learning by on-line gradient descent. Journal of Physics A: Mathematical and general283643.
- Boix-Adsera . (2023) boix_abbe Boix-Adsera, E., Saremi, O., Abbe, E., Bengio, S., Littwin, E. Susskind, J. 2023. When can transformers reason with abstract symbols? When can transformers reason with abstract symbols? arXiv preprint arXiv:2310.09753.
- Bordelon Pehlevan (2022) bordelon_self_consistent Bordelon, B. Pehlevan, C. 2022. Self-consistent dynamical field theory of kernel evolution in wide neural networks Self-consistent dynamical field theory of kernel evolution in wide neural networks. Advances in Neural Information Processing Systems3532240–32256.
- Bubeck . (2023) Bubeck2023_sparks_of_agi Bubeck, S., Chandrasekaran, V., Eldan, R., Gehrke, J., Horvitz, E., Kamar, E. Zhang, Y. 202313 04. Sparks of Artificial General Intelligence: Early experiments with GPT-4 Sparks of artificial general intelligence: Early experiments with GPT-4. arXiv.
- Carstensen Frank (2021) carstensen_graded_abs Carstensen, A. Frank, M C. 2021. Do graded representations support abstract thought? Do graded representations support abstract thought? Current Opinion in Behavioral Sciences3790–97.
- Chizat Bach (2020) chizat_bach_max_margin Chizat, L. Bach, F. 2020. Implicit bias of gradient descent for wide two-layer neural networks trained with the logistic loss Implicit bias of gradient descent for wide two-layer neural networks trained with the logistic loss. Conference on learning theory Conference on learning theory ( 1305–1338).
- Chizat . (2019) chizat2019lazy_rich Chizat, L., Oyallon, E. Bach, F. 2019. On lazy training in differentiable programming On lazy training in differentiable programming. Advances in neural information processing systems32.
- Cho Saul (2009) cho_saul_kernel Cho, Y. Saul, L. 2009. Kernel methods for deep learning Kernel methods for deep learning. Advances in neural information processing systems22.
- Farrell . (2023) Farrell2023-lazy_rich Farrell, M., Recanatesi, S. Shea-Brown, E. 20231 12. From lazy to rich to exclusive task representations in neural networks and neural codes From lazy to rich to exclusive task representations in neural networks and neural codes. Curr. Opin. Neurobiol.83102780102780.
- Flesch . (2022) Flesch2022-orthogonal Flesch, T., Juechems, K., Dumbalska, T., Saxe, A. Summerfield, C. 20226 04. Orthogonal representations for robust context-dependent task performance in brains and neural networks Orthogonal representations for robust context-dependent task performance in brains and neural networks. Neuron11071258–1270.e11.
- Fleuret . (2011) fleuret2011svrt Fleuret, F., Li, T., Dubout, C., Wampler, E K., Yantis, S. Geman, D. 2011. Comparing machines and humans on a visual categorization test Comparing machines and humans on a visual categorization test. Proceedings of the National Academy of Sciences1084317621–17625.
- Fodor Pylyshyn (1988) fodor_and_pylyshyn Fodor, J A. Pylyshyn, Z W. 1988. Connectionism and cognitive architecture: A critical analysis Connectionism and cognitive architecture: A critical analysis. Cognition281-23–71.
- A. Geiger . (2023) geiger_nonsym Geiger, A., Carstensen, A., Frank, M C. Potts, C. 2023. Relational reasoning and generalization using nonsymbolic neural networks. Relational reasoning and generalization using nonsymbolic neural networks. Psychological Review1302308.
- M. Geiger . (2020) geiger2020disentangling_feature_and_lazy Geiger, M., Spigler, S., Jacot, A. Wyart, M. 2020. Disentangling feature and lazy training in deep neural networks Disentangling feature and lazy training in deep neural networks. Journal of Statistical Mechanics: Theory and Experiment202011113301.
- Giurfa . (2001) Giurfa2001-bee Giurfa, M., Zhang, S., Jenett, A., Menzel, R. Srinivasan, M V. 200119 04. The concepts of ’sameness’ and ’difference’ in an insect The concepts of ’sameness’ and ’difference’ in an insect. Nature4106831930–933.
- Goldt . (2019) goldt2019dynamics Goldt, S., Advani, M., Saxe, A M., Krzakala, F. Zdeborová, L. 2019. Dynamics of stochastic gradient descent for two-layer neural networks in the teacher-student setup Dynamics of stochastic gradient descent for two-layer neural networks in the teacher-student setup. Advances in neural information processing systems32.
- Gülçehre Bengio (2016) gulccehre2016pentomino Gülçehre, Ç. Bengio, Y. 2016. Knowledge matters: Importance of prior information for optimization Knowledge matters: Importance of prior information for optimization. The Journal of Machine Learning Research171226–257.
- Guo . (2025) guo2025deepseek_r1 Guo, D., Yang, D., Zhang, H., Song, J., Zhang, R., Xu, R. others 2025. Deepseek-r1: Incentivizing reasoning capability in llms via reinforcement learning Deepseek-r1: Incentivizing reasoning capability in llms via reinforcement learning. arXiv preprint arXiv:2501.12948.
- Ito Murray (2023) Ito2023-multitask Ito, T. Murray, J D. 2023 02. Multitask representations in the human cortex transform along a sensory-to-motor hierarchy Multitask representations in the human cortex transform along a sensory-to-motor hierarchy. Nat. Neurosci.262306–315.
- Jacot . (2018) jacot_ntk Jacot, A., Gabriel, F. Hongler, C. 2018. Neural tangent kernel: Convergence and generalization in neural networks Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems31.
- James (1905) james1905principles_of_psych James, W. 1905. The Principles of Psychology The principles of psychology. New York,: H. Holt.
- Johnston Fusi (2023) johnston2023abstract_rep Johnston, W J. Fusi, S. 2023. Abstract representations emerge naturally in neural networks trained to perform multiple tasks Abstract representations emerge naturally in neural networks trained to perform multiple tasks. Nature Communications1411040.
- Kim . (2018) Kim2018-not_so_clevr Kim, J., Ricci, M. Serre, T. 201815 06. Not-So-CLEVR: learning same–different relations strains feedforward neural networks Not-so-CLEVR: learning same–different relations strains feedforward neural networks. Interface Focus8420180011.
- Kingma (2014) kingma2014adam Kingma, D P. 2014. Adam: A method for stochastic optimization Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
- Krizhevsky Hinton (2009) krizhevsky2009cifar100 Krizhevsky, A. Hinton, G. 2009. Learning multiple layers of features from tiny images Learning multiple layers of features from tiny images. Technical report.
- Mahowald . (2024) Mahowald2024-dissociate Mahowald, K., Ivanova, A A., Blank, I A., Kanwisher, N., Tenenbaum, J B. Fedorenko, E. 20241 06. Dissociating language and thought in large language models Dissociating language and thought in large language models. Trends Cogn. Sci.286517–540.
- G. Marcus (2020) Marcus2020-next_decade Marcus, G. 202019 02. The Next Decade in AI: Four Steps Towards Robust Artificial Intelligence The next decade in AI: Four steps towards robust artificial intelligence. arXiv.
- G F. Marcus (1998) marcus1998rethinking Marcus, G F. 1998. Rethinking eliminative connectionism Rethinking eliminative connectionism. Cognitive psychology373243–282.
- G F. Marcus (2003) Marcus2003-algebraic Marcus, G F. 2003. The algebraic mind: Integrating connectionism and cognitive science The algebraic mind: Integrating connectionism and cognitive science. Cambridge, MABradford Books.
- G F. Marcus . (1999) Marcus1999-rule_learning Marcus, G F., Vijayan, S., Bandi Rao, S. Vishton, P M. 19991 01. Rule learning by seven-month-old infants Rule learning by seven-month-old infants. Science283539877–80.
- Mastrovito . (2024) Mastrovito2024-consciousness Mastrovito, D., Liu, Y H., Kusmierz, L., Shea-Brown, E., Koch, C. Mihalas, S. 202415 05. Transition to chaos separates learning regimes and relates to measure of consciousness in recurrent neural networks Transition to chaos separates learning regimes and relates to measure of consciousness in recurrent neural networks. bioRxivorg2024.05.15.594236.
- McCoy . (2023) McCoy2023-embers McCoy, R T., Yao, S., Friedman, D., Hardy, M. Griffiths, T L. 202324 09. Embers of Autoregression: Understanding Large Language Models Through the Problem They are Trained to Solve Embers of autoregression: Understanding large language models through the problem they are trained to solve. arXiv.
- Mei . (2018) mei2018mean_field Mei, S., Montanari, A. Nguyen, P M. 2018. A mean field view of the landscape of two-layer neural networks A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences11533E7665–E7671.
- Morwani . (2023) morwani_max_margin Morwani, D., Edelman, B L., Oncescu, C A., Zhao, R. Kakade, S. 2023. Feature emergence via margin maximization: case studies in algebraic tasks Feature emergence via margin maximization: case studies in algebraic tasks. arXiv preprint arXiv:2311.07568.
- Obozova . (2015) Obozova2015-parrot Obozova, T., Smirnova, A., Zorina, Z. Wasserman, E. 201518 11. Analogical reasoning in amazons Analogical reasoning in amazons. Anim. Cogn.1861363–1371.
- OpenAI (2023) gpt4 OpenAI. 202315 03. GPT-4 Technical Report GPT-4 technical report. arXiv.
- OpenAI (2024) o1 OpenAI. 2024. Learning to Reason with LLMs. Learning to reason with LLMs. https://openai.com/index/learning-to-reason-with-llms/.
- Penn . (2008) Penn2008-darwin_mistake Penn, D C., Holyoak, K J. Povinelli, D J. 2008 04. Darwin’s mistake: Explaining the discontinuity between human and nonhuman minds Darwin’s mistake: Explaining the discontinuity between human and nonhuman minds. Behav. Brain Sci.312109–130.
- Premack (1983) Premack1983-codes Premack, D. 1983 03. The codes of man and beasts The codes of man and beasts. Behav. Brain Sci.61125–136.
- Rabagliati . (2019) Rabagliati2019-infant_abstract_rule_learning Rabagliati, H., Ferguson, B. Lew-Williams, C. 20191 01. The profile of abstract rule learning in infancy: Meta-analytic and experimental evidence The profile of abstract rule learning in infancy: Meta-analytic and experimental evidence. Dev. Sci.221e12704.
- Raven (2003) raven2003raven_prog_mats Raven, J. 2003. Raven progressive matrices Raven progressive matrices. Handbook of nonverbal assessment Handbook of nonverbal assessment ( 223–237). Springer.
- Rotskoff Vanden-Eijnden (2022) rotskoff2022trainability Rotskoff, G. Vanden-Eijnden, E. 2022. Trainability and accuracy of artificial neural networks: An interacting particle system approach Trainability and accuracy of artificial neural networks: An interacting particle system approach. Communications on Pure and Applied Mathematics7591889–1935.
- Saad Solla (1995) saad_learning_soft_comm Saad, D. Solla, S A. 1995. On-line learning in soft committee machines On-line learning in soft committee machines. Physical Review E5244225.
- Saffran Thiessen (2003) Saffran2003-infant Saffran, J R. Thiessen, E D. 2003 05. Pattern induction by infant language learners Pattern induction by infant language learners. Dev. Psychol.393484–494.
- Santoro . (2017) Santoro2017-relational_module Santoro, A., Raposo, D., Barrett, D G T., Malinowski, M., Pascanu, R., Battaglia, P. Lillicrap, T. 20175 06. A simple neural network module for relational reasoning A simple neural network module for relational reasoning. arXiv.
- Seidenberg . (1999) Seidenberg1999-infants_grammar Seidenberg, M S., Elman, J., Eimas, P D., M, N. Marcus, G F. 199916 04. Do infants learn grammar with algebra or statistics? Do infants learn grammar with algebra or statistics? Science2845413434–5; author reply 436–7.
- Seidenberg Elman (1999) Seidenberg1999-rules Seidenberg, M S. Elman, J L. 19991 08. Networks are not ’hidden rules’ Networks are not ’hidden rules’. Trends Cogn. Sci.38288–289.
- Simonyan Zisserman (2014) simonyan2014vgg Simonyan, K. Zisserman, A. 2014. Very deep convolutional networks for large-scale image recognition Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556.
- Smirnova . (2015) Smirnova2015-crow Smirnova, A., Zorina, Z., Obozova, T. Wasserman, E. 201519 01. Crows spontaneously exhibit analogical reasoning Crows spontaneously exhibit analogical reasoning. Curr. Biol.252256–260.
- Stabinger . (2021) Stabinger2021-evaluating Stabinger, S., Peer, D., Piater, J. Rodríguez-Sánchez, A. 20215 10. Evaluating the progress of deep learning for visual relational concepts Evaluating the progress of deep learning for visual relational concepts. J. Vis.21118.
- Tong Pehlevan (2024) Tong2024-mlps Tong, W L. Pehlevan, C. 202424 05. MLPs learn in-context on regression and classification tasks MLPs learn in-context on regression and classification tasks. arXiv.
- Ullman (2023) Ullman2023-fail Ullman, T. 202316 02. Large language models fail on trivial alterations to Theory-of-Mind tasks Large language models fail on trivial alterations to theory-of-mind tasks. arXiv.
- Vaishnav . (2022) Vaishnav2022-computational_demands Vaishnav, M., Cadene, R., Alamia, A., Linsley, D., VanRullen, R. Serre, T. 202215 04. Understanding the computational demands underlying visual reasoning Understanding the computational demands underlying visual reasoning. Neural Comput.3451075–1099.
- Vonk (2003) Vonk2003-primate Vonk, J. 20031 06. Gorilla ( Gorilla gorilla gorilla) and orangutan ( Pongo abelii) understanding of first- and second-order relations Gorilla ( gorilla gorilla gorilla) and orangutan ( pongo abelii) understanding of first- and second-order relations. Anim. Cogn.6277–86.
- E. Wasserman . (2017) Wasserman2017-perceptual_to_conceptual Wasserman, E., Castro, L. Fagot, J. 2017. Relational thinking in animals and humans: From percepts to concepts Relational thinking in animals and humans: From percepts to concepts. APA handbook of comparative psychology: Perception, learning, and cognition Apa handbook of comparative psychology: Perception, learning, and cognition ( 359–384). WashingtonAmerican Psychological Association.
- E A. Wasserman Young (2010) Wasserman2010_same_diff Wasserman, E A. Young, M E. 2010 01. Same-different discrimination: the keel and backbone of thought and reasoning Same-different discrimination: the keel and backbone of thought and reasoning. J. Exp. Psychol. Anim. Behav. Process.3613–22.
- Webb . (2023) Webb2023-relational_bottleneck Webb, T W., Frankland, S M., Altabaa, A., Krishnamurthy, K., Campbell, D., Russin, J. Cohen, J D. 202312 09. The Relational Bottleneck as an Inductive Bias for Efficient Abstraction The relational bottleneck as an inductive bias for efficient abstraction. arXiv.
- Webb . (2020) Webb2020-emergent Webb, T W., Sinha, I. Cohen, J D. 202028 12. Emergent symbols through binding in external memory Emergent symbols through binding in external memory. arXiv.
- Wei . (2019) wei_ma_max_margin Wei, C., Lee, J D., Liu, Q. Ma, T. 2019. Regularization matters: Generalization and optimization of neural nets vs their induced kernel Regularization matters: Generalization and optimization of neural nets vs their induced kernel. Advances in Neural Information Processing Systems32.
- Woodworth . (2020) woodworth2020kernel Woodworth, B., Gunasekar, S., Lee, J D., Moroshko, E., Savarese, P., Golan, I. Srebro, N. 2020. Kernel and rich regimes in overparametrized models Kernel and rich regimes in overparametrized models. Conference on Learning Theory Conference on learning theory ( 3635–3673).
- Yang Hu (2021) yang2021feature_learning Yang, G. Hu, E J. 2021. Tensor programs iv: Feature learning in infinite-width neural networks Tensor programs iv: Feature learning in infinite-width neural networks. International Conference on Machine Learning International conference on machine learning ( 11727–11737).
- Yang . (2022) Yang2022-mup_transfer Yang, G., Hu, E J., Babuschkin, I., Sidor, S., Liu, X., Farhi, D. Gao, J. 20227 03. Tensor programs V: Tuning large neural networks via zero-shot hyperparameter transfer Tensor programs V: Tuning large neural networks via zero-shot hyperparameter transfer. arXiv.
## Appendix
## Appendix A Conceptual representations in visual same-different
<details>
<summary>x4.png Details</summary>

### Visual Description
\n
## Heatmaps & Scatter Plots: Regime-Dependent Feature Representations
### Overview
The image presents a series of heatmaps and scatter plots examining feature representations in different regimes (rich vs. lazy). The heatmaps visualize correlation matrices for features extracted from two datasets (PSVRT and Pentomino) under varying parameter 'a'. The scatter plots show the relationship between the variance of feature representations (v<sub>f</sub><sup>2</sup>) and the parameter 'a' for the CIFAR-100 dataset.
### Components/Axes
The image is organized into a 2x3 grid.
* **Top Row (a & b):** PSVRT dataset.
* **a:** Heatmap with α = -5.76, scale from -1.0 to 1.0.
* **b:** Heatmap with α = 3.36, scale from -1.0 to 1.0.
* **Second Row (c & d):** Pentomino dataset.
* **c:** Heatmap with α = -12.97, scale from -4.0 to 4.0.
* **d:** Heatmap with α = 7.74, scale from -1.0 to 1.0.
* **Bottom Row (e & f):** CIFAR-100 dataset.
* **e:** Scatter plot with x-axis labeled 'a<sub>i</sub>' ranging from approximately -25 to 25, and y-axis labeled '(v<sub>f</sub><sup>2</sup>)/μ<sub>i</sub>' ranging from approximately -1 to 1.
* **f:** Scatter plot with x-axis labeled 'a<sub>i</sub>' ranging from approximately -0.05 to 0.05, and y-axis labeled '(v<sub>f</sub><sup>2</sup>)/μ<sub>i</sub>' ranging from approximately -1 to 1.
The image also includes two overarching titles: "Rich regime (γ = 1)" above the left column (a, c, e) and "Lazy regime (γ ≈ 0)" above the right column (b, d, f).
### Detailed Analysis or Content Details
**Heatmaps (PSVRT):**
* **a (α = -5.76):** The heatmap shows strong positive and negative correlations. The color scale ranges from -1.0 (dark blue) to 1.0 (dark red). There are several blocks of strong positive correlation (dark red) and strong negative correlation (dark blue).
* **b (α = 3.36):** The heatmap shows weaker correlations compared to 'a'. The color scale ranges from -1.0 (dark blue) to 1.0 (dark red). The correlations are more diffuse, with less distinct blocks of strong positive or negative correlation.
**Heatmaps (Pentomino):**
* **c (α = -12.97):** This heatmap exhibits the most extreme correlations, with a scale ranging from -4.0 to 4.0. There are large blocks of strong positive and negative correlations.
* **d (α = 7.74):** Similar to the PSVRT heatmap with α = 3.36, this heatmap shows weaker and more diffuse correlations, with a scale ranging from -1.0 to 1.0.
**Scatter Plots (CIFAR-100):**
* **e:** The scatter plot shows a clear trend: as 'a<sub>i</sub>' increases from approximately -25 to 25, (v<sub>f</sub><sup>2</sup>)/μ<sub>i</sub> initially increases rapidly, reaches a peak around a<sub>i</sub> = 20, and then plateaus. The data points are densely clustered, forming a curved shape.
* **f:** The scatter plot shows a much more limited range of 'a<sub>i</sub>' values (approximately -0.05 to 0.05). The data points are scattered around the origin, with a slight upward trend. The variance is much lower than in plot 'e'.
### Key Observations
* The correlation matrices in the "Rich regime" (α < 0) exhibit stronger correlations than those in the "Lazy regime" (α > 0).
* The Pentomino dataset with α = -12.97 shows the most pronounced correlations.
* The CIFAR-100 scatter plot 'e' demonstrates a strong relationship between 'a<sub>i</sub>' and the variance of feature representations, while 'f' shows a much weaker relationship.
* The scale of the heatmaps varies significantly, indicating different magnitudes of correlation in each case.
### Interpretation
The image suggests that the parameter 'a' plays a crucial role in shaping the feature representations learned by the model. In the "Rich regime" (γ = 1), where 'a' is negative, the features are highly correlated, indicating a more structured and potentially redundant representation. The strong correlations in the Pentomino dataset with α = -12.97 suggest that this dataset is particularly sensitive to the value of 'a'.
In the "Lazy regime" (γ ≈ 0), where 'a' is positive, the features are less correlated, indicating a more diverse and potentially less informative representation. The scatter plots for CIFAR-100 further support this interpretation, showing that the variance of feature representations increases with 'a' in the rich regime but remains low in the lazy regime.
The difference in the scales of the heatmaps highlights the varying degrees of correlation across different datasets and parameter values. The outlier in the Pentomino dataset (α = -12.97) suggests that certain datasets may be more susceptible to the effects of the parameter 'a' than others. The plots demonstrate a clear phase transition in the feature representation as the parameter 'a' changes, from a highly structured regime to a more unstructured one.
</details>
Figure A1: Conceptual representations in visual same-different tasks. (a) Visualization of representative hidden weights associated with maximal positive and negative readout weights, for a rich model training on PSVRT. Parallel/antiparallel structures are visible in how regions are either the same or precisely the opposite of their neighbors. (b) Visualization of representative hidden weights associated with maximal positive and negative readouts, for a lazy model training on PSVRT. There is no discernible structure, and the magnitudes of both readouts and hidden weights are small. (c) The same as (a), but for a rich model on the Pentomino task. Parallel/antiparallel structures are likewise visible. (d) The same as (b), but for a lazy model on the Pentomino task. There is no discernible structure, and the magnitudes of both readouts and hidden weights are small. (e) For a rich model trained on the CIFAR-100 task, we plot the alignment between weight components corresponding to the two images. There is a distinct parallel/antiparallel-like structure visible in the weight alignment, as we saw for MLPs trained on our simple SD task. (f) For a lazy model trained on the CIFAR-100 task, we plot the alignment between weight components corresponding to the two images. There are no discernible parallel/antiparallel structures at all.
In Section 4, we experimented with three different visual same-different tasks to validate our theoretical predictions in more complex settings. We found that richer models tend to learn the task with fewer training examples, and display some insensitivity to spurious details. The final signature of a conceptual solution is the presence of conceptual representations. In this appendix, we examine the hidden weights learned by rich and lazy models on these tasks, and present evidence for conceptual representations.
PSVRT
Recall our MLP given in Eq (1). To interrogate the hidden weights $\mathbf{w}_{i}$ for conceptual representations, we reshape the weights to match the input shape and visualize them directly. The results are plotted in Figure A1 a and b. For ease of visualization, this task consists of images that are two patches wide.
The rich model learns an interesting analog of parallel/antiparallel weights. Recall for our MLPs trained on the simple same-different task, weight vectors associated with negative readouts tend to develop antiparalllel components. Weight vectors associated with positive readouts tend to develop parallel components.
We witness a similar development for PSVRT. For the example weights with a negative readout, adjacent patches are exactly the opposite, learning a negative weight where the neighboring patch has a positive weight. This structure mirrors the antiparallel weights learned in the simple same-different task. One difference is that for PSVRT, while two pairs of regions are precisely the opposite, the other two pairs are the same. While it is impossible to have every region become antiparallel to every other region, it is not obvious why two pairs should become parallel despite the negative readout weight.
Meanwhile, the example weights with a positive readout feature identical patches, matching weights exactly across the four regions of the input. These parallel regions are exactly what we would expect from our consideration of the simple same-different task.
In the lazy regime, the model learns no discernible structure. The magnitudes of both the readout and hidden weights are also significantly smaller. Altogether, the existence of parallel/antiparallel analogs for PSVRT strongly suggests that only the rich model has learned conceptual representations.
Pentomino
We perform the same analysis of the hidden weights $\mathbf{w}_{i}$ for the Pentomino task. The results are visualized in Figure A1 c and d. For ease of visualization, this task consists of images that are two patches wide.
As we saw for PSVRT, the rich model learns analogs of parallel/antiparallel weights. For the example weights corresponding to a negative readout, the top regions are precisely the opposite of the bottom regions, suggestive of the antiparallel weight components we characterized in the simple same-different task. For example weights corresponding to a positive readout, all four regions are the same, suggestive of parallel weight components.
These structure emerge only in the rich regime. For the lazy regime model, no discernible structure is learned. The overall magnitudes of the readouts and hidden weights are also much smaller. Altogether, the existence of parallel/antiparallel analogs for Penotmino strongly suggests only the rich model has learned conceptual representations.
### A.1 CIFAR-100
For the CIFAR-100 task, we visualize the hidden weights in the same way as we did for the simple same-different task in Figure 1. We separate the weight vector $\mathbf{w}_{i}$ into two components, corresponding to the two flattened input images, and measure their alignment. The results are plotted in Figure A1 e and f.
For the rich case, alignment associated with negative readouts tends to be negative, and alignment associated with positive readouts tends to be positive, suggestive of the right parallel/antiparallel structure. The alignment is quite similar to what we saw for the rich model on the simple same-different task, though the antiparallel alignment is not as strong. The lazy model shows no apparent correlation at all between readouts and alignment. Altogether, the relationship between readouts and alignment witnessed only in the rich model strongly suggests only the rich model has learned conceptual representations.
## Appendix B Hand-crafted solution details
We outline in full detail how our hand-crafted solution solves the same-different task. Recall that the hand-crafted solution is given by the following weight configuration:
| | $\displaystyle\mathbf{w}_{1}^{+}$ | $\displaystyle=(\mathbf{1};\mathbf{1})\,,$ | |
| --- | --- | --- | --- |
for $\mathbf{1}=(1,1,\ldots,1)\in\mathbb{R}^{d}$ and some $\rho>0$ . The MLP is given by
| | $\displaystyle f(\mathbf{x})=\,\,$ | $\displaystyle a^{+}\big{(}\phi(\mathbf{w}_{1}^{+}\cdot\mathbf{x})+\phi(\mathbf{w}_{2}^{+}\cdot\mathbf{x})\big{)}$ | |
| --- | --- | --- | --- |
Upon receiving a same example $\mathbf{x}^{+}=(\mathbf{z},\mathbf{z})$ , our model returns
| | $\displaystyle f(\mathbf{x}^{+})$ | $\displaystyle=\phi(\mathbf{1}\cdot\mathbf{z}+\mathbf{1}\cdot\mathbf{z})+\phi(-\mathbf{1}\cdot\mathbf{z}-\mathbf{1}\cdot\mathbf{z})$ | |
| --- | --- | --- | --- |
which is certainly a positive quantity. Hence, the model classifies all positive examples correctly.
Upon receiving a different example $\mathbf{x}^{-}=(\mathbf{z},\mathbf{z}^{\prime})$ , our model returns
| | $\displaystyle f(\mathbf{x}^{-})$ | $\displaystyle=\phi(\mathbf{1}\cdot\mathbf{z}+\mathbf{1}\cdot\mathbf{z}^{\prime})+\phi(-\mathbf{1}\cdot\mathbf{z}-\mathbf{1}\cdot\mathbf{z}^{\prime})$ | |
| --- | --- | --- | --- |
Since training symbols are sampled as $\mathbf{z}\sim\mathcal{N}(0,\mathbf{I}/d)$ , we have that $\mathbf{z}\overset{d}{=}-\mathbf{z}$ . Furthermore, a sum of independent Gaussians remains Gaussian, so $\mathbf{1}\cdot\mathbf{z}\pm 1\cdot\mathbf{z}^{\prime}\sim\mathcal{N}(0,2)$ . Hence, $f(\mathbf{x}^{-})\overset{d}{=}u-\rho v$ , where $u,v\sim\text{HalfNormal}(0,2)$ . Note, the ReLU nonlinearity ensures these quantities are distributed along a Half-Normal distribution, rather than a Gaussian. Further, $u$ and $v$ are independent since $\mathbf{1}\cdot\mathbf{z}+\mathbf{1}\cdot\mathbf{z}^{\prime}$ is independent from $\mathbf{1}\cdot\mathbf{z}-\mathbf{1}\cdot\mathbf{z}^{\prime}$ (the two sums are jointly Gaussian with zero covariance).
The test accuracy of the model on $\mathbf{x}^{-}$ is given by $p(f(\mathbf{x}^{-})<0)$ , which can be expressed as an integral over the joint PDF of $u,v$ :
| | $\displaystyle p(f(\mathbf{x}^{-})<0)$ | $\displaystyle=p(u-\rho v<0)$ | |
| --- | --- | --- | --- |
To compute this quantity, we convert to polar coordinates. Let $u=r\cos(\theta)$ and $v=r\sin(\theta)$ . Under this change of variables, we have
| | $\displaystyle p(f(\mathbf{x}^{-})<0)$ | $\displaystyle=\frac{1}{2\pi}\int_{\tan^{-1}(1/\rho)}^{\pi/2}\int_{0}^{\infty}r\exp e^{-r^{2}/8}\,dr\,d\theta$ | |
| --- | --- | --- | --- |
For $\rho\rightarrow\infty$ , this quantity approaches $1$ . Since $\rho$ cancels in the result for the same input, $\rho$ can be arbitrarily large without impacting the classification accuracy on same inputs. Hence, the hand-crafted solution overall solves the same-different task provided the relative magnitude of the negative readouts is large.
A technical detail required for both successful positive and negative classifications is that the test example is not precisely orthogonal to the parallel/antiparallel vectors, in which case the relevant dot products would be zero. However, a test example is exactly orthogonal to the weight vectors with probability zero, so this eventuality does not impact the solution’s overall test accuracy.
Figure B1 illustrates how parallel/antiparallel weight vectors may correctly classify a same or different example.
<details>
<summary>x5.png Details</summary>

### Visual Description
\n
## Mathematical Expression: Vector Analysis
### Overview
The image presents a mathematical derivation involving vector operations, specifically dot products, and inequalities. It defines two vectors, `w+` and `w-`, and then demonstrates how their dot products with different vectors `x` result in positive and negative values, respectively. The vectors are represented visually with arrows.
### Components/Axes
There are no explicit axes or scales in this image. The components are mathematical expressions and vector representations. The key elements are:
* `w+ = (↑, ↓)`: A vector defined as having an upward and downward component.
* `w- = (↑, ↓)`: A vector defined as having an upward and downward component.
* `x = (↗, ↙)`: A vector defined as having a northeast and southwest component.
* `x = (↖, ↗)`: A vector defined as having a northwest and northeast component.
* Dot product notation: `w+.x - w-.x` and `w+.x - ρw-.x`
* Inequality signs: `> 0` and `< 0`
* Vector arrows: Representing direction and potentially magnitude.
### Detailed Analysis or Content Details
The image shows two separate calculations:
**Calculation 1:**
* `w+ = (↑, ↓)` and `w- = (↑, ↓)` are defined.
* `x = (↗, ↙)` is defined.
* `w+.x - w-.x = (↗ + ↙) - (↖ + ↘)`
* The result of the dot product is `> 0`. The arrows visually represent the vector addition and subtraction. The upward-pointing red arrow and downward-pointing red arrow are added, and the upward-pointing blue arrow and downward-pointing blue arrow are subtracted.
**Calculation 2:**
* `w+ = (↑, ↓)` and `w- = (↑, ↓)` are defined.
* `x = (↖, ↗)` is defined.
* `w+.x - ρw-.x = (↖ + ↗) - (↙ + ↘)` where ρ (rho) is a scalar.
* The result of the dot product is `< 0`. The upward-pointing red arrow and downward-pointing red arrow are added, and the upward-pointing blue arrow and downward-pointing blue arrow are subtracted.
The colors used are:
* Red: Represents the `w+` vector components.
* Blue: Represents the `w-` vector components.
### Key Observations
* The vectors `w+` and `w-` are identical in both calculations.
* The vector `x` changes in each calculation, leading to different results.
* The first calculation results in a positive value, while the second results in a negative value.
* The introduction of the scalar `ρ` (rho) in the second calculation suggests a scaling factor applied to the `w-` vector.
* The visual representation of the dot product using arrows is a key element in understanding the calculations.
### Interpretation
This image demonstrates a concept in vector analysis, likely related to determining the sign of a dot product. The choice of vectors `x` and the scalar `ρ` are crucial in determining whether the result is positive or negative. The visual representation helps to understand how the components of the vectors contribute to the overall dot product.
The first calculation shows that when `x` is oriented in a way that aligns more with `w+` than `w-`, the dot product is positive. The second calculation, with a different `x` and the scaling factor `ρ`, shows that the dot product can be negative under different conditions.
This could be related to concepts like:
* **Classification:** Determining which side of a hyperplane a point lies on.
* **Support Vector Machines (SVMs):** The sign of the dot product is used to determine the margin.
* **Eigenvalue problems:** The sign of the eigenvalue can be related to the stability of a system.
The image is a concise illustration of a fundamental principle in linear algebra and its applications. The use of both mathematical notation and visual representation enhances understanding.
</details>
Figure B1: Illustration of hand-crafted solution. Parallel/antiparallel weight vectors $\mathbf{w}^{+},\mathbf{w}^{-}$ are represented pictorially as sets of two vectors. The dot product operation is represented by conjoining the corresponding vectors: the dot product equals the cosine angle scaled by the magnitudes of the component vectors. For a same test example, the $\mathbf{w}^{+}\cdot\mathbf{x}$ remains positive while $\mathbf{w}^{-}\cdot\mathbf{x}$ cancels to zero. For a different test example, the relative magnitude $\rho$ enables a successful negative classification.
## Appendix C Rich regime details
We conduct our analysis of the rich regime in two parts. We begin with a derivation of the max margin solution to our same-different task in Section C.1. Doing so requires us to replace our model’s ReLU activations with quadratic activations. The max margin solution also does not demonstrate the rich model’s learning efficiency or insensitivity to perceptual details. To address these shortcomings, we extend our analysis by considering a heuristic construction in which we approximate a rich MLP using an ensemble of independent Markov processes (Section C.2). Using this construction, we derive a finer-grain characterization of the MLP’s weight structure, and apply it to estimate the model’s test accuracy for varying $L$ and $d$ .
### C.1 Max margin solution
An MLP trained on a classification objective often learns a max margin solution over the dataset (Morwani ., 2023; Chizat Bach, 2020; Wei ., 2019). While this outcome is not guaranteed in our setting, studying the structure of the max margin solution nonetheless reveals critical details about how our MLP may be solving the same-different task. Following Morwani . (2023), we adopt two conditions to expedite our analysis:
1. We replace a strict max margin objective with a max average margin objective over a dataset $\mathcal{D}=\left\{\mathbf{x}_{n},y_{n}\right\}_{n=1}^{P}$
$$
\max_{\mathbf{\theta}}\,\frac{1}{P}\sum_{n=1}^{P}\left[(2y_{n}-1)f(\mathbf{x}_{n};\mathbf{\theta})\right]\,,
$$
where $\mathbf{x}_{n},y_{n}$ are sampled over a training distribution with $L$ symbols and the objective is subject to some norm constraint on $\mathbf{\theta}$ . Given the symmetry of the task, a max average margin objective forms a reasonable proxy to the strict max margin.
1. We consider quadratic activations $\phi(\cdot)=(\cdot)^{2}$ rather than ReLU. Doing so alters our model from Eq (1), but we later use a heuristic construction to argue that the resulting solution is recovered under ReLU activations in a rich learning regime.
We further allow $P,L\rightarrow\infty$ . Following these simplifications, we derive the max average margin solution.
**Theorem 1**
*Let $\mathcal{D}=\{\mathbf{x}_{n},y_{n}\}_{n=1}^{P}$ be a training set consisting of $P$ points sampled across $L$ training symbols, as specified in Section 2. Let $f$ be the MLP given by Eq 1, with two changes:
1. Fix the readouts $a_{i}=\pm 1$ , where exactly $m/2$ readouts are positive and the remaining are negative.
1. Use quadratic activations $\phi(\cdot)=(\cdot)^{2}$ .
For weights $\mathbf{\theta}=\left\{\mathbf{w}_{i}\right\}_{i=1}^{m}$ , define the max margin set $\Delta(\mathbf{\theta})$ to be
$$
\Delta(\mathbf{\theta})=\operatorname*{arg\,max}_{\mathbf{\theta}}\frac{1}{P}\sum_{n=1}^{P}\left[(2y_{n}-1)f(\mathbf{x}_{n};\mathbf{\theta})\right]\,,
$$
subject to the norm constraints $\left|\left|\mathbf{w}_{i}\right|\right|=1$ . If $P,L\rightarrow\infty$ , then for any $\mathbf{w}_{i}=(\mathbf{v}_{i}^{1};\mathbf{v}_{i}^{2})\in\Delta(\mathbf{\theta})$ and $\ell_{i}=\left|\left|\mathbf{v}_{i}^{1}\right|\right|\,\left|\left|\mathbf{v}_{i}^{2}\right|\right|$ , we have that $\mathbf{v}_{i}^{1}\cdot\mathbf{v}_{i}^{2}/\ell_{i}=1$ if $a_{i}=1$ and $\mathbf{v}_{i}^{1}\cdot\mathbf{v}_{i}^{2}/\ell_{i}=-1$ if $a_{i}=-1$ . Further, $\left|\left|\mathbf{v}_{i}^{1}\right|\right|=\left|\left|\mathbf{v}_{i}^{2}\right|\right|$ .*
* Proof*
Let $\mathcal{D}^{+}$ be the subset of $\mathcal{D}$ consisting of same examples and $\mathcal{D}^{-}$ be the subset of different examples. Let $\mathcal{I}^{+}$ be the set of indices $i$ such that the readout weight $a_{i}>0$ . Let $\mathcal{I}^{-}$ be the set of indices $j$ such that $a_{j}<0$ . Then our max average margin solution becomes | | $\displaystyle\max_{\mathbf{\theta}}\,\frac{1}{P}\sum_{n=1}^{P}\left[(2y_{n}-1)f(\mathbf{x}_{n};\mathbf{\theta})\right]$ | $\displaystyle=\max_{\mathbf{\theta}}\,\frac{1}{P}\left[\sum_{\mathbf{x}^{+}\in\mathcal{D}^{+}}\left[f(\mathbf{x}^{+};\mathbf{\theta})\right]-\sum_{\mathbf{x}^{-}\in\mathcal{D}^{-}}\left[f(\mathbf{x}^{-};\mathbf{\theta})\right]\right]$ | |
| --- | --- | --- | --- | Suppose we stack all same training examples $\mathbf{x}^{+}$ into a large matrix $\mathbf{X}^{+}\in\mathbb{R}^{|\mathcal{D}^{+}|\times 2d}$ and stack all different training examples $\mathbf{x}^{-}$ into a large matrix $\mathbf{X}^{-}\in\mathbb{R}^{|\mathcal{D}^{-}|\times 2d}$ . Applying the norm constraints $\left|\left|\mathbf{w}_{i}\right|\right|=1$ , our max margin solution is resolved by the following objectives
| | $\displaystyle\mathbf{w}_{*}^{+}=$ | $\displaystyle\operatorname*{arg\,max}_{\mathbf{w}}\,\frac{1}{P}\left[\left|\left|\mathbf{X}^{+}\mathbf{w}\right|\right|^{2}-\left|\left|\mathbf{X}^{-}\mathbf{w}\right|\right|^{2}\right]\quad\text{such that}\,\left|\left|\mathbf{w}\right|\right|=1\,,$ | |
| --- | --- | --- | --- |
where $\mathbf{w}_{*}^{+}$ and $\mathbf{w}_{*}^{-}$ represent the hidden weights of the max average margin solution. Maximizing (or minimizing) this objective is equivalent to finding the largest (or smallest) eigenvector of the matrix $\mathbf{X}=\frac{1}{P}\left[\left(\mathbf{X}^{+}\right)^{\intercal}\mathbf{X}^{+}-\left(\mathbf{X}^{-}\right)^{\intercal}\mathbf{X}^{-}\right]$ . In the limit $P,L\rightarrow\infty$ , this matrix becomes circulant. Let us see how. Note that $\mathbf{X}\in\mathbb{R}^{2d\times 2d}$ . Suppose there are exactly $P/2$ same examples and $P/2$ different examples. Along the diagonal of $\mathbf{X}$ are terms
$$
X_{ii}=\frac{1}{P}\sum_{j=1}^{P/2}\left[(x_{ij}^{+})^{2}-(x_{ij}^{-})^{2}\right]\,,
$$
where $x_{ij}^{+}$ corresponds to the $i$ th index of the $j$ th same example, and $x_{ij}^{-}$ is the same for the $j$ th different example. Because $L\rightarrow\infty$ , we have that $x_{i,j}^{+}\sim\mathcal{N}(0,1/d)$ and $x_{i,j}^{-}\sim\mathcal{N}(0,1/d)$ , where $x_{i,j}^{+}$ is independent of $x_{i,j}^{-}$ . Hence, $X_{ii}\rightarrow 0$ as $P\rightarrow\infty$ . Now let us consider the diagonal of the first quadrant
$$
X_{i,2i}=\frac{1}{P}\sum_{j=1}^{P/2}\left[x_{ij}^{+}\,x_{2i,j}^{+}-x_{ij}^{-}\,x_{2i,j}^{-}\right]\,.
$$
For same examples, $x_{ij}^{+}=x_{2i,j}^{+}$ , so
$$
\frac{1}{P}\sum_{j=1}^{P/2}x_{ij}^{+}x_{2i,j}^{+}=\frac{1}{P}\sum_{j=1}^{P/2}\left(x_{ij}^{+}\right)^{2}\rightarrow\frac{1}{2}\mathbb{E}\left[\left(x_{ij}^{+}\right)^{2}\right]=\frac{1}{2d}
$$
For different examples, $x_{ij}^{-}$ remains independent of $x_{2i,j}^{-}$ , so
$$
\frac{1}{P}\sum_{j=1}^{P/2}x_{ij}^{+}x_{2i,j}^{+}\rightarrow 0\,.
$$
We therefore have overall that $X_{i,2i}\rightarrow 1/2d$ . The same argument applies for the diagonal of the third quadrant, revealing that $X_{2i,i}\rightarrow 1/2d$ . For all other terms $X_{ik}$ where $i\neq k$ , $i\neq k/2$ , and $i/2\neq k$ , we must have that $x_{ij}$ is independent of $x_{kj}$ for both same and different examples, so $X_{ik}\rightarrow 0$ . Hence, $\mathbf{X}$ is circulant with nonzero values $1/2d$ only on the diagonals of the first and third quadrant. Figure C1 plots an example $\mathbf{X}$ . In the remainder of this section, we multiply $\mathbf{X}$ by a normalization factor $2d$ . Doing so does not impact the max margin weights, but changes the value along the quadrant diagonals to 1.
<details>
<summary>x6.png Details</summary>

### Visual Description
\n
## Heatmap: Ideal vs. Empirical Data
### Overview
The image presents a comparison between an "Ideal" and an "Empirical" dataset, visualized as heatmaps. Both heatmaps display a similar pattern of diagonal lines, but the "Empirical" heatmap exhibits significant noise and variation compared to the clean, uniform "Ideal" heatmap. A colorbar on the right indicates the value scale, ranging from -1 to 1.
### Components/Axes
The image consists of three main components:
1. **"Ideal" Heatmap:** Located on the left side of the image.
2. **"Empirical" Heatmap:** Located in the center of the image.
3. **Colorbar:** Positioned on the right side of the image, representing the value scale.
There are no explicit axes labels, but the heatmaps represent a two-dimensional space where color intensity corresponds to a numerical value. The colorbar has the following markings:
* Top: 1
* Middle: 0
* Bottom: -1
### Detailed Analysis
Both heatmaps appear to be square, with the same dimensions. The "Ideal" heatmap shows two distinct diagonal lines. The color of these lines is a bright yellow/orange, indicating a value close to 1. The background color is a uniform pink/red, indicating a value around 0.
The "Empirical" heatmap also displays two diagonal lines, but these lines are significantly more fragmented and noisy. The color of the lines varies, ranging from yellow/orange (value close to 1) to red/purple (value close to 0 or negative). There are numerous small areas of blue/purple scattered throughout the heatmap, indicating negative values.
The colorbar shows a gradient from blue (-1) to red (0) to yellow (1). The color transitions are smooth and continuous.
### Key Observations
* The "Empirical" data deviates significantly from the "Ideal" data, exhibiting substantial noise and variation.
* The "Empirical" data includes negative values, which are not present in the "Ideal" data.
* The "Ideal" data is perfectly uniform and consistent, while the "Empirical" data is highly irregular.
* The diagonal lines are the dominant feature in both heatmaps.
### Interpretation
The image likely illustrates the difference between a theoretical model ("Ideal") and real-world measurements ("Empirical"). The "Ideal" heatmap represents a perfect scenario, while the "Empirical" heatmap represents the results of a measurement process that is subject to errors and noise.
The presence of negative values in the "Empirical" data suggests that the measurement process may be susceptible to systematic errors or biases. The noise and variation in the "Empirical" data indicate that the measurement process is not perfectly precise or accurate.
The comparison between the two heatmaps highlights the challenges of applying theoretical models to real-world phenomena. It demonstrates that real-world data is often more complex and messy than idealized models predict. The image could be used to illustrate the importance of error analysis and uncertainty quantification in scientific research. The fragmentation of the lines in the empirical data could be due to measurement error, limited resolution, or inherent randomness in the underlying process.
</details>
Figure C1: Example ideal and empirical $\mathbf{X}$ . The matrix does indeed become circulant with nonzero values on the diagonal of the first and third quadrants. The empirical $\mathbf{X}$ is computed from a batch of 3000 examples sampled from a training set consisting of 64 symbols. The eigendecomposition of a circulant matrix is well studied, and can be given in terms of Fourier modes. In particular, the $\ell$ th eigenvector $\mathbf{u}_{\ell}$ is given by
$$
\mathbf{u}_{\ell}=\frac{1}{\sqrt{2d}}(1,r^{\ell},r^{2\ell},\ldots,r^{(2d-1)\ell})\,,
$$
where
$$
r=e^{\frac{\pi i}{d}}\,
$$
and $\ell$ ranges from $0$ to $2d-1$ . The corresponding eigenvalues $\lambda_{\ell}$ are
$$
\lambda_{\ell}=r^{d\ell}=e^{\ell\pi i}\,.
$$
This expression implies that $\lambda_{\ell}=1$ for even $\ell$ and $\lambda_{\ell}=-1$ for odd $\ell$ . Hence, $\mathbf{w}_{*}^{+}$ lies in the subspace spanned by $\mathbf{u}_{\ell}$ for even $\ell$ , and $\mathbf{w}_{*}^{-}$ lies in the subspace spanned by eigenvectors with odd $\ell$ . To characterize this solution further, suppose we partition a weight vector $\mathbf{w}\in\mathbb{R}^{2d}$ into equal halves $\mathbf{w}=(\mathbf{v}^{1};\mathbf{v}^{2})$ , where $\mathbf{v}^{1}\in\mathbb{R}^{d}$ . Considering the even case first, suppose $\mathbf{w}\in\mathbf{U}_{2}$ where $\mathbf{U}_{2}=\text{span}\left\{\mathbf{u}_{0},\mathbf{u}_{2},\ldots,\mathbf{u}_{2(d-1)}\right\}$ . Then there exist coefficients $c_{0},c_{2},\ldots c_{2(d-1)}$ such that
$$
\mathbf{w}=\sum_{n=0}^{d-1}c_{2n}\mathbf{u}_{2n}\,.
$$
Note that
$$
r^{k\ell_{1}}\cdot\overline{r}^{k\ell_{2}+d}=e^{\frac{k\pi i}{d}(\ell_{1}-\ell_{2})}\cdot e^{-\pi i\ell_{2}}\,.
$$
If we partition our set of eigenvectors as $\mathbf{u}_{\ell}=(\mathbf{s}_{\ell}^{1},\mathbf{s}_{\ell}^{2})$ , then
$$
\mathbf{s}_{\ell_{1}}^{1}\cdot\mathbf{s}_{\ell_{2}}^{2}=e^{-\pi i\ell_{2}}\sum_{k=0}^{d-1}e^{\frac{k\pi i}{d}(\ell_{1}-\ell_{2})}\,.
$$
This quantity is 0 when $\ell_{1}\neq\ell_{2}$ . Otherwise, it is $1$ if $\ell_{1}=\ell_{2}$ are even and $-1$ if they are odd. Hence, for $(\mathbf{v}^{1};\mathbf{v}^{2})\in\mathbf{U}_{2}$ , we have that
$$
\mathbf{v}^{1}\cdot\mathbf{v}^{2}=\frac{1}{2d}\left(c_{0}^{2}+c_{2}^{2}+\ldots+c_{2(d-1)}^{2}\right)\,.
$$
Observe also that
$$
\left|\left|\mathbf{v}^{1}\right|\right|=\left|\left|\mathbf{v}^{2}\right|\right|=\frac{1}{\sqrt{2d}}\sqrt{c_{0}^{2}+c_{2}^{2}+\ldots+c_{2(d-1)}^{2}}\,,
$$
so we must have
$$
\frac{\mathbf{v}^{1}\cdot\mathbf{v}^{2}}{\left|\left|\mathbf{v}^{1}\right|\right|\,\left|\left|\mathbf{v}^{2}\right|\right|}=1\,.
$$
In this way, we see that the components of $\mathbf{w}^{+}_{*}$ must be parallel and share the same mangitude. We may repeat the same calculation for $\mathbf{w}\in\mathbf{U}_{1}$ , where $\mathbf{U}_{1}=\text{span}\left\{\mathbf{u}_{1},\mathbf{u}_{3},\ldots,\mathbf{u}_{2d-1}\right\}$ . Doing so reveals that
$$
\frac{\mathbf{v}^{1}\cdot\mathbf{v}^{2}}{\left|\left|\mathbf{v}^{1}\right|\right|\,\left|\left|\mathbf{v}^{2}\right|\right|}=-1\,,
$$
so the components of $\mathbf{w}^{-}_{*}$ must be antiparallel. ∎
### C.2 Heuristic construction
By examining the max average margin solution, we witness the emergence of parallel/antiparallel weight vectors. In Section 3.1, we discussed how parallel/antiparallel weights allow an MLP to solve the SD task. However, it remains unclear how to characterize the learning efficiency and insensitivity to spurious perceptual details in the resulting model, and whether these results apply at all to a ReLU MLP trained on a finite dataset. To begin answering these questions, we develop a heuristic construction that summarizes the learning dynamics of a ReLU MLP over the subsequent sections. We will demonstrate that
1. The hidden weights $\mathbf{w}$ become parallel (or antiparallel) for correspondingly positive (or negative) readout weights $a$
1. The magnitude of the readout weights are such that $|\overline{a^{-}}|>|\overline{a^{+}}|$ , where $\overline{a^{-}}$ denotes the average across negative readout weights and $\overline{a^{+}}$ denotes the average across positive readout weights
We then leverage our understanding of the weight structure to estimate a rich model’s test accuracy on our SD task. The remainder of this appendix is dedicated to developing this heuristic approach.
We proceed using a Markov process approximation to the full learning dynamics in the noiseless setting ( $\sigma^{2}=0$ ). Observe that the gradient updates to the readout and hidden weights take the following form. For a batch containing $N$ training examples,
| | $\displaystyle\Delta a_{i}$ | $\displaystyle=-\frac{c}{N}\sum_{j=1}^{N}\frac{\partial\mathcal{L}_{j}}{\partial f}\,\phi(\mathbf{w}_{i}\cdot\mathbf{x}_{j})\,,$ | |
| --- | --- | --- | --- |
where $c=\frac{\alpha}{\gamma\sqrt{d}}$ , $\alpha$ is the learning rate, and
$$
\displaystyle-\frac{\partial\mathcal{L}_{j}}{\partial f} \displaystyle=-\frac{\partial\mathcal{L}(y,f(\mathbf{x}))}{\partial f}\Bigg{|}_{y_{j},f(\mathbf{x}_{j})} \displaystyle=\frac{y_{j}}{1+e^{f(\mathbf{x}_{j})}}-\frac{1-y_{j}}{1+e^{-f(\mathbf{x}_{j})}}\,.
$$
Focusing on $\Delta\mathbf{w}_{i}$ , we rewrite its gradient update as
$$
\Delta\mathbf{w}_{i}=\sum_{j=1}^{N}\xi_{ij}\mathbf{x}_{j}\,,
$$
where
$$
\xi_{ij}=-\frac{c}{N}\frac{\partial\mathcal{L}_{j}}{\partial f}\,a_{i}\,\phi^{\prime}(\mathbf{w}_{i}\cdot\mathbf{x}_{i})\,.
$$
When written in this form, it becomes clear that the hidden weight gradient updates lie in the basis of the training examples $\mathbf{x}_{j}$ . If the initialization of the hidden weights $\mathbf{w}_{i}$ is small, then $\mathbf{w}_{i}$ lies approximately in the basis of training examples also. Specifically, we require that
$$
\mathbf{w}_{i}(0)\neq\mathbf{0}\;\;\text{and}\;\;\frac{1}{\xi_{ij}}\,\mathbf{w}_{i}(0)\cdot\mathbf{x}_{j}\rightarrow 0\quad\text{as}\;\;d\rightarrow\infty\,, \tag{0}
$$
where $\mathbf{w}_{i}(0)$ refers to $\mathbf{w}_{i}$ at initialization (generally, $\mathbf{w}_{i}(t)$ is the value of $\mathbf{w}_{i}$ after $t$ gradient steps). The requirement that $\mathbf{w}_{i}(0)\neq\mathbf{0}$ ensures that the initial gradient update is nonzero.
Suppose our training set consists of $L$ symbols $\mathbf{z}_{1};\mathbf{z}_{2},\ldots,\mathbf{z}_{L}$ . If we partition $\mathbf{w}_{i}=(\mathbf{v}_{i}^{1};\mathbf{v}_{i}^{2})$ , then after $t$ gradient steps and in the infinite limit $d\rightarrow\infty$
$$
\displaystyle\mathbf{v}_{i}^{1}(t) \displaystyle=\omega_{i,1}^{1}(t)\,\mathbf{z}_{1}+\omega_{i,2}^{1}(t)\,\mathbf{z}_{2}+\ldots+\omega_{i,L}^{1}(t)\,\mathbf{z}_{L}\,, \displaystyle\mathbf{v}_{i}^{2}(t) \displaystyle=\omega_{i,1}^{2}(t)\,\mathbf{z}_{1}+\omega_{i,2}^{2}(t)\,\mathbf{z}_{2}+\ldots+\omega_{i,L}^{2}(t)\,\mathbf{z}_{L}\,,
$$
where $\omega_{i,k}^{p}(t)$ corresponds to the overlap $\mathbf{w}_{i}^{p}(t)\cdot\mathbf{z}_{k}$ . The large number of indices on $\omega$ is unwieldy, so we will omit some or all of them where context allows. Note, for these relations to hold, we require that $\mathbf{z}_{i}\cdot\mathbf{z}_{j}=\delta_{ij}$ as $d\rightarrow\infty$ . In this way, we may consider the $\omega$ ’s to be the coordinates of $\mathbf{w}$ in the basis of the training symbols $\mathbf{z}$ .
Note that $\omega$ is a function of the update coefficients $\xi_{ij}$ . If $\mathbf{w}_{i}\cdot\mathbf{x}_{j}<0$ , then $\xi_{ij}=0$ . Otherwise, $\xi_{ij}$ depends on $\frac{\partial\mathcal{L}_{j}}{\partial f}$ and $a_{i}$ , introducing many additional and complex couplings to other parameters in the model. Our ultimate goal is to understand the general structure of the hidden weights $\mathbf{w}_{i}$ , rather than to obtain exact formulas, so we apply the following coarse approximation
$$
\xi_{ij}=\begin{cases}\text{sign}\left(-\frac{\partial\mathcal{L}_{j}}{\partial f}\,a_{i}\right)&\mathbf{w}_{i}\cdot\mathbf{x}_{j}>0\,,\\
0&\mathbf{w}_{i}\cdot\mathbf{x}_{j}\leq 0\,.\end{cases}
$$
Such an approximation resembles sign-based gradient methods like signSGD (Bernstein ., 2018) and Adam (Kingma, 2014). We also verify empirically that this approximation describes rich regime learning dynamics well.
From Eq (C.1), observe that $\text{sign}\left(-\frac{\partial\mathcal{L}_{j}}{\partial f}\right)=1$ given a label $y_{j}=1$ , and $\text{sign}\left(-\frac{\partial\mathcal{L}_{j}}{\partial f}\right)=-1$ for the label $y_{j}=0$ . Recall from our hand-crafted solution that $a_{i}>0$ implies that $\mathbf{w}_{i}$ should align with same examples, and $a_{i}<0$ implies that $\mathbf{w}_{i}$ should align with different examples. Hence, if $\mathbf{w}_{i}\cdot\mathbf{x}_{j}>0$ , we conclude that $\xi_{ij}=1$ if the example $\mathbf{x}_{j}$ matches the corresponding readout weight $a_{i}$ — that is, $\mathbf{x}_{j}$ is same and $a_{i}>0$ , or $\mathbf{x}_{j}$ is different and $a_{i}<0$ . Otherwise, if there is a mismatch, $\xi_{ij}=-1$ . In this way, we may interpret $\mathbf{w}_{i}$ as a “state vector” to which we add or subtract examples $\mathbf{x}_{j}$ based on a simple set of update rules. We proceed to study the limiting form of $\mathbf{w}_{i}$ by treating it as a Markov process. Our approximation in Eq (C.4) decouples the dependency between hidden weight vectors. The set of hidden weights can be treated as an ensemble of independent Markov processes evolving in parallel, allowing us to understand the overall structure of the hidden weights.
### C.3 Markov process approximation
Altogether, we summarize the learning dynamics on $\mathbf{w}$ through the following Markov process. In the remainder of this section, we drop the index $i$ from $\mathbf{w}_{i}$ and $a_{i}$ , and $\mathbf{w}$ and $a$ should be understood as a representative sets of weights. Similarly, we write $\omega_{k}^{p}(t)$ to represent the coefficient of $\mathbf{v}^{p}$ (the $p$ th partition of $\mathbf{w}$ , $p\in\{1,2\}$ ) for the $k$ th training symbol after $t$ steps. The set of coefficients $\omega$ represent the state of the Markov process, which proceeds as follows.
Step 1.
Initialize $\omega_{k}^{p}(0)=0$ for all $k$ , $p$ . Initialize the time step $t=0$ . Initialize batch updates $b_{k}^{p}=0$ and batch counter $n=0$ . Set the batch size $N$ .
Step 2.
Sample an integer $u$ uniformly at random from the set $[L]=\{1,2,\ldots,L\}$ .
With probability 1/2, set $v=u$ .
Otherwise, sample $v$ uniformly from $[L]\setminus\{u\}$ .
Step 3.
Compute $\rho=\omega_{u}^{1}(t)+\omega_{v}^{2}(t)$ .
If $\rho>0$ , proceed to step 4.
If $\rho=0$ , with probability $1/2$ proceed to step 4. Otherwise, proceed to step 5.
If $\rho<0$ , proceed to step 5.
Step 4.
If $a>0$ , $u=v$ , or $a<0$ , $u\neq v$ , update
| | $\displaystyle b_{u}^{1}$ | $\displaystyle\leftarrow b^{1}_{u}+1\,,$ | |
| --- | --- | --- | --- |
Otherwise, update
| | $\displaystyle b_{u}^{1}$ | $\displaystyle\leftarrow b^{1}_{u}-1\,,$ | |
| --- | --- | --- | --- |
Step 5.
Increment the batch counter $n\leftarrow n+1$ .
If $n=N$ , Set $\omega_{k}^{p}(t+1)\leftarrow\omega_{k}^{p}(t)+b_{k}^{p}$ and increment the time step $t\leftarrow t+1$ . Reset $n\leftarrow 0,b_{k}^{p}\leftarrow 0$ .
Proceed to step 2.
Remarks.
This Markov process approximates the learning dynamics of an MLP given the simplification in Eq (C.4). We elaborate on the link below.
In step 1, we initialize the weight vector to zero. In practice, weight vectors are initialized as $\mathbf{w}\sim\mathcal{N}(\mathbf{0},\mathbf{I}/d)$ . For large $d$ , the condition required in Eq (C.2) allows us to approximate this as zero, with some caveats described below.
In step 2, we sample a training example. Because we operate in the basis spanned by training examples, we discard the vector content of a training symbol and consider only its index. With half the training examples being same and half being different, we sample indices accordingly.
In step 3, we consider the overlap $\mathbf{w}\cdot\mathbf{x}$ , where $\mathbf{x}=(\mathbf{z}_{u},\mathbf{z}_{v})$ . Given the assumptions in Eq (C.2), we have that
| | $\displaystyle\rho$ | $\displaystyle=\mathbf{w}\cdot\mathbf{x}$ | |
| --- | --- | --- | --- |
If the overlap $\rho$ is positive, then the update coefficient $\xi\neq 0$ , so we branch to the step where the state is updated. If $\rho$ is negative, we must have that $\xi=0$ , so we skip the update. If $\omega_{u}^{1}+\omega_{v}^{2}=0$ , the overlap still picks up the initialization, $\mathbf{w}\cdot\mathbf{x}=\mathbf{w}(0)\cdot\mathbf{x}$ . In the limit $d\rightarrow\infty$ , we have that $\mathbf{w}(0)\cdot\mathbf{x}\rightarrow 0$ . However, $\mathbf{w}(0)\cdot\mathbf{x}\neq 0$ almost surely for any finite $d$ . Because $\mathbf{w}(0)$ and $\mathbf{x}$ are radially symmetric about the origin, their overlap is positive with probability 1/2. Thus, when $\rho=0$ , we branch to the corresponding positive or negative overlap steps each with probability 1/2.
In step 4, we apply the updates from $\xi=\pm 1$ , based on whether the training example $\mathbf{z}_{u},\mathbf{z}_{v}$ matches the readout weight $a$ . We conclude the training loop in step 5, and restart with a fresh example.
### C.4 Limiting weight structure
With the setup now complete, we analyze the Markov process proposed above to understand the limiting structure of $\mathbf{w}$ and $a$ . Specifically, we will show that for large $L$ and as $t\rightarrow\infty$ ,
1. If $a>0$ , then $\left|\left|\mathbf{v}_{1}\right|\right|=\left|\left|\mathbf{v}_{2}\right|\right|$ and $\mathbf{v}^{1}\cdot\mathbf{v}^{2}/(\left|\left|\mathbf{v}^{1}\right|\right|\,\left|\left|\mathbf{v}^{2}\right|\right|)=1$
1. If $a<0$ , then $\left|\left|\mathbf{v}_{1}\right|\right|=\left|\left|\mathbf{v}_{2}\right|\right|$ and $\mathbf{v}^{1}\cdot\mathbf{v}^{2}/(\left|\left|\mathbf{v}^{1}\right|\right|\,\left|\left|\mathbf{v}^{2}\right|\right|)=-1$
1. Suppose $\overline{a^{+}}$ is the average over all weights $a>0$ , and $\overline{a^{-}}$ is the average over all weights $a<0$ . Then $|\overline{a^{-}}|>|\overline{a^{+}}|$ .
Our general approach will involve factorizing the Markov process into an ensemble of random walkers with simple dependencies, then reason about the long time-scale behavior of these walkers. For simplicity, we will focus on the single batch case $N=1$ . Generalizing to $N>1$ is straightforward but notationally cluttered, and does not change the final result.
#### C.4.1 Same case
We begin by examining weights $\mathbf{w}$ such that the corresponding readout $a>0$ . Recall that these weights favor same examples. Consider a random walker on $\mathbb{R}$ whose position at time $t$ is given by $s_{u}(t)=\omega_{u}^{1}(t)+\omega_{u}^{2}(t)$ . Then the following rules govern the walker’s dynamics:
1. If $s_{u}(t)>0$ and the model receives a same training example $(\mathbf{z}_{u},\mathbf{z}_{u})$ , then $s_{u}(t+1)=s_{u}(t)+2$ .
1. If the model receives a different training example $(\mathbf{z}_{u},\mathbf{z}_{v})$ , where $u\neq v$ then $s_{u}(t+1)\geq s(t)-1$ .
1. If $s_{u}(T)<0$ , then $s_{u}(t)<0$ for all $t>T$ .
Rules 1 and 2 reflect the update dynamics of the Markov process. Since $s_{u}(t)>0$ , upon receiving a same example $(\mathbf{z}_{u},\mathbf{z}_{u})$ , we witness updates $\omega_{u}^{1}(t+1)=\omega_{u}^{1}(t)+1$ and $\omega_{u}^{2}(t+1)=\omega_{u}^{2}(t)+1$ , so $s_{u}(t+1)=s_{u}(t)+2$ . Similarly, upon receiving a different example $(\mathbf{z}_{u},\mathbf{z}_{v})$ , we have that $\omega_{u}^{1}(t+1)=\omega_{u}^{1}(t)-1$ if $\omega_{u}(t)+\omega_{v}(t)>0$ , so $s_{u}$ decreases at most by 1. Finally, for rule 3, if $s_{u}$ ever falls below 0, then it will never increment. Hence, $s_{u}$ will remain negative for all subsequent steps.
Together, these rules partition our ensemble of walkers $s_{u}$ into two sets: walkers with positive position $\mathcal{S}^{+}(t)=\left\{s_{u}:s_{u}(t)>0\right\}$ and walkers with negative position $\mathcal{S}^{-}(t)=\left\{s_{v}:s_{v}(t)<0\right\}$ . We will show that under typical conditions, members of $\mathcal{S}^{+}$ grow continually more positive, while members of $\mathcal{S}^{-}$ grow continually more negative. We denote $n^{+}(t)=|\mathcal{S}^{+}(t)|$ and $n^{-}(t)=|\mathcal{S}^{-}(t)|$ . Where the meaning is unambiguous, we drop the indices $t$ .
We first make the following counterintuitive observation about the relative occurrence of same and different examples. Although training examples are sampled from each class with equal probability, the probabilities of observing a same or different pair when conditioned on observing a particular training symbol are not equal. Suppose we would like to count all training pairs that contain at least one occurrence of $\mathbf{z}_{u}$ . Out of all same examples, we would expect roughly $\frac{1}{L}$ of such examples to contain $\mathbf{z}_{u}$ . Out of all different examples, we would expect roughly $\frac{2}{L(L-1)}$ occurrences of the pair $(\mathbf{z}_{u},\mathbf{z}_{v})$ , for a specific $v\neq u$ . Across all $L-1$ possible $v$ , this proportion rises to $\frac{2}{L(L-1)}\cdot(L-1)=\frac{2}{L}$ . Hence, the probability of observing a same example conditioned on containing $\mathbf{z}_{u}$ is actually $\frac{1/L}{1/L+2/L}=\frac{1}{3}$ , while the probability of observing different is $\frac{2}{3}$ .
Suppose we allow our Markov process to run for $t$ time steps, after which there are $n^{+}$ walkers in a positive position and $n^{-}$ walkers in a negative position, among $L=n^{+}+n^{-}\equiv n$ total walkers. Upon receiving the next training example $(\mathbf{z}_{u},\mathbf{z}_{v})$ , there are four possible outcomes.
Case 1.
The walker $s_{u}(t)>0$ and we receive a same example (so $u=v$ ). In this case, $s_{u}(t+1)=s_{u}(t)+2$ . The probability of observing a same pair containing $\mathbf{z}_{u}$ is $\frac{1}{3}$ , so we summarize this case as
$$
p(s_{u}\leftarrow s_{u}+2\,|\,s_{u}>0)=\frac{1}{3}\,.
$$
Case 2.
The walker $s_{u}(t)>0$ and we receive a different example (so $u\neq v$ ). Whether $s_{u}$ decrements in this case is complex to determine, and depends on the precise coordinates $\omega_{u}$ and $\omega_{v}$ . We treat this issue coarsely by modeling the probability of decrement through an average case approximation: if $s_{v}>0$ , we assume that $s_{u}$ will always decrement; if $s_{v}<0$ , we assume that $s_{u}$ will decrement with some mean probability $\mu$ . Since $p(s_{v}>0)=\frac{n^{+}-1}{n-1}$ and $p(s_{v}<0)=\frac{n^{-}}{n-1}$ , and the probability of selecting a different example overall remains $2/3$ , we summarize this case as
$$
p(s_{u}\leftarrow s_{u}-1\,|\,s_{u}>0)=\frac{2}{3}\left(\frac{n^{+}-1}{n-1}+\frac{n^{-}}{n-1}\mu\right)\,.
$$
This average case approximation is similar in flavor to the mean field ansatz common in physics, and we employ it for similar reasons: it simplifies a complex many-bodied interaction into a simple interaction between a single body and an average field. We validate the accuracy of this approximation later below.
Case 3.
The walker $s_{u}(t)<0$ and we receive a same example. No updates occur in this case. For completeness, we summarize it as
$$
p(s_{u}\leftarrow s_{u}+2\,|\,s_{u}<0)=0\,.
$$
Case 4.
The walker $s_{u}(t)<0$ and we receive a different example. We again apply a coarse, average case approximation to model the probability of decrementing. If $s_{v}<0$ , we assume that $s_{u}$ will never decrement. If $s_{v}>0$ , the probability of decrementing again depends on our mean quantity $\mu$ . The probability of selecting a different example overall remain $2/3$ , so we summarize this case as
$$
p(s_{u}\leftarrow s_{u}-1\,|\,s_{u}<0)=\frac{2}{3}\left(\frac{n^{+}}{n-1}\mu\right)\,.
$$
To gain greater insight into $\mu$ , we consider how much a walker’s position may drift as it encounters different training examples. Define a walker’s expected drift to be the quantity $\Delta s(t)=\mathbb{E}[s(t+1)-s(t)]$ , averaged over possible walker states $s$ . Then under Eq (C.5) and Eq (C.6), considering a positive walker $s^{+}>0$
$$
\displaystyle\Delta s^{+}(t) \displaystyle=2\,p(s^{+}>0)\,p(s^{+}\leftarrow s^{+}+2\,|\,s^{+}>0) \displaystyle\quad-p(s^{+}>0)p(s^{+}\leftarrow s^{+}-1\,|\,s^{+}>0) \displaystyle=\frac{2n^{+}}{3n}\left(1-\frac{n^{+}+n^{-}\mu-1}{n-1}\right) \displaystyle=\frac{2n^{+}n^{-}(1-\mu)}{3n(n-1)}\,.
$$
Similarly, under Eq (C.7) and Eq (C.8), a negative walker $s^{-}$ has expected drift
$$
\displaystyle\Delta s^{-}(t) \displaystyle=-\,p(s^{-}<0)\,p(s^{-}\leftarrow s^{-}-1\,|\,s^{-}<0) \displaystyle=-\frac{2n^{-}n^{+}\mu}{3n(n-1)}\,.
$$
Suppose $\mu=1$ . In this case, if we encounter a different example $(\mathbf{z}_{u},\mathbf{z}_{v})$ such that $s_{u}>0$ and $s_{v}<0$ , then $s_{u}$ will always decrement. On average, $\Delta s_{u}=0$ while $\Delta s_{v}$ is a negative quantity, indicating that $s_{u}$ will on average remain around the same position while $s_{v}$ decreases. However, if $s_{v}$ decreases without bound, there comes a point where $\omega_{u}+\omega_{v}<0$ , preventing further decrements. This situation implies that $\mu=0$ , resulting in $\Delta s_{u}>0$ and $\Delta s_{v}=0$ . However, if $s_{u}$ now increases without bound, there comes a point where $\omega_{u}+\omega_{v}>0$ , allowing again further decrements, raising our mean update probability back to $\mu=1$ .
In general, if $|\Delta s_{u}|<|\Delta s_{v}|$ , we experience further increments until $|\Delta s_{u}|>|\Delta s_{v}|$ , at which point we experience further decrements, returning us back to $|\Delta s_{u}|<|\Delta s_{v}|$ . Over a long time period, we might therefore expect our dynamics to settle around an average point $|\Delta s_{u}|=|\Delta s_{v}|$ . If this is true, then we employ the relation $|\Delta s^{+}|=|\Delta s^{-}|$ as a self-consistency condition to solve for $\mu$ . Equating (C.9) and (C.10) reveals that $\mu=\frac{1}{2}$ .
Altogether, we arrive at the following picture of the walkers’ dynamics. Walkers $s_{u}$ at a positive position drift with an average rate
$$
\Delta s_{u}=\frac{n^{+}n^{-}}{3n(n-1)}\,.
$$
Meanwhile, walkers $s_{v}$ at a negative position drift with an average rate $\Delta s_{v}=-\Delta s_{u}$ . Over long time periods, $s_{u}\rightarrow\infty$ while $s_{v}\rightarrow-\infty$ . Because positive updates increment both coordinates $\omega_{u}^{1}$ and $\omega_{u}^{2}$ equally, we have that $\omega_{u}^{1}\approx\omega_{u}^{2}>0$ . Meanwhile, because negative updates have a higher chance of decrementing a more positive coordinate, we also have that $\omega_{v}^{1}\approx\omega_{v}^{2}<0$ . In this way, we must have overall that $\left|\left|\mathbf{v}_{1}\right|\right|=\left|\left|\mathbf{v}_{2}\right|\right|$ and $\mathbf{v}^{1}\cdot\mathbf{v}^{2}/(\left|\left|\mathbf{v}^{1}\right|\right|\,\left|\left|\mathbf{v}^{2}\right|\right|)=1$ at long time scales, confirming that a weight vector aligned with same examples adopts parallel components.
To validate our key assumption on $\mu$ , we simulate the Markov process $100$ times with $L=16$ and a batch size of $512$ . We empirically find that $\mu=0.508\pm 0.056$ (given by two standard deviations), matching closely our conjecture that $\mu=\frac{1}{2}$ .
One caveat we have not addressed is the case where $n^{-}=n$ . In this case, no further updates occur and the weights are frozen in their current position. However, as we suggest later in Section C.4.3, the corresponding readout of these weights will be relatively small, reducing its impact. In practice, “dead” weights like these that are negatively aligned with all training symbols appear to be rare in trained models.
#### C.4.2 Different case
For weights $\mathbf{w}$ corresponding to a readout $a<0$ , similar rules hold but now with flipped signs. Considering again a random walker with position $s_{u}(t)=\omega_{u}^{1}(t)+\omega_{u}^{2}(t)$ , the following rules govern the walker’s dynamics:
1. If $s_{u}(t)>0$ and the model receives a training example $(\mathbf{z}_{u},\mathbf{z}_{u})$ , then $s_{u}(t+1)=s_{u}(t)-2$ .
1. If the model receives a training example $(\mathbf{z}_{u},\mathbf{z}_{v})$ or its reverse $(\mathbf{z}_{v},\mathbf{z}_{u})$ , where $u\neq v$ then $s_{u}(t+1)\leq s(t)+1$ .
The rules follow from the update dynamics of the Markov process precisely as before, now for weights sensitive to different examples. Note, there is no equivalence to Rule 3 in this case, since a walker may (in most cases) continue to receive either positive or negative updates regardless of the sign of its position. Indeed, this added symmetry to the different case simplifies the analysis somewhat compared to the same case, where it was necessary to study the interactions between two ensembles of random walkers that evolve in different ways. Here, we may treat all walkers uniformly.
Our general approach for analyzing this case is the same. Conditioned on training examples containing the $u$ th training symbol, recall that observing a same pair $(\mathbf{z}_{u},\mathbf{z}_{u})$ has probability 1/3, and observing a different pair $(\mathbf{z}_{u},\mathbf{z}_{v})$ has probability 2/3. Then we have the following cases:
Case 1.
The walker $s_{u}(t)>0$ and receives a same example. In this case, $s_{u}(t+1)=s_{u}(t)-2$ . The probability of observing a same pair containing $\mathbf{z}_{u}$ is 1/3, so we summarize this case as
$$
p(s_{u}\leftarrow s_{u}-2\,|\,s_{u}>0)=\frac{1}{3}\,.
$$
Case 2.
The walker $s_{u}(t)>0$ and we receive a different example. Whether $s_{u}$ increments is complex to determine, and depends on the precise coordinates $\omega_{u}$ and $\omega_{v}$ . As before, we treat this issue coarsely by approximating the probability of incrementing through an average case parameter $\mu$ . The probability of selecting a different example overall remains $2/3$ , so we summarize this case as
$$
p(s_{u}\leftarrow s_{u}+1\,|\,s_{u}>0)=\frac{2}{3}\mu\,.
$$
Case 3.
The walker $s_{u}(t)<0$ and we receive a same example. No updates occur in this case. For completeness, we summarize it as
$$
p(s_{u}\leftarrow s_{u}-2\,|\,s_{u}<0)=0\,.
$$
Case 4.
The walker $s_{u}(t)<0$ and we receive a different example. We again use our average case parameter to describe the probability of incrementing. The probability of selecting a different example overall remains $2/3$ , so we summarize this case as
$$
p(s_{u}\leftarrow s_{u}+1\,|\,s_{u}<0)=\frac{2}{3}\mu\,.
$$
To obtain a self-consistent condition for $\mu$ , we consider again the expected drift of walkers at positive or negative positions. After $t$ timesteps have elapsed, suppose the number of walkers with position $s^{+}>0$ is $n^{+}$ , and the number of walkers with position $s^{-}<0$ is $n^{-}$ , where $L=n^{+}+n^{-}\equiv n$ . Then combining Eq (C.12) and (C.13), the expected drift for positive walkers is
$$
\Delta s^{+}=\frac{2n^{+}}{3n}(\mu-1)\,.
$$
Combining Eq (C.14) and (C.15), the expected drift of the negative walkers is
$$
\Delta s^{-}=\frac{2n^{-}}{3n}\mu
$$
Note, unlike the same case, the expected drift for positive walkers is negative, and the expected drift for negative walkers is positive. Hence, for $\mu\in(0,1)$ , if $|\Delta s^{+}|>|\Delta s^{-}|$ , the number of positive walkers decreases faster than it increases, so we eventually reach a point where $|\Delta s^{+}|\leq|\Delta s^{+}|$ . However, when $|\Delta s^{+}|<|\Delta s^{-}|$ , the number of negative walkers decreases faster than it increases, so we oscillate back to $|\Delta s^{+}|\geq|\Delta s^{-}|$ . Over a long time period, we assume our walkers settle around an average point $|\Delta s^{+}|=|\Delta s^{-}|$ . If this is true, then as before, we employ the relation $|\Delta s^{+}|=|\Delta s^{-}|$ as a self-consistency condition to solve for $\mu$ . Equating (C.16) and (C.17) indicates that $\mu=\frac{n^{+}}{n}$ .
There are three potential settings for $n^{+}$ to consider: $0<n^{+}<n$ , $n^{+}=n$ , or $n^{+}=0$ . Let us begin with $0<n^{+}<n$ . Because $n^{+}<n$ , the average position of positive walkers experiences a net negative drift, gradually bringing them closer to zero. Because $n^{+}>0$ , the average position of negative walkers experiences a net positive drift, gradually bringing them closer to zero also. Over a long period of time, we would therefore expect $n^{+}\approx n^{-}$ (and $\mu\approx\frac{1}{2}$ ).
If $n^{+}=n$ , then all walkers are in a positive position and $\Delta s^{+}=0$ . Over a long period of time, as the variance in walker position grows, through random chance at least one walker will eventually drift to a negative position, returning us to the case where $0<n^{+}<n$ . If $n^{+}=0$ , then all walkers are in a negative position, and $\Delta s^{-}>0$ . Over time, the walkers’ average position grows more positive, until at least one becomes positive and we again re-enter the case $0<n^{+}<n$ .
Altogether, we arrive at the following picture of the walker’s dynamics. Walkers at a positive position drift with a negative rate down to zero, and walkers at a negative position drift with a positive rate up to zero. After sufficient time has elapsed, we would therefore expect the position of all walkers to be close to zero. However, for a walker $s_{u}$ , positive updates increment the underlying coordinates $\omega_{u}^{1}$ and $\omega_{u}^{2}$ asymmetrically. Furthermore, a more positive coordinate receives a positive update with greater probability. Hence, we must have that $\omega_{u}^{1}\gg\omega_{u}^{2}$ or $\omega_{u}^{1}\ll\omega_{u}^{2}$ . Because their sum must remain close to zero, it must be true that $\omega_{u}^{1}\approx-\omega_{u}^{2}$ . Thus, we have overall that $\left|\left|\mathbf{v}_{1}\right|\right|=\left|\left|\mathbf{v}_{2}\right|\right|$ and $\mathbf{v}^{1}\cdot\mathbf{v}^{2}/(\left|\left|\mathbf{v}^{1}\right|\right|\,\left|\left|\mathbf{v}^{2}\right|\right|)=-1$ , confirming that a weight vector aligned with different examples adopts antiparallel components.
To validate our key assumption on $\mu$ , we simulate the Markov process $100$ times with $L=16$ and a batch size of $512$ . We empirically find that $\mu=0.499\pm 0.009$ (given by two standard deviations), matching closely our conjecture that $\mu=\frac{1}{2}$ .
#### C.4.3 Magnitude of readouts
The final piece to demonstrate in our study of the rich regime is that $|\overline{a^{+}}|<|\overline{a^{-}}|$ , where $\overline{a^{+}}$ corresponds to the average across all positive readout weights and $\overline{a^{-}}$ corresponds to the average across all negative readout weights. Exactly characterizing these magnitudes is difficult, so we apply a heuristic argument based what we learned about the structure of parallel and antiparallel weights above and support it with numeric evidence.
Recall that the update rule for a readout weight $a_{i}$ is given by
$$
\Delta a=-\frac{c}{N}\sum_{j=1}^{N}\frac{\partial\mathcal{L}_{j}}{\partial f}\phi(\mathbf{w}\cdot\mathbf{x}_{j}),
$$
where $c=\frac{\alpha}{\gamma\sqrt{d}}$ and $\alpha$ is the learning rate. Suppose $\mathbf{x}_{j}=(\mathbf{z}_{u},\mathbf{z}_{v})$ . Then the update rule becomes
$$
\Delta a=-\frac{c}{N}\sum_{u,v}\frac{\partial\mathcal{L}_{u,v}}{\partial f}\phi(\omega_{u}^{1}+\omega_{v}^{2})\,.
$$
If $\frac{\partial\mathcal{L}}{\partial f}$ is about the same in magnitude across all training examples, then our readout updates are proportional to
$$
\Delta a\propto\sum_{u,v}S(u,v)\,\phi(\omega_{u}^{1}+\omega_{v}^{2})\,.
$$
where
$$
S(u,v)=\begin{cases}1&u=v\\
-1&u\neq v\,.\end{cases}
$$
Let us first consider the case where $\mathbf{w}$ corresponds to a negative readout weight $a^{-}$ . From above, we know that $\mathbf{w}$ is antiparallel. Hence, when encountering a same example, $\omega_{u}+\omega_{u}\approx 0$ . If the magnitude of all coordinates are roughly equal, when encountering a different example, $\omega_{u}+\omega_{v}>0$ about $1/4$ of the time. Comparing Eq (C.11) to Eq (C.16), we see that the expected drift for antiparallel weight vectors is roughly twice that of parallel weight vectors over long timescales. Altogether, since the number of same and different examples is balanced, we have overall that $|\Delta a^{-}|\propto 2\cdot\frac{1}{4}=\frac{1}{2}$ for $a^{-}<0$ .
Now consider the case where $\mathbf{w}$ corresponds to a positive readout weight $a^{+}$ . From above, we know that $\mathbf{w}$ is parallel. For large batch sizes $N$ , the expected drift of a walker at initialization is 0 There is a 1/3 chance of observing a same example, which increments by 2. There is a 2/3 chance of observing a different example, which decrements by 1. Hence, the expected drift at initialization must be zero overall., so we expect $n^{+}\approx n^{-}$ . If the magnitude of all coordinates are roughly equal and $L$ is large, when encountering a same example, $\omega_{u}+\omega_{v}>0$ about $1/2$ of the time. When encountering a different example, $\omega_{u}+\omega_{v}>0$ about $1/4$ of the time. Altogether, we would therefore expected $|\Delta a^{+}|\propto\frac{1}{2}-\frac{1}{4}=\frac{1}{4}$ .
From this rough estimate, we find that $\frac{|\Delta a^{-}|}{|\Delta a^{+}|}\approx 2$ for average negative and positive readouts. Since the rate of increase for negative readouts tends to be larger than that of positive readouts, we would expect the magnitude of negative readouts to be similarly larger. In fact, if readouts start with small initialization, we may conjecture that $\frac{|\overline{a^{-}}|}{|\overline{a^{+}}|}\approx 2$ . In practice, this quantity turns out to be about $1.56\pm 0.09$ (with 2 standard deviations), computed from 10 runs The MLP has width $m=1024$ , and inputs have dimension $d=512$ . There are $L=32$ training symbols.
Note in the case that $L$ is small (for instance, $L=2$ ), our estimate $|\Delta a^{+}|\propto\frac{1}{4}$ breaks down since there may be only a single set of positive coordinates and no penalty is incurred on negative examples. In this case, we would have $|\Delta a^{+}|\propto\frac{1}{2}$ , so $\frac{|\Delta a^{-}|}{|\Delta a^{+}|}=1$ . Indeed, this is exactly what we observe in the case where $L=2$ . Computing this quantity empirically yields $0.99\pm 0.07$ (with 2 standard deviations), computed from 10 runs As before, the MLP has width $m=1024$ , and inputs have dimension $d=512$ . There are $L=2$ training symbols. This outcome seems to be part of the reason why the model does not generalize well on the SD task with only $2$ symbols, despite developing parallel/antiparallel weights (Figure C2). For $L\geq 3$ , there seems to be sufficient pairs of positive coordinates in parallel weight vectors to restore the situation where $|\Delta a^{-}|>|\Delta a^{+}|$ .
<details>
<summary>x7.png Details</summary>

### Visual Description
\n
## Scatter Plot: Relationship between a<sub>j</sub> and (v<sub>i</sub><sup>-1</sup> ⋅ v<sub>i</sub><sup>-1</sup>) / l<sub>i</sub>
### Overview
The image presents a scatter plot illustrating the relationship between two variables: a<sub>j</sub> on the x-axis and (v<sub>i</sub><sup>-1</sup> ⋅ v<sub>i</sub><sup>-1</sup>) / l<sub>i</sub> on the y-axis. The plot shows a distinct sigmoidal or "S" shaped curve, indicating a non-linear relationship between the two variables. The data points are clustered, forming a clear pattern.
### Components/Axes
* **X-axis Label:** a<sub>j</sub>
* Scale: Approximately ranges from -2.75 to 2.75.
* Markers: -2.5, 0.0, 2.5
* **Y-axis Label:** (v<sub>i</sub><sup>-1</sup> ⋅ v<sub>i</sub><sup>-1</sup>) / l<sub>i</sub>
* Scale: Approximately ranges from -1.2 to 1.2.
* Markers: -1, 0, 1
* **Data Points:** Blue circles, representing individual data observations.
* **Background:** White.
### Detailed Analysis
The data exhibits a clear transition between two distinct states.
* **Left Side (a<sub>j</sub> < 0):** For values of a<sub>j</sub> less than approximately 0, the value of (v<sub>i</sub><sup>-1</sup> ⋅ v<sub>i</sub><sup>-1</sup>) / l<sub>i</sub> remains relatively constant at approximately -1. There is some scatter, but the points cluster tightly around this value.
* **Transition Region (a<sub>j</sub> ≈ 0):** As a<sub>j</sub> approaches 0, there is a rapid increase in (v<sub>i</sub><sup>-1</sup> ⋅ v<sub>i</sub><sup>-1</sup>) / l<sub>i</sub>. This region is characterized by a high density of data points and a steep slope. The transition appears to occur over a range of approximately -0.5 to 0.5 for a<sub>j</sub>.
* **Right Side (a<sub>j</sub> > 0):** For values of a<sub>j</sub> greater than approximately 0, the value of (v<sub>i</sub><sup>-1</sup> ⋅ v<sub>i</sub><sup>-1</sup>) / l<sub>i</sub> plateaus at approximately 1. Similar to the left side, the points are clustered around this value, with some scatter.
There are no explicit data points listed, but the trend is clear.
### Key Observations
* The plot demonstrates a threshold effect. The value of (v<sub>i</sub><sup>-1</sup> ⋅ v<sub>i</sub><sup>-1</sup>) / l<sub>i</sub> remains near -1 until a<sub>j</sub> reaches a certain threshold (around 0), at which point it rapidly increases to approximately 1.
* The transition is not instantaneous but occurs over a small range of a<sub>j</sub> values.
* The data is relatively clean, with no obvious outliers.
### Interpretation
This plot likely represents a system exhibiting a switching behavior. The variable a<sub>j</sub> could be an input parameter or control variable, and (v<sub>i</sub><sup>-1</sup> ⋅ v<sub>i</sub><sup>-1</sup>) / l<sub>i</sub> could be an output or response variable. The sigmoidal shape suggests that the system transitions between two states based on the value of a<sub>j</sub>.
The variables themselves are not immediately interpretable without further context. However, the form of the equation (v<sub>i</sub><sup>-1</sup> ⋅ v<sub>i</sub><sup>-1</sup>) / l<sub>i</sub> suggests that v<sub>i</sub> represents a velocity or rate, and l<sub>i</sub> represents a length or distance. The inverse squared velocity divided by length could represent a form of energy or power.
The plot could be modeling a physical system, a biological process, or a computational model. The threshold behavior suggests a form of activation or triggering mechanism. The steepness of the transition indicates a sensitive response to changes in a<sub>j</sub> near the threshold.
</details>
Figure C2: Rich-regime weight structure when $L=2$ . The model continues to develop parallel/antiparallel weights, though the magnitude of negative readouts is now about the same as the magnitude of positive readouts.
### C.5 Test accuracy prediction
We apply our knowledge on the structure of $\mathbf{w}$ and $a$ to estimate the test accuracy of the rich regime model. Our derivation is heuristic, but seeks to capture broad phenomena rather than achieve exact precision. We validate our predicted test accuracy in Figure C3, demonstrating excellent agreement.
Recall from Section 3.1 that a model achieving the hand-crafted solution exhibits perfect classification of same examples. Any errors are therefore accumulated from misclassifying different examples. The crux of our estimate stems from approximating the classification accuracy of different examples.
Define $\mathcal{I}^{+}$ to be the set of weight indices $i$ such that $a_{i}>0$ , and define $\mathcal{I}^{-}$ to be the set of weight indices $j$ where $a_{j}<0$ . Let $\mathbf{x}$ be a different example. Dropping constants that do not affect the outcome of a classification, our model becomes
$$
f(\mathbf{x})=\sum_{i\in I^{+}}|a_{i}|\,\phi(\mathbf{w}_{i}\cdot\mathbf{x})-\sum_{j\in I^{-}}|a_{j}|\,\phi(\mathbf{w}_{j}\cdot\mathbf{x})\,.
$$
Define the weighted sums
| | $\displaystyle\overline{a^{+}}_{*}$ | $\displaystyle=\frac{\sum_{i\in I^{+}}|a_{i}|\,\phi(\mathbf{w}_{i}\cdot\mathbf{x})}{\sum_{i\in I^{+}}\phi(\mathbf{w}_{i}\cdot\mathbf{x})}\,,$ | |
| --- | --- | --- | --- |
Then
$$
f(\mathbf{x})=\overline{a}^{+}_{*}\sum_{i\in I^{+}}\phi(\mathbf{w}_{i}\cdot\mathbf{x})-\overline{a}^{-}_{*}\sum_{j\in I^{-}}\phi(\mathbf{w}_{j}\cdot\mathbf{x})\,.
$$
If the magnitudes of $\phi(\mathbf{w}_{i}\cdot\mathbf{x})$ are the same for all $i$ and all $\mathbf{x}$ , then $\overline{a^{+}}_{*}=\overline{a^{+}}=\frac{1}{|I^{+}|}\sum_{i\in I^{+}}a_{i}$ . Since $\mathbf{x}$ is an unseen different example, by symmetry we conclude that $\overline{a^{+}}_{*}=\overline{a^{+}}$ is a reasonable approximation. The same applies for $\overline{a^{-}}_{*}=\overline{a^{-}}$ .
Since the magnitude of $f(\mathbf{x})$ does not affect its classification, we divide through by $\overline{a^{+}}$ to redefine our model as
$$
f(\mathbf{x})=\sum_{i\in I^{+}}\phi(\mathbf{w}_{i}\cdot\mathbf{x})-\rho\sum_{j\in I^{-}}\phi(\mathbf{w}_{j}\cdot\mathbf{x})\,,
$$
where $\rho=\frac{|\overline{a^{-}}|}{|\overline{a^{+}}|}$ . To calculate the probability of classifying an unseen different example, we would like to estimate $p(f(\mathbf{x})<0)$ , for $\mathbf{x}=(\mathbf{z},\mathbf{z}^{\prime})$ and $\mathbf{z},\mathbf{z}^{\prime}\sim\mathcal{N}(\mathbf{0},\mathbf{I}/d)$ .
Then from Eq C.3
$$
\phi(\mathbf{w}_{i}\cdot\mathbf{x})=\phi\left(\sum_{k=1}^{L}\left[\omega_{i,k}^{1}\,\mathbf{z}_{k}\cdot\mathbf{z}\pm\omega_{i,k}^{2}\,\mathbf{z}_{k}\cdot\mathbf{z}^{\prime}\right]\right)\,.
$$
Over the distribution of an unseen symbol $\mathbf{z}$
$$
\mathbf{z}_{k}\cdot\mathbf{z}\overset{d}{=}-\mathbf{z}_{k}\cdot\mathbf{z}\,,
$$
so we replace $\pm$ with simply $+$ in the summation. Summing across all weight vectors corresponding to the same class yields
| | $\displaystyle\sum_{i\in I^{+}}\phi(\mathbf{w}_{i}\cdot\mathbf{x})$ | $\displaystyle=\sum_{i\in I^{+}}\phi\left(\sum_{k=1}^{L}\left[\omega_{i,k}^{1}\,\mathbf{z}_{k}\cdot\mathbf{z}+\omega_{i,k}^{2}\,\mathbf{z}_{k}\cdot\mathbf{z}^{\prime}\right]\right)$ | |
| --- | --- | --- | --- |
Since $\mathbf{w}_{i}\cdot\mathbf{x}>0$ , it is likely that $\mathbf{\omega}_{i,k}^{1}\,\mathbf{z}_{k}\cdot\mathbf{z}>0$ . Since $\mathbf{z}_{k}\cdot\mathbf{z}$ is approximately Normal at high $d$ , we approximate the term $c_{i,k}^{1}\equiv\omega_{i,k}^{1}\mathbf{z}_{k}\cdot\mathbf{z}$ as a Half-Normal random variable. We therefore focus on characterizing the distribution of a sum over Half-Normal random variables, which we denote by $\overline{c^{+}}$ :
$$
\sum_{i\in\mathcal{I^{+}}}\phi(\mathbf{w}_{i}\cdot\mathbf{x})\overset{d}{=}\overline{c^{+}}\equiv\sum_{i=1}^{|\mathcal{I}^{+}|}\sum_{k=1}^{L}\left[c_{i,k}^{1}+c_{i,k}^{2}\right]\,.
$$
The individual $c_{i,k}^{1}$ and $c_{i,k}^{2}$ are learned from a finite set of training symbols, so they cannot be independently distributed. However, since training symbols are sampled independently, we would expect for any particular weight index $i=i_{0}$ , the corresponding $c_{i_{0},k}^{1}$ and $c_{i_{0},k}^{2}$ are indeed independent. Hence, we have at least $2L$ independent terms for a particular index $i=i_{0}$ .
The dependency structure across weight indices $i$ is more subtle, but we make a reasonable guess at their structure and later validate this heuristic with numerics. While our analysis in Appendix C assumed each weight vector evolves independently, in a training model they evolve based on the same set of inputs. As a result, significant correlations emerge across weight vectors. To understand these correlations, let us fix our training symbol index to $k=k_{0}$ and consider all terms $c_{i,k_{0}}^{1}=\omega_{i,k_{0}}^{1}\mathbf{z}_{k_{0}}\cdot\mathbf{z}$ . Since they all share $\mathbf{z}_{k_{0}}\cdot\mathbf{z}$ , these quantities are not strictly independent. However, after $T$ training steps, we have that $\omega_{i,k_{0}}^{1}=\mathcal{O}(T)$ and $c_{i,k_{0}}^{1}=\mathcal{O}(T)$ while $\mathbf{z}_{k_{0}}\cdot\mathbf{z}=\mathcal{O}(1)$ , so for our approximation we will consider the dependency incurred from $\mathbf{z}_{k_{0}}\cdot\mathbf{z}$ as negligible.
Let us therefore turn our attention to the coordinates $\omega_{i,k_{0}}^{1}$ . If the model only ever received same inputs $(\mathbf{z}_{k_{0}},\mathbf{z}_{k_{0}})$ , all $\omega_{i,k_{0}}^{1}$ would be identical or zero across indices $i$ . However, if we allow the model to witness different inputs $(\mathbf{z}_{k_{0}},\mathbf{z}_{\ell})$ for some $\ell\neq k_{0}$ , we would expect a distribution of $\omega_{i,k_{0}}^{1}$ driven by the underlying initialization of $\mathbf{w}_{i}$ and the number of symbols $L$ . Coordinates where $\omega_{i,k_{0}}^{1}(0)+\omega_{i,1}^{1}(0)>0$ and $\omega_{i,k_{0}}^{1}(0)+\omega_{i,2}^{1}(0)<0$ would evolve differently from coordinates where $\omega_{i,k_{0}}^{1}(0)+\omega_{i,1}^{1}(0)<0$ and $\omega_{i,k_{0}}^{1}(0)+\omega_{i,1}^{2}(0)>0$ . If the number of training symbols increases, we would expect the number of independent coordinates $\omega_{i,k_{0}}^{1}$ to also increase. Given $L$ training symbols, we might therefore guess that the number of independent coordinates to be proportional to $L-1$ , for $L-1$ symbols where $\ell\neq k_{0}$ . However, we also need to account for the sign of $\mathbf{z}_{\ell}\cdot\mathbf{z}^{\prime}$ . If this quantity is positive and the corresponding readout is positive, then correlations resulting from the symbol $\mathbf{z}_{\ell}$ would be unimportant since they lower the probability that $\mathbf{w}_{i}\cdot\mathbf{x}>0$ , filtering them from the sum. (The reverse is true for weights corresponding to negative readouts.) Hence, roughly half the $L-1$ symbols contribute to unique coordinates. The total number of unique coordinates $\omega_{i,k_{0}}^{1}$ is therefore approximately $\frac{1}{2}(L-1)$ .
If each of our $\frac{1}{2}(L-1)$ independent coordinates carries $2L$ independent dot-product terms, we have
$$
\overline{c^{+}}\overset{d}{=}\,\sum_{\ell=1}^{L\left(L-1\right)}c_{\ell}\,,
$$
where $c_{\ell}$ is distributed Half-Normal with mean $0$ and some variance $\sigma^{2}$ , which will cancel in the final calculation.
Applying the central limit theorem together with the first and second moments of a Half-Normal distribution reveals that
$$
\overline{c^{+}}\sim\mathcal{N}\left((L^{2}-L)\sigma\sqrt{\frac{2}{\pi}}\,,\,(L^{2}-L)\left(1-\frac{2}{\pi}\right)\right)\,.
$$
As we noted before, $\mathbf{z}_{k}\cdot\mathbf{z}\overset{d}{=}\mathbf{z}_{k}\mathbf{z}^{\prime}$ for large $d$ , so the distribution of $\sum\phi(\mathbf{w}\cdot\mathbf{x})$ is the same regardless if the weight vectors $\mathbf{w}$ are parallel or antiparallel. Hence, $\overline{c^{+}}\overset{d}{=}\overline{c^{-}}$ . The distribution of our output is therefore
$$
f(\mathbf{x})\overset{d}{=}\overline{c^{+}}-\rho\,\overline{c^{-}}\,.
$$
Our final probability of classifying an unseen different example correctly is
$$
p(f(\mathbf{x})<0)=\Phi\left(\sqrt{\frac{(2L^{2}-2L)(\rho-1)^{2}}{(\pi-2)(\rho^{2}+1)}}\right)\,,
$$
where $\Phi$ is the CDF of a standard Normal.
From Section C.4.3, we found that $\rho\approx 1.5$ for $L\geq 3$ and $\rho=1$ for $L=2$ . Plugging this value into Eq (C.18) allows us to compute the probability of classifying an unseen different example. If the model classifies all unseen same examples correctly, the total test accuracy of the rich regime model is given by $\frac{1}{2}+\frac{1}{2}p(f(\mathbf{x})<0)$ , yielding the expression we report in Eq (3). We validate this prediction in Figure C3, showing excellent agreement with the measured test accuracy of a rich model.
<details>
<summary>x8.png Details</summary>

### Visual Description
\n
## Chart: Test Accuracy vs. Number of Symbols
### Overview
The image presents a line chart illustrating the relationship between test accuracy and the number of symbols (L). The chart compares theoretical performance with empirical results for different values of gamma (γ).
### Components/Axes
* **X-axis:** "# symbols (L)", ranging from approximately 0 to 12. The scale is linear.
* **Y-axis:** "Test accuracy", ranging from approximately 0.7 to 1.0. The scale is linear.
* **Legend:** Located in the top-right corner of the chart. It contains the following entries:
* "Theory" - represented by a dashed red line.
* "γ = 10<sup>0.0</sup>" - represented by a dark blue circular marker line.
* "γ = 10<sup>1.6</sup>" - represented by a purple circular marker line.
* "γ = 10<sup>1.2</sup>" - represented by a magenta circular marker line.
* "γ = 10<sup>0.8</sup>" - represented by a red circular marker line.
* "γ = 10<sup>0.4</sup>" - represented by a light red circular marker line.
### Detailed Analysis
The chart displays six distinct lines, each representing a different value of gamma. All lines show an increasing trend in test accuracy as the number of symbols increases.
* **Theory (dashed red line):** Starts at approximately (0, 0.75) and rises rapidly, approaching a plateau around (8, 0.98) and remaining near 1.0 for the rest of the range.
* **γ = 10<sup>0.0</sup> (dark blue line):** Starts at approximately (0, 0.75) and rises rapidly, reaching approximately (4, 0.95) and then leveling off, approaching 1.0 around (8, 0.99).
* **γ = 10<sup>1.6</sup> (purple line):** Starts at approximately (0, 0.75) and rises rapidly, reaching approximately (3, 0.92) and then leveling off, approaching 1.0 around (7, 0.98).
* **γ = 10<sup>1.2</sup> (magenta line):** Starts at approximately (0, 0.75) and rises rapidly, reaching approximately (2, 0.90) and then leveling off, approaching 1.0 around (6, 0.97).
* **γ = 10<sup>0.8</sup> (red line):** Starts at approximately (0, 0.75) and rises rapidly, reaching approximately (1, 0.88) and then leveling off, approaching 1.0 around (5, 0.96).
* **γ = 10<sup>0.4</sup> (light red line):** Starts at approximately (0, 0.75) and rises rapidly, reaching approximately (0.5, 0.85) and then leveling off, approaching 1.0 around (4, 0.95).
All lines converge towards a test accuracy of 1.0 as the number of symbols increases. The lines corresponding to higher values of gamma reach higher accuracy levels with fewer symbols.
### Key Observations
* The theoretical curve closely matches the empirical results for γ = 10<sup>0.0</sup>.
* As gamma increases, the rate of accuracy improvement with increasing symbols decreases.
* The initial accuracy is the same for all curves at L=0.
* The curves demonstrate diminishing returns; the increase in accuracy slows down as the number of symbols increases.
### Interpretation
The chart demonstrates the relationship between the number of symbols used and the resulting test accuracy, influenced by the parameter gamma. The theoretical curve provides a benchmark for expected performance. The empirical results suggest that increasing gamma leads to faster convergence to high accuracy, but also indicates diminishing returns as the number of symbols increases. This could imply that beyond a certain point, adding more symbols does not significantly improve test accuracy, and the value of gamma becomes the dominant factor. The convergence of all lines towards 1.0 suggests that, given enough symbols, the system can achieve perfect accuracy regardless of the gamma value. The initial low accuracy for all gamma values suggests a baseline level of noise or inherent difficulty in the task. The data suggests that the choice of gamma is crucial for optimizing performance, especially when the number of symbols is limited.
</details>
Figure C3: Rich-regime test accuracy. We demonstrate close agreement between the theoretically predicted and empirically measured rich-regime test accuracy. The rich-regime parametrization explored in the main text corresponds to $\gamma=1$ . To confirm that our results hold for arbitrarily rich models, we also plot accuracies attained in the ultra-rich regime $\gamma\gg 1$ (Atanasov ., 2024). In all cases, our predictions continue to hold.
Two important details to note:
- The test accuracy of the rich model rises rapidly with $L$ . By $L=3$ , the model already attains over 90 percent test accuracy.
- The test error does not depend on the input dimension $d$ . The impact of $d$ is captured in the variance of $\sigma^{2}$ of our Half-Normal random variables $c$ , which cancels in the final calculation.
In this way, we see how the conceptual parallel/antiparallel representations of the same-different model lead to highly efficient learning and insensitivity to input dimension, completing our analysis of the rich regime.
## Appendix D Lazy regime details
In the lazy regime, we will demonstrate that the model requires 1) far more training symbols than in the rich regime to learn the SD task, and 2) the model’s test accuracy depends explicitly on the input dimension $d$ .
A lazy MLP’s learning dynamics can be described using kernel methods. In particular, the case where $\gamma\rightarrow 0$ corresponds to using the Neural Tangent Kernel (NTK) (Jacot ., 2018), in which weights evolve linearly around their initialization. We demonstrate that the number of training symbols required to generalize using the NTK grows quadratically with the input dimension.
Recall that our model has form
$$
f(\mathbf{x})=\sum_{i=1}^{m}a_{i}\,\phi(\mathbf{w}_{i}\cdot\mathbf{x})\,.
$$
If there are $P$ unique training examples in our dataset, we may rewrite our model in its dual form
$$
f(\mathbf{x})=\sum_{j=1}^{P}b_{j}\,K(\mathbf{x},\mathbf{x}_{j})\,,
$$
for the kernel
$$
K(\mathbf{x},\mathbf{x}_{j})=\frac{1}{m}\sum_{i=1}^{m}\phi(\mathbf{w}_{i}\cdot\mathbf{x})\,\phi(\mathbf{w}_{i}\cdot\mathbf{x}_{j})\,.
$$
For ease of exposition, we assume that inputs $\mathbf{x}$ lie on the unit sphere $\mathbf{x}\in\mathbb{S}^{2d-1}$ . This is exactly true (up to a constant radius) as $d\rightarrow\infty$ . For width $m\rightarrow\infty$ and ReLU activations $\phi$ , the analytic form of the NTK kernel $K$ is known to be
$$
K(u)=u\left(1-\frac{1}{\pi}\cos^{-1}(u)\right)+\frac{1}{2\pi}\sqrt{1-u^{2}}\,.
$$
where $u=\mathbf{x}\cdot\mathbf{x}^{\prime}$ (Cho Saul, 2009).
With the setup complete, we present our central result.
**Theorem 2**
*Let $f$ be an infinite-width ReLU MLP as given in Eq D.1, with an NTK kernel. Suppose inputs $\mathbf{x}$ are restricted to lie on the unit sphere $\mathbf{x}\in\mathbb{S}^{2d-1}$ . If $f$ is trained on a dataset consisting of $P$ points constructed from $L$ symbols with input dimension $d$ , then the test error of $f$ is upper bounded by $\mathcal{O}\left(\exp\left\{-\frac{L}{d^{2}}\right\}\right)$ .*
* Proof*
Our proof strategy proceeds as follows. We restrict the space over which our dual coefficients $b$ can vary to a convenient subset, and upper bound the achievable test error of $f$ over this restricted parameter space. Because the restricted parameter space is a subset of the full parameter space, our derived upper bound applies to the unrestricted model as well. We restrict the dual coefficients $b$ as follows. Let $I^{+}$ be the set of all indices $i$ such that $\mathbf{x}_{i}$ is a same example, and $I^{-}$ be the set of all indices $j$ such that $\mathbf{x}_{j}$ is different. Then for all $i\in I^{+}$ , we fix $b_{i}=b^{+}>0$ . For all $j\in I^{-}$ , we fix $b_{j}=b^{-}<0$ . Hence, we effectively tune just two parameters: $b^{+}$ and $b^{-}$ . We also set a number of coefficients $b$ to zero. Given a dataset with symbols $\mathbf{z}_{1},\mathbf{z}_{2},\ldots,\mathbf{z}_{L}$ , partition the symbols such that set $\mathcal{S}_{1}=\left\{\mathbf{z}_{1},\mathbf{z}_{2},\ldots\mathbf{z}_{L/3}\right\}$ and $\mathcal{S}_{2}=\left\{\mathbf{z}_{L/3+1},\mathbf{z}_{L/3+2},\ldots\mathbf{z}_{L}\right\}$ . Consider the kernel coefficient $b_{k}$ , which corresponds to a training example $\mathbf{x}_{k}=(\mathbf{z}_{\ell_{1}};\mathbf{z}_{\ell_{2}})$ . If $\mathbf{z}_{\ell_{1}}=\mathbf{z}_{\ell_{2}}$ and $\mathbf{z}_{\ell_{1}}\notin\mathcal{S}_{1}$ , then we fix $b_{k}=0$ . If $\mathbf{z}_{\ell_{1}}\neq\mathbf{z}_{\ell_{2}}$ , then we check three conditions: (1) $\mathbf{z}_{\ell_{1}}\in\mathcal{S}_{2}$ , (2) $\ell_{2}-\ell_{1}=1$ , and $\ell_{1}$ is odd. If any one of these conditions is violated, we set $b_{k}=0$ . This procedure for deciding whether $b_{k}=0$ ensures that the remaining nonzero terms in Eq (D.1) are independent, and that there are an equal number of same and different examples remaining. The set $\mathcal{S}_{1}$ determines the symbols that contribute to same examples. The disjoint set $\mathcal{S}_{2}$ determines the symbols that contribute to different examples. We further stipulate that different examples do not contain overlapping symbols, leading to the three conditions enumerated above. Note, to construct a dataset such that there are $P$ nonzero terms in our kernel sum, we require $L\propto P$ symbols. First, suppose $\mathbf{x}$ is a same test example. Since we restricted the summands to be independent in our kernel function, the probability of mis-classifying $\mathbf{x}$ can bounded through a straightforward application of Hoeffding’s inequality
$$
p(f(\mathbf{x})<0)\leq\exp\left\{-\frac{2\mathbb{E}[f(\mathbf{x})]^{2}}{Pc^{2}}\right\}\,
$$
where $P$ is the size of the training set and $c$ is a constant related to the range of individual summands $b_{j}K(\mathbf{x},\mathbf{x}_{j})$ . Note, $b_{j}$ can be arbitrarily small without changing the classification and $0\leq K(u)<3$ , so $c$ is finite. Distributing the expectation, we have
$$
\mathbb{E}[f(\mathbf{x})]=\sum_{i\in I^{+}}b^{+}\,\mathbb{E}[K(\mathbf{x},\mathbf{x}_{i})]+\sum_{j\in I^{-}}b^{-}\,\mathbb{E}[K(\mathbf{x},\mathbf{x}_{j})]\,.
$$
Taylor expanding $K$ to second order in $u$ reveals that
$$
\mathbb{E}[K(u)]=\frac{1}{2\pi}+\frac{\mathbb{E}[u]}{2}+\frac{3\mathbb{E}[u^{2}]}{4\pi}+o(\mathbb{E}[u^{2}])
$$
Since input symbols are normally distributed with mean zero, we know that $\mathbb{E}[u]=0$ and $\mathbb{E}[u^{2}]\propto 1/d$ . Furthermore, if $\mathbf{x}_{j}$ is a same training example and $\mathbf{x}_{k}$ is a different training example, inspecting second moments reveals that $\mathbb{E}[(\mathbf{x}\cdot\mathbf{x}_{j})^{2}]=2\mathbb{E}[(\mathbf{x}\cdot\mathbf{x}_{k})^{2}]$ , for an unseen same example $\mathbf{x}$ . Thus, provided that $|b^{-}/b^{+}|<2$ , substituting (D.4) into (D.3) yields
$$
\mathbb{E}[f(\mathbf{x})]=\mathcal{O}\left(\frac{P}{d}\right)\,,
$$
which implies that
$$
p(f(\mathbf{x})<0)\leq\mathcal{O}\left(\exp\left\{-\frac{P}{d^{2}}\right\}\right)\,.
$$ Now suppose $\mathbf{x}$ is a different test example. If $x^{+}$ is a same training example and $x^{-}$ is a different training example, then the first and second moments of $\mathbf{x}\cdot\mathbf{x}^{+}$ are equal to that of $\mathbf{x}\cdot\mathbf{x}^{-}$ . Hence, (D.4) and (D.3) suggest that if $|b^{-}|-|b^{+}|=\mathcal{O}(1)$ , then $\mathbb{E}[f(\mathbf{x})]=-\mathcal{O}(P)$ . Applying Hoeffding’s a second time suggests that
$$
p(f(\mathbf{x})>0)=\mathcal{O}\left(\exp\left\{-P\right\}\right)\,.
$$
Note, it is possible to satisfy both $|b^{-}|-|b^{+}|$ and $|b^{-}/b^{+}|<2$ , for example with $b^{+}=1$ and $b^{-}=1.1$ . The test error overall is dominated by the contribution from mis-classifying same examples $p(f(\mathbf{x})<0)=\mathcal{O}(\exp\left\{-P/d^{2}\right\})$ . Because of our independence restriction on the dual coefficients $b$ , in order to produce $P$ training examples, we require $L\propto P$ training symbols. The test error of the model overall is therefore upper bounded by $\mathcal{O}\left(\exp\left\{-\frac{L}{d^{2}}\right\}\right)$ . ∎
Hence, in order to maintain a constant error rate, our bound suggests that the number of training symbols $L$ should scale as $L\propto d^{2}$ . While this scaling is an upper bound on the true error rate of a lazy model, Figure 1 f suggests that this quadratic relationship remains descriptive of the full model. There are two important consequences of this result:
1. For a large $d$ , the lazy model requires substantially more training symbols to learn the SD task than the rich model. In Appendix D, we found that the rich model can generalize with as few as $L=3$ symbols. In contrast, Figure 1 f suggests the lazy model will often require hundreds or thousands of training symbols to generalize.
1. For a fixed number of training symbols, a lazy model’s performance decays as $d$ increases. Unlike in the rich case, there is an explicit dependency on $d$ in the test error for the lazy model, hurting its performance as $d$ grows larger.
In this way, we see how a lazy model can leverage the differing statistics of same and different examples to accomplish the SD task, but at the cost of exhaustive training data and strong sensitivity to input dimension.
## Appendix E Bayesian posterior calculations
In Section 3.4, we compute with the posteriors corresponding to two different idealized models: one that generalizes to novel symbols based on the true underlying symbol distribution, and one that memorizes the training symbols. Below, we present the Bayes optimal classifier for our noisy same different, and derive the posteriors associated with these two settings.
### E.1 Generalizing prior
We define the following data generating process that constitutes a prior which generalizes to arbitrary, unseen symbols.
| | $\displaystyle r$ | $\displaystyle\sim\text{Bernoulli}\left(p=\frac{1}{2}\right)$ | |
| --- | --- | --- | --- |
The quantity $r$ represents either a same or different relation. Variables $\mathbf{s}_{1},\mathbf{s}_{2}$ are symbols matching their description in Section 2. The notation $\delta(\mathbf{s}_{1})$ denotes a Delta distribution centered at $\mathbf{s}_{1}$ . Hence, $\mathbf{s}_{1}=\mathbf{s}_{2}$ if $r=1$ , and differ otherwise. Typically, we consider the noiseless case $\sigma^{2}=0$ , but to develop a Bayesian treatment, we allow $\sigma^{2}>0$ . We approximate the noiseless case by considering $\sigma^{2}\rightarrow 0$ .
The Bayes optimal classifier is
$$
\hat{y}_{bayes}=\begin{cases}1&p(r=1\,|\,\mathbf{z}_{1},\mathbf{z}_{2})\geq\frac{1}{2}\\
0&\text{otherwise}\end{cases}
$$
From Bayes rule, we know that
$$
p(r\,|\,\mathbf{z}_{1},\mathbf{z}_{2})\propto p(\mathbf{z}_{1},\mathbf{z}_{2}\,|\,r)\,p(r)\,.
$$
Since $r$ is sampled with equal probability $1$ or $0$ , we have simply
$$
p(r\,|\,\mathbf{z}_{1},\mathbf{z}_{2})\propto p(\mathbf{z}_{1},\mathbf{z}_{2}\,|\,r)\,.
$$
We use the notation $\mathcal{N}(\mathbf{x};\mathbf{\mu},\sigma^{2})$ to mean the PDF of a Normal distribution evaluated at $\mathbf{x}$ , with mean $\mathbf{\mu}$ and covariance $\sigma^{2}\mathbf{I}$ . We then compute
$$
\displaystyle p(\mathbf{z}_{1},\mathbf{z}_{2}\,|\,r=1) \displaystyle=\int\mathcal{N}\left(\mathbf{z}_{1};\mathbf{s},\frac{\sigma^{2}}{d}\right)\,\mathcal{N}\left(\mathbf{z}_{2};\mathbf{s},\frac{\sigma^{2}}{d}\right)\,\mathcal{N}\left(\mathbf{s};\mathbf{0},\frac{\sigma^{2}}{d}\right)\,d\mathbf{s} \displaystyle=\left(\frac{d}{2\pi\sqrt{\sigma^{2}(2+\sigma^{2})}}\right)^{d}\exp\left\{-\frac{d}{2\sigma^{2}}\left(\frac{1+\sigma^{2}}{2+\sigma^{2}}\left(\left|\left|\mathbf{z}_{1}\right|\right|^{2}+\left|\left|\mathbf{z}_{2}\right|\right|^{2}\right)-\frac{2}{2+\sigma^{2}}\left(\mathbf{z}_{1}\cdot\mathbf{z}_{2}\right)\right)\right\}\,, \displaystyle p(\mathbf{z}_{1},\mathbf{z}_{2}\,|\,r=0) \displaystyle=\int\int\mathcal{N}\left(\mathbf{z}_{1};\mathbf{s}_{1},\frac{\sigma^{2}}{d}\right)\,\mathcal{N}\left(\mathbf{z}_{2};\mathbf{s}_{2},\frac{\sigma^{2}}{d}\right)\,\mathcal{N}\left(\mathbf{s}_{1};\mathbf{0},\frac{\sigma^{2}}{d}\right)\,\mathcal{N}\left(\mathbf{s}_{2};\mathbf{0},\frac{\sigma^{2}}{d}\right)\,d\mathbf{s}_{1}\,d\mathbf{s}_{2} \displaystyle=\left(\frac{d}{2\pi(1+\sigma^{2})}\right)^{d}\exp\left\{-\frac{d}{2}\left(\frac{1}{1+\sigma^{2}}\right)\left(\left|\left|\mathbf{z}_{1}\right|\right|^{2}+\left|\left|\mathbf{z}_{2}\right|\right|^{2}\right)\right\}\,.
$$
Using Eq (C.2) and (C.3), we compute
$$
p(r=1\,|\,\mathbf{z}_{1},\mathbf{z}_{2})=\frac{p(\mathbf{z}_{1},\mathbf{z}_{2}\,|\,r=1)}{p(\mathbf{z}_{1},\mathbf{z}_{2}\,|\,r=1)+p(\mathbf{z}_{1},\mathbf{z}_{2}\,|r=0)}\,,
$$
which we plug back into Eq (E.1) to obtain our Bayes classifier under a generalizing prior.
### E.2 Memorizing prior
The data generating process for a model that memorizes the training data is similar to the generalizing model, but the crucial difference is that the symbols $\mathbf{s}$ are now distributed uniformly across the training symbols rather than sampled from their population distribution.
Let $\mathbf{\hat{s}}_{1},\mathbf{\hat{s}}_{2},\ldots,\mathbf{\hat{s}}_{L}$ be the set of $L$ training symbols. Then the data generating process is given by
| | $\displaystyle r$ | $\displaystyle\sim\text{Bernoulli}\left(p=\frac{1}{2}\right)$ | |
| --- | --- | --- | --- |
As before, we compute the probabilities $p(\mathbf{z}_{1},\mathbf{z}_{2}\,|\,r=1)$ and $p(\mathbf{z}_{1},\mathbf{z}_{2},\,|\,r=0)$ , which are given by
$$
\displaystyle p(\mathbf{z}_{1},\mathbf{z}_{2}\,|\,r=1) \displaystyle=\frac{1}{L}\sum_{i=1}^{L}p(\mathbf{z}_{1},\mathbf{z}_{2}\,|\,\mathbf{\hat{s}}_{i}) \displaystyle=\frac{1}{L}\sum_{i=1}^{L}\mathcal{N}\left(\mathbf{z}_{1};\mathbf{\hat{s}}_{i},\frac{\sigma^{2}}{d}\right)\,\mathcal{N}\left(\mathbf{z}_{2};\mathbf{\hat{s}}_{i},\frac{\sigma^{2}}{d}\right) \displaystyle=\left(\frac{d}{2\pi\sigma^{2}}\right)^{d}\exp\left\{-\frac{d}{2\sigma^{2}}\left(\frac{1}{2}\left(\left|\left|\mathbf{z}_{1}\right|\right|^{2}+\left|\left|\mathbf{z}_{2}\right|\right|\right)^{2}-\mathbf{z}_{1}\cdot\mathbf{z}_{2}\right)\right\}\left(\frac{1}{L}\sum_{i=1}^{L}\exp\left\{-\frac{d}{\sigma^{2}}\left|\left|\mathbf{\hat{s}}_{i}-\frac{\mathbf{z}_{1}+\mathbf{z}_{2}}{2}\right|\right|^{2}\right\}\right)\,, \displaystyle p(\mathbf{z}_{1},\mathbf{z}_{2}\,|\,r=0) \displaystyle=\frac{1}{L(L-1)}\sum_{i\neq j}p(\mathbf{z}_{1}\,|\,\mathbf{\hat{s}}_{i})\,p(\mathbf{z}_{2}\,|\,\mathbf{\hat{s}}_{j}) \displaystyle=\frac{1}{L(L-1)}\sum_{i\neq j}^{L}\mathcal{N}\left(\mathbf{z}_{1};\mathbf{\hat{s}}_{i},\frac{\sigma^{2}}{d}\right)\,\mathcal{N}\left(\mathbf{z}_{2};\mathbf{\hat{s}}_{j},\frac{\sigma^{2}}{d}\right) \displaystyle=\frac{1}{L(L-1)}\sum_{i\neq j}\left(\frac{d}{2\pi\sigma^{2}}\right)^{d}\exp\left\{-\frac{d}{2\sigma^{2}}\left|\left|\mathbf{z}_{1}-\mathbf{\hat{s}}_{i}\right|\right|^{2}\right\}\exp\left\{-\frac{d}{2\sigma^{2}}\left|\left|\mathbf{z}_{2}-\mathbf{\hat{s}}_{j}\right|\right|^{2}\right\}\,.
$$
Using Eq (C.4) and (C.5), we compute Eq (E.1) to obtain our Bayes classifier under a memorizing prior.
## Appendix F Rich and lazy scaling
We review rich and lazy regime scaling in our setting. In particular, we consider learning dynamics as we increase the input dimension $d$ (Saad Solla, 1995; Biehl Schwarze, 1995; Goldt ., 2019). This setting differs from other rich-regime studies, where scaling is considered with respect to increasing width $m$ . In particular, maximal update ( $\mu$ P) and the related mean-field parameterizations consider an infinite-width limit (Yang Hu, 2021; Mei ., 2018; Rotskoff Vanden-Eijnden, 2022). Our analysis holds $m$ fixed.
Recall that our model is given by
$$
f(\mathbf{x};\mathbf{\theta})=\frac{1}{\gamma\sqrt{d}}\sum_{i=1}^{m}a_{i}\,\phi(\mathbf{w}_{i}\cdot\mathbf{x})\,.
$$
Let $\mathbf{\theta}(t)$ be the value of the parameters $\mathbf{\theta}$ at time-step $t$ . Crucially, to permit a valid interpolation between rich and lazy learning regimes, our MLP is centered: $f(\mathbf{x};\mathbf{\theta}(0))=0$ . Following Chizat . (2019), we enforce centering by subtracting the initial logit from every prediction. Hence, our classifier takes the form
$$
\displaystyle\tilde{f}(\mathbf{x};\mathbf{\theta}) \displaystyle=f(\mathbf{x};\mathbf{\theta})-f(\mathbf{x};\mathbf{\theta}(0))\,. \tag{0}
$$
We use $\tilde{f}$ as our centered MLP in all experiments.
To see how changing $\gamma$ interpolates between rich and lazy learning regimes, recall that learning richness is a description of activation change over the course of training. One way to operationalize this description is to define rich learning as the case in which parameters $\mathbf{\theta}$ change substantially in comparison with changes in the model output $\tilde{f}$ , and lazy learning as the case in which $\mathbf{\theta}$ change very little with respect to the model output.
<details>
<summary>x9.png Details</summary>

### Visual Description
## Chart: Convergence Rate Comparison
### Overview
The image presents two charts comparing the convergence rate of a process under two conditions: "With 1/√d" and "Without 1/√d". Both charts plot the absolute difference between the current weight and the initial weight, normalized by the initial weight, against a parameter γ (gamma). The charts use different colors to represent different values of 'd'.
### Components/Axes
* **X-axis (both charts):** γ (gamma), with a logarithmic scale ranging from 10^-4 to 10^0.
* **Y-axis (both charts):** |w(t) - x| / |w(0) - x|, ranging from 10^-6 to 10^0, also on a logarithmic scale.
* **Title (left chart):** "With 1/√d"
* **Title (right chart):** "Without 1/√d"
* **Legend (right chart, top-right):**
* d = 128 (light pink)
* d = 256 (pink)
* d = 512 (reddish-pink)
* d = 1024 (purple)
### Detailed Analysis
**Left Chart: With 1/√d**
The left chart displays a single data series (purple, d=1024). The line slopes upward, indicating that the absolute difference decreases as γ increases.
* γ = 10^-4: |w(t) - x| / |w(0) - x| ≈ 0.0015 (1.5 x 10^-3)
* γ = 10^-3: |w(t) - x| / |w(0) - x| ≈ 0.004 (4 x 10^-3)
* γ = 10^-2: |w(t) - x| / |w(0) - x| ≈ 0.015 (1.5 x 10^-2)
* γ = 10^-1: |w(t) - x| / |w(0) - x| ≈ 0.06 (6 x 10^-2)
* γ = 10^0: |w(t) - x| / |w(0) - x| ≈ 0.25 (2.5 x 10^-1)
**Right Chart: Without 1/√d**
The right chart displays four data series, each corresponding to a different value of 'd'. All lines exhibit an upward slope, similar to the left chart, but with variations in their trajectories.
* **d = 128 (light pink):**
* γ = 10^-4: |w(t) - x| / |w(0) - x| ≈ 0.002 (2 x 10^-3)
* γ = 10^-3: |w(t) - x| / |w(0) - x| ≈ 0.006 (6 x 10^-3)
* γ = 10^-2: |w(t) - x| / |w(0) - x| ≈ 0.02 (2 x 10^-2)
* γ = 10^-1: |w(t) - x| / |w(0) - x| ≈ 0.08 (8 x 10^-2)
* γ = 10^0: |w(t) - x| / |w(0) - x| ≈ 0.3 (3 x 10^-1)
* **d = 256 (pink):**
* γ = 10^-4: |w(t) - x| / |w(0) - x| ≈ 0.0018 (1.8 x 10^-3)
* γ = 10^-3: |w(t) - x| / |w(0) - x| ≈ 0.005 (5 x 10^-3)
* γ = 10^-2: |w(t) - x| / |w(0) - x| ≈ 0.018 (1.8 x 10^-2)
* γ = 10^-1: |w(t) - x| / |w(0) - x| ≈ 0.07 (7 x 10^-2)
* γ = 10^0: |w(t) - x| / |w(0) - x| ≈ 0.28 (2.8 x 10^-1)
* **d = 512 (reddish-pink):**
* γ = 10^-4: |w(t) - x| / |w(0) - x| ≈ 0.0016 (1.6 x 10^-3)
* γ = 10^-3: |w(t) - x| / |w(0) - x| ≈ 0.0045 (4.5 x 10^-3)
* γ = 10^-2: |w(t) - x| / |w(0) - x| ≈ 0.016 (1.6 x 10^-2)
* γ = 10^-1: |w(t) - x| / |w(0) - x| ≈ 0.065 (6.5 x 10^-2)
* γ = 10^0: |w(t) - x| / |w(0) - x| ≈ 0.26 (2.6 x 10^-1)
* **d = 1024 (purple):**
* γ = 10^-4: |w(t) - x| / |w(0) - x| ≈ 0.0015 (1.5 x 10^-3)
* γ = 10^-3: |w(t) - x| / |w(0) - x| ≈ 0.004 (4 x 10^-3)
* γ = 10^-2: |w(t) - x| / |w(0) - x| ≈ 0.015 (1.5 x 10^-2)
* γ = 10^-1: |w(t) - x| / |w(0) - x| ≈ 0.06 (6 x 10^-2)
* γ = 10^0: |w(t) - x| / |w(0) - x| ≈ 0.25 (2.5 x 10^-1)
### Key Observations
* In both charts, the absolute difference decreases as γ increases, indicating convergence.
* The convergence rate appears to be slower for smaller values of 'd' (without 1/√d).
* The inclusion of 1/√d (left chart) results in a faster convergence rate compared to the scenarios without it (right chart).
* The lines for different 'd' values in the "Without 1/√d" chart are relatively close together, suggesting that the impact of 'd' is less significant when 1/√d is not included.
### Interpretation
The charts demonstrate the impact of the term 1/√d on the convergence rate of a process. The "With 1/√d" chart shows a clear convergence trend, while the "Without 1/√d" chart exhibits a slower convergence, with the rate being influenced by the value of 'd'. This suggests that including 1/√d accelerates the convergence process. The parameter 'd' likely represents a dimensionality or size parameter, and the inclusion of 1/√d may be a technique to mitigate the curse of dimensionality, leading to faster convergence in higher-dimensional spaces. The logarithmic scales on both axes highlight the relative changes in the absolute difference and γ, emphasizing the exponential nature of the convergence process. The data suggests that the inclusion of 1/√d is beneficial for achieving faster convergence, particularly as γ increases.
</details>
Figure F 1: Activation scales with $\gamma$ and $d$ . We plot the average absolute activation change across 6000 test examples as a function of $\gamma$ (for $m=4096$ ), normalized by the initial activation size $|\mathbf{w}(0)\cdot\mathbf{x}|$ . Higher $\gamma$ leads to more activation change. In the absence of a $1/\sqrt{d}$ prefactor, the activation change scales inversely with $d$ . Including the $1/\sqrt{d}$ prefactor suppresses this change.
Consider the change in $\tilde{f}$ after one step of gradient descent. For a learning rate $\alpha$ and training set size $P$ , we update our parameters as
$$
\displaystyle\mathbf{\theta}(1) \displaystyle=\mathbf{\theta}(0)-\alpha\nabla_{\mathbf{\theta}}\left(\frac{1}{P}\sum_{p=1}^{P}\mathcal{L}(y_{p},\tilde{f}(\mathbf{x}_{p};\mathbf{\theta}))\right) \displaystyle=\mathbf{\theta}(0)-\frac{\alpha}{P}\sum_{p=1}^{P}(y_{p}-\sigma(\tilde{f}(\mathbf{x}_{p};\mathbf{\theta}))\nabla_{\mathbf{\theta}}\tilde{f}(\mathbf{x}_{p};\mathbf{\theta})\,. \tag{1}
$$
Note that for an input $\mathbf{x}$ ,
| | $\displaystyle\frac{\partial\tilde{f}}{\partial a_{i}}$ | $\displaystyle=\frac{1}{\gamma\sqrt{d}}\phi(\mathbf{w}_{i}\cdot\mathbf{x})\,,$ | |
| --- | --- | --- | --- |
Define
| | $\displaystyle\Delta a_{i}\equiv\frac{1}{P}\sum_{p=1}^{P}(y_{p}-\sigma(\tilde{f}(\mathbf{x}_{p};\mathbf{\theta})))\frac{\partial\tilde{f}}{\partial a_{i}}\,,$ | |
| --- | --- | --- |
Substituting our weight updates into our model reveals that
$$
\displaystyle\tilde{f}(\mathbf{x};\mathbf{\theta}(1))=\frac{\alpha}{\gamma^{2}d} \displaystyle\sum_{i=1}^{m}\Big{[}\Delta a_{i}\,\phi(\Delta w_{i}\cdot\mathbf{x})/(\gamma\sqrt{d}) \displaystyle+\Delta a_{i}\,\phi(\mathbf{w}(0)\cdot\mathbf{x}) \displaystyle+a_{i}(0)\,\phi(\Delta\mathbf{w}_{i}\cdot\mathbf{x})\Big{]}\,. \tag{1}
$$
Observe that $|\Delta a_{i}\,\phi(\mathbf{w}(0)\cdot\mathbf{x})|=\mathcal{O}_{d}(1)$ and $|a_{i}(0)\,\phi(\Delta\mathbf{w}_{i}\cdot\mathbf{x})|=\mathcal{O}_{d}(1/\sqrt{d})$ for a test point $\mathbf{x}$ . If we adopt a learning rate $\alpha=\gamma^{2}d$ , then we have overall
$$
|\tilde{f}(\mathbf{x};\mathbf{\theta}(1))|=\mathcal{O}_{d}(1)\,. \tag{1}
$$
In this way, we find that the model output changes by a constant amount relative to the input dimension $d$ . Meanwhile, the parameters change by a total magnitude
$$
||\mathbf{\theta}(1)-\mathbf{\theta}(0)||=\mathcal{O}_{d}(\alpha\left(|\Delta a_{i}|+||\Delta\mathbf{w}_{i}||\right)=\mathcal{O}_{d}(\gamma\sqrt{d})\,. \tag{1}
$$
At initialization, we have that
$$
||\mathbf{\theta}(0)||=\mathcal{O}_{d}(|a_{i}(0)|+||\mathbf{w}_{i}(0)||)=\mathcal{O}_{d}(\sqrt{d})\,, \tag{0}
$$
so the change in weights relative to the scale of their initialization is simply
$$
||\mathbf{\theta}(1)-\mathbf{\theta}(0)||/||\mathbf{\theta}(0)||=\mathcal{O}_{d}(\gamma)\,. \tag{1}
$$
All together, after one gradient step, while the model output changes by a constant amount with respect to $d$ , the model parameters change by $\gamma$ relative to their initialization. For $\gamma\rightarrow 0$ , the initialization dominates (even as the model output changes), resulting in lazy learning. For increasing $\gamma$ , the parameters move proportionally further from their initialization, resulting in progressively rich learning.
Peculiar to our setting is the additional $1/\sqrt{d}$ factor in the output scale of the MLP, not found in other rich-regime studies that consider width scaling (Yang Hu, 2021; Mei ., 2018; Rotskoff Vanden-Eijnden, 2022). In the absence of the $1/\sqrt{d}$ factor, Eqs (F.1) and (F.2) suggest that we should adjust our learning rate to be $\alpha=\gamma^{2}$ in order to maintain a stable $\mathcal{O}_{d}(1)$ output change with increasing $d$ . However, the relative weight change in Eq (F.3) now becomes $\mathcal{O}_{d}(\gamma/\sqrt{d})$ . For fixed $\gamma$ , a model becomes lazier as $d$ increases. Hence, to maintain consistent richness, we require an additional $1/\sqrt{d}$ prefactor on the MLP (along with the corresponding $\alpha=\gamma^{2}d$ learning rate).
Figure F 1 illustrates these conclusions. Increasing $\gamma$ increases the change in activations $|\mathbf{\tilde{w}}(t)\cdot\mathbf{x}|$ for a test example $\mathbf{x}$ . In the absence of the $1/\sqrt{d}$ prefactor, increasing $d$ also decreases the change in activations.
## Appendix G Model and task details
We enumerate all model and task configurations in this Appendix. Exact details are available in our code, https://github.com/wtong98/equality-reasoning, which can be run to reproduce all plots in this manuscript.
### G.1 Model
In all experiments, we use a two-layer MLP without biases that takes inputs $\mathbf{x}\in\mathbb{R}^{d}$ and outputs
$$
f(\mathbf{x})=\frac{1}{\gamma\sqrt{d}}\sum_{i=1}^{m}a_{i}\,\phi(\mathbf{w}_{i}\cdot\mathbf{x})\,,
$$
where $\phi$ is a point-wise ReLU nonlinearity and $\gamma$ is a hyperparamter that governs learning richness. Our MLP is centered using the procedure described in Appendix F. To produce a classification, $f$ is passed through a standard logit link function
$$
\hat{y}=\frac{1}{1+e^{-f}}\,.
$$
Parameters are initialized based on $\mu$ P (Yang ., 2022). Specifically, we initialize our weights as
| | $\displaystyle a_{i}$ | $\displaystyle\sim\mathcal{N}\left(0,1/m\right)\,,$ | |
| --- | --- | --- | --- |
We train the model using stochastic gradient descent on binary cross entropy loss. Following Atanasov . (2024), we set the learning rate $\alpha$ as $\alpha=\gamma^{2}d\,\alpha_{0}$ for $\gamma\leq 1$ and $\alpha=\gamma\sqrt{d}\,\alpha_{0}$ for $\gamma>1$ . The base learning rate $\alpha_{0}$ is task-specific, and varies from $0.01$ to $0.5$ . To measure a model’s performance, we train for a large, fixed number of iterations past convergence in training accuracy, and select the best test accuracy from the model’s history.
### G.2 Same-Different
The same-different task consists of input pairs $\mathbf{z}_{1},\mathbf{z}_{2}\in\mathbb{R}^{d}$ , where $\mathbf{z}_{i}=\mathbf{s}_{i}+\mathbf{\eta}_{i}$ . The labeling function $y$ is given by
$$
y(\mathbf{z}_{1},\mathbf{z}_{2})=\begin{cases}1&\mathbf{s}_{1}=\mathbf{s}_{2}\\
0&\mathbf{s}_{1}\neq\mathbf{s}_{2}\end{cases}\,.
$$
We sample these quantities as
| | $\displaystyle\mathbf{s}$ | $\displaystyle\sim\mathcal{N}\left(\mathbf{0},\mathbf{I}/d\right)\,,$ | |
| --- | --- | --- | --- |
A training set is sampled such that half the training examples belong to class $1$ , and half belong to class $0$ . Crucially, the training set consists of $L$ fixed symbols $\mathbf{s}_{1},\mathbf{s}_{2},\ldots,\mathbf{s}_{L}$ sampled prior to the experiment. All training examples are constructed from these $L$ symbols. During testing, symbols are sampled afresh, forcing the model to generalize. If the noise variance $\sigma$ is not explicitly stated, then we take it to be $\sigma=0$ . We use a base learning rate $\alpha_{0}=0.1$ with batches of size 128.
### G.3 PSVRT
The PSVRT task consists of a single-channel square image with two blocks of bit-patterns. If the bit-patterns match exactly, then the image belongs to the same class. If the bit-patterns differ, then the image belongs to the different class. Images are flattened before being passed to the MLP.
Images are patch-aligned to prevent overlapping bit-patterns. An image is tiled by non-overlapping square regions which may be filled by bit-patterns. No two bit-patterns may share a single patch. Unless otherwise stated, we use patches that are 5 pixels to a side, and images that are 5 patches to a side, for a total of 25 by 25 pixels.
One important feature of PSVRT is that the inputs do not grow in norm as their dimension increases. Because there are only ever two patches in an image, regardless of its size, the total norm of the input remains constant regardless of the image dimensions. As a result, the $1/\sqrt{d}$ scaling on the MLP output is extraneous for PSVRT, and we remove it in these experiments.
A subset of all possible bit-patterns are used for training. The remaining unseen bit-patterns are used for testing. We use a base learning rate $\alpha_{0}=0.5$ with batches of size 128.
### G.4 Pentomino
The Pentomino task consists of a single-channel square image with two pentomino shapes. If the shapes are the same (up to rotation, but not reflection), then the image belongs to the same class. If the bit-patterns differ, then the image belongs to the different class. Images are flattened before being passed to the MLP.
Like before, images are patch-aligned. To provide a border around each pentomino, patches are 7 pixels to a side. Unless otherwise stated, images are 2 patches to a side, for a total of 14 by 14 pixels.
As with PSVRT, the inputs for Pentomino do not grow in norm as their dimension icnreases. There are only ever two pentomino shapes in an image, regardless of its dimension, so the total norm of the input remains constant. Like with PSVRT, we remove the $1/\sqrt{d}$ output scaling on the MLP for these experiments.
There are a total of 18 possible pentomino shapes. A subset of these 18 is held out for testing, and the model trains on the remainder. To improve training stability, mild Gaussian blurs are randomly applied to training images, but not testing images. We use a base learning rate $\alpha_{0}=0.5$ with batches of size 128.
### G.5 CIFAR-100
The CIFAR-100 same-different task consists of full-color images taken from the CIFAR-100 dataset. Images are 32 by 32 pixels, and depict 1 of among 100 different classes. To form an input example, we place two images side-by-side, forming a larger 64 by 32 pixel image. If the images come from the same class (but are not necessarily the same exact image), the example belongs to the same class. If the images come from different classes, the example belong to the different class.
To separate an MLP’s ability to reason about equality from its ability to extract meaningful visual features, we first pass the image through a VGG-16 backbone pretrained on ImageNet. Activations are then taken from an intermediate layer, flattened, and passed to the MLP. Because VGG-16 activations are coordinate-wise $O(1)$ in magnitude, we normalize them by $1/\sqrt{d}$ before input to the model. The resulting performance of the MLP from activations of each layer are plotted in Figure G 1.
Of the 100 total classes, a subset is held out for testing, and the model trains on the remainder. We use a base learning rate $\alpha_{0}=0.01$ with batches of size 128.
<details>
<summary>x10.png Details</summary>

### Visual Description
\n
## Chart: Test Accuracy vs. Number of Classes for Different Regularization Parameters
### Overview
This image presents a grid of 15 line plots, each showing the relationship between "Test accuracy" and "# classes" for different ReLU layer configurations (relu1_1 to relu5_3) and regularization parameters (γ). Each plot visualizes the performance of a model under varying numbers of classes, with different levels of regularization applied. The x-axis represents the number of classes, ranging from 2^0 (1) to 2^6 (64), displayed on a logarithmic scale. The y-axis represents the test accuracy, ranging from approximately 0.6 to 0.9.
### Components/Axes
* **X-axis Label:** "# classes" (logarithmic scale)
* **Y-axis Label:** "Test accuracy"
* **Legend:** Located in the bottom-right corner, it defines the different regularization parameters (γ) represented by different line styles and colors:
* γ = 0 (pink, dashed)
* γ = 10^-4 (light green, dashed)
* γ = 10^-3 (light blue, dashed)
* γ = 10^-2 (light purple, dashed)
* γ = 10^-1 (black, solid)
* γ = 10^0 (dark red, solid)
* **Subplot Titles:** Each subplot is labeled with a combination of "relu" layer number and a second index (e.g., "relu1\_1", "relu5\_3").
* **Data Points:** Each line represents the test accuracy for a specific regularization parameter as the number of classes increases. Shaded areas around the lines represent the standard deviation or confidence interval.
### Detailed Analysis or Content Details
The image consists of a 3x5 grid of plots. Here's a breakdown of the trends observed in each plot, along with approximate data points. Note that precise values are difficult to extract due to the resolution of the image.
**Row 1 (relu1\_1 to relu5\_1):**
* **relu1\_1:** All lines are relatively flat. γ = 0 (pink) starts around 0.75 and ends around 0.78. γ = 10^-1 (black) starts around 0.75 and ends around 0.77.
* **relu2\_1:** Similar to relu1\_1, lines are relatively flat. γ = 0 (pink) starts around 0.75 and ends around 0.79. γ = 10^-1 (black) starts around 0.75 and ends around 0.77.
* **relu3\_1:** Lines show a slight upward trend. γ = 0 (pink) starts around 0.68 and ends around 0.75. γ = 10^-1 (black) starts around 0.68 and ends around 0.72.
* **relu4\_1:** Lines show a slight upward trend. γ = 0 (pink) starts around 0.68 and ends around 0.75. γ = 10^-1 (black) starts around 0.68 and ends around 0.72.
* **relu5\_1:** Lines show a slight upward trend. γ = 0 (pink) starts around 0.68 and ends around 0.75. γ = 10^-1 (black) starts around 0.68 and ends around 0.72.
**Row 2 (relu1\_2 to relu5\_2):**
* **relu1\_2:** Lines are relatively flat. γ = 0 (pink) starts around 0.75 and ends around 0.78. γ = 10^-1 (black) starts around 0.75 and ends around 0.77.
* **relu2\_2:** Lines are relatively flat. γ = 0 (pink) starts around 0.75 and ends around 0.79. γ = 10^-1 (black) starts around 0.75 and ends around 0.77.
* **relu3\_2:** Lines show a slight upward trend. γ = 0 (pink) starts around 0.68 and ends around 0.75. γ = 10^-1 (black) starts around 0.68 and ends around 0.72.
* **relu4\_2:** Lines show a slight upward trend. γ = 0 (pink) starts around 0.68 and ends around 0.75. γ = 10^-1 (black) starts around 0.68 and ends around 0.72.
* **relu5\_2:** Lines show a slight upward trend. γ = 0 (pink) starts around 0.68 and ends around 0.75. γ = 10^-1 (black) starts around 0.68 and ends around 0.72.
**Row 3 (relu1\_3 to relu5\_3):**
* **relu1\_3:** Lines are relatively flat. γ = 0 (pink) starts around 0.75 and ends around 0.78. γ = 10^-1 (black) starts around 0.75 and ends around 0.77.
* **relu2\_3:** Lines are relatively flat. γ = 0 (pink) starts around 0.75 and ends around 0.79. γ = 10^-1 (black) starts around 0.75 and ends around 0.77.
* **relu3\_3:** Lines show a slight upward trend. γ = 0 (pink) starts around 0.68 and ends around 0.75. γ = 10^-1 (black) starts around 0.68 and ends around 0.72.
* **relu4\_3:** Lines show a slight upward trend. γ = 0 (pink) starts around 0.68 and ends around 0.75. γ = 10^-1 (black) starts around 0.68 and ends around 0.72.
* **relu5\_3:** Lines show a slight upward trend. γ = 0 (pink) starts around 0.68 and ends around 0.75. γ = 10^-1 (black) starts around 0.68 and ends around 0.72.
### Key Observations
* The regularization parameter γ = 0 (pink) generally exhibits the highest test accuracy across all ReLU configurations.
* Higher values of γ (e.g., γ = 10^-1) tend to result in lower test accuracy.
* The impact of the number of classes on test accuracy is minimal for most configurations, with lines remaining relatively flat.
* The shaded areas around the lines indicate some variability in the test accuracy for each regularization parameter.
* There is a slight upward trend in test accuracy as the number of classes increases for some configurations, but the effect is small.
### Interpretation
The data suggests that using no regularization (γ = 0) generally leads to the best performance in terms of test accuracy for these ReLU network configurations. Applying regularization, even at moderate levels (γ = 10^-1), tends to decrease performance. The relatively flat lines indicate that the number of classes does not significantly impact the model's ability to generalize, at least within the range of 1 to 64 classes tested. The small variability indicated by the shaded areas suggests that the results are relatively consistent.
The consistent performance of γ = 0 across different ReLU layers suggests that the network architecture itself may not be the primary factor influencing performance in this scenario. The regularization parameter appears to be the dominant factor. The lack of a strong trend with the number of classes could indicate that the dataset is not complex enough to require a large number of classes for effective learning, or that the network capacity is sufficient to handle the given number of classes without overfitting.
</details>
Figure G 1: CIFAR-100 same-different accuracy across different VGG-16 activations. Activations are named by relu[block]_[layer] The plot with name id corresponds to using the raw images directly without first preprocessing in VGG-16. Earlier and later layers demonstrate an interesting collapse where learning richness does not seem to impact classification accuracy very strongly. Intermediate layers suggest that greater learning richness tends to perform better, though the richest model tends to do poorly. Shaded error regions correspond to 95 percent confidence intervals estimated from 6 runs.