2404.02258v1
Model: nemotron-free
# Mixture-of-Depths: Dynamically allocating compute in transformer-based language models
**Authors**: David Raposo, Sam Ritter, Blake Richards, Timothy Lillicrap, Peter Conway Humphreys, Adam Santoro
> Google DeepMind
> Google DeepMindMcGill University & Mila
\pdftrailerid
redacted Corresponding author: draposo@google.com, adamsantoro@google.com
Abstract
Transformer-based language models spread FLOPs uniformly across input sequences. In this work we demonstrate that transformers can instead learn to dynamically allocate FLOPs (or compute) to specific positions in a sequence, optimising the allocation along the sequence for different layers across the model depth. Our method enforces a total compute budget by capping the number of tokens ( $k$ ) that can participate in the self-attention and MLP computations at a given layer. The tokens to be processed are determined by the network using a top- $k$ routing mechanism. Since $k$ is defined a priori, this simple procedure uses a static computation graph with known tensor sizes, unlike other conditional computation techniques. Nevertheless, since the identities of the $k$ tokens are fluid, this method can expend FLOPs non-uniformly across the time and model depth dimensions. Thus, compute expenditure is entirely predictable in sum total, but dynamic and context-sensitive at the token-level. Not only do models trained in this way learn to dynamically allocate compute, they do so efficiently. These models match baseline performance for equivalent FLOPS and wall-clock times to train, but require a fraction of the FLOPs per forward pass, and can be upwards of 50% faster to step during post-training sampling.
1 Introduction
Not all problems require the same amount of time or effort to solve. Analogously, in language modeling not all tokens and sequences require the same time or effort to accurately make a prediction. And yet, transformer models expend the same amount of compute per token in a forward pass. Ideally, transformers would use smaller total compute budgets by not spending compute unnecessarily.
Conditional computation is a technique that tries to reduce total compute by expending it only when needed (Bengio, 2013; Bengio et al., 2013, 2016). Various algorithms offer solutions to when and how much compute should be used (Ainslie et al., 2023; Fedus et al., 2022; Bapna et al., 2020). However, general formulations of this challenging problem may not work well with existing hardware constraints since they tend to introduce dynamic computation graphs (Graves, 2016; Dehghani et al., 2018). The most promising conditional computation methods may instead be those that are harmonious with our current hardware stack, which prioritizes static computation graphs, and known tensor sizes that are selected to maximize hardware utilization.
Here we consider the problem of language modeling using a static compute budget that can be made less than that used by a vanilla transformer. The network must learn how to dynamically allocate the available compute by making decisions per-token, in each layer, about where to spend compute from the available budget. In our implementation total compute is user defined and unchanging prior to training, rather than being a function of the networkās on-the-fly decisions. Thus, hardware efficiency gainsāsuch as reduced memory footprint, or reduced FLOPs per forward passācan be anticipated and exploited ahead of time. As we will show, these gains can be had without sacrificing overall performance.
We leverage an approach akin to Mixture of Experts (MoE) transformers, in which dynamic token-level routing decisions are made across the network depth. Departing from MoE, we choose to either apply a computation to a token (as would be the case for a standard transformer), or pass it through a residual connection (remaining unchanged and saving compute). Also in contrast to MoE, we apply this routing to both forward MLPs and multi-head attention. Since this therefore also impacts the keys and queries we process, the routing makes decisions not only about which tokens to update, but also which tokens are made available to attend to. We refer to this strategy as Mixture-of-Depths (MoD) to emphasize how individual tokens pass through different numbers of layers, or blocks, through the depth of the transformer (see figure 1).
The MoD technique also allows one to trade-off performance with speed. On the one hand, one can train an MoD transformer that improves upon vanilla transformers by as much as $1.5\%$ on the final log probability training objective for equivalent training FLOPs (isoFLOP), and while taking an equivalent amount of wall-clock time to train. On the other hand, one can train an MoD transformer that achieves training loss parity with an isoFLOP optimal vanilla transformer, but which uses a fraction of the FLOPs (upwards of 50%) per forward pass, and hence is faster to step. Together, these results imply that MoD transformers learn to route intelligently (i.e., skipping computations that are unnecessary) since they can achieve equal or better log probabilities per sequence despite a smaller FLOP footprint per forward pass.
<details>
<summary>2404.02258v1/mod.png Details</summary>

### Visual Description
# Technical Document Extraction: Neural Architecture and Routing Decisions
## Diagram Analysis: Mixture-of-Depths Architecture
### Left Diagram: Component Flow
1. **Input Layer**
- Input: `X_i` (left) and `X_i+1` (right)
- Spatial grounding: Inputs at bottom of diagram
2. **Left Pathway (w=0.41)**
- **Layer Stack**:
- MLP (purple)
- Normalize (pink)
- Self-attention (yellow)
- Normalize (pink)
- **Output**:
- Route (cyan) ā `X_i+1`
3. **Right Pathway (w=0.65)**
- **Layer Stack**:
- MLP (purple)
- Normalize (pink)
- Self-attention (yellow)
- Normalize (pink)
- **Output**:
- Route (cyan) ā `X_i+1`
4. **Routing Logic**
- Weights:
- Left pathway: `w=0.41`
- Right pathway: `w=0.65`
- Spatial grounding: Weights labeled at bottom of respective pathways
### Right Diagram: Routing Decisions Heatmap
1. **Title**: "Mixture-of-Depths Routing Decisions"
2. **Axes**:
- X-axis: "Sequence"
- Y-axis: "Layer"
3. **Color Coding**:
- Purple: "Use block" (as per legend)
- Orange: "Route around block" (as per legend)
4. **Legend**:
- Position: Bottom right
- Spatial grounding: Legend colors match heatmap squares exactly
## Component Isolation
### Left Diagram Regions
1. **Header**: "layers" (top arrow)
2. **Main Chart**: Dual-pathway architecture with layer stacks
3. **Footer**: Inputs (`X_i`, `X_i+1`) and routing weights
### Right Diagram Regions
1. **Header**: "Mixture-of-Depths Routing Decisions"
2. **Main Chart**: Heatmap visualization
3. **Footer**: Legend
## Trend Verification
- **Left Pathway**:
- Visual trend: Consistent layer stacking with decreasing weight (0.41)
- **Right Pathway**:
- Visual trend: Consistent layer stacking with higher weight (0.65)
- **Heatmap**:
- No numerical trends; categorical representation of routing decisions
## Data Extraction
### Legend Cross-Reference
| Color | Label | Spatial Match |
|--------|---------------------|---------------|
| Purple | Use block | Confirmed |
| Orange | Route around block | Confirmed |
### Axis Markers
- Left Diagram:
- X-axis: Inputs (`X_i`, `X_i+1`)
- Y-axis: Layer stacking
- Right Diagram:
- X-axis: Sequence
- Y-axis: Layer
## Language Analysis
- **Primary Language**: English
- **No additional languages detected**
## Critical Observations
1. Weighted routing mechanism (0.41 vs 0.65) suggests adaptive computation
2. Heatmap shows non-uniform routing decisions across layers/sequences
3. Normalization layers appear after both MLP and self-attention components
4. Self-attention modules are consistently placed in both pathways
## Final Notes
- All textual elements extracted per STRICT INSTRUCTION requirements
- Spatial grounding confirmed for all legend elements and axis markers
- No data tables present; diagram components fully transcribed
</details>
Figure 1: Mixture-of-Depths Transformer. As in mixture-of-experts (MoE) transformers we use a router to choose among potential computational paths. But unlike in MoE transformers the possible choices are a standard blockās computation (i.e., self-attention and MLP) or a residual connection. Since some tokens take this second route, Mixture-of-Depths (MoD) transformers have a smaller total FLOP footprint compared to vanilla or MoE transformers. On the top right is depicted a trained modelās routing decisions for a short sequence truncated to 64 tokens for visualization purposes. When examining the choices one can find tokens processed by later blocksā layers, despite passing through relatively few total blocks throughout the modelās depth. This is a unique feature of MoD compared to conventional halting-based, or "early-exit" conditional computation, which instead engage blocks serially, or vanilla transformers, which engage every block.
2 Background
The transformer architecture has become the workhorse of a revolution in practical artificial intelligence, bringing unprecedented capabilities at the cost of expensive training runs and serving procedures. This has spurred tremendous interest in making transformer architectures more efficient (Tay et al., 2020; Gupta and Agrawal, 2021). One of the promising approaches is conditional computation, whereby learned mechanisms determine when and how to expend computation. This terminology was introduced by Bengio (2013), and the concept was explored further over the next several years (Bengio et al., 2013; Cho and Bengio, 2014; Graves, 2016; Jernite et al., 2017; Bengio et al., 2016; Wang et al., 2017).
A wide variety of recent work has developed conditional computation methods for transformers. Some of this work focuses on "early exiting", that is, learning to decide when to end computation on a given token, allowing the token to skip any remaining transformer layers after the exit decision is made (Elbayad et al., 2019; Liu et al., 2021; Schuster et al., 2022). In MoD, unlike in early-exit methods, a token can skip middle layers, then be updated via self-attention with tokens that that have gone through all the middle layers. We speculate that this might be a useful property.
Other work has developed methods for iterating transformer layers with shared weights for an adaptive number of steps (Simoulin and CrabbƩ, 2021; Dehghani et al., 2018). Bolya et al. (2023) developed a method for choosing tokens to merge when running inference on a trained vision transformer which notably requires no learning. Lei et al. (2023) make use of conditional computation in a fine tuning setting by building on adapter approaches (He et al., 2021) to learn to skip blocks of frozen pre-trained weights in favor of running only a small fine-tuned adapter.
CoLT5 (Ainslie et al., 2023) uses conditional routing to select whether a given token will pass through a heavy or light pathway for each feedforward layer. Further, they use the same routing mechanism to select whether a token will attend to all other tokens or to a select few, as in Guo et al. (2022). Like MoD, CoLT5 uses soft top-k for making routing decisions. However, CoLT5 focuses on a encoder-decoder setting, and thus does need to contend with the problem of efficient sequential decoding given the non-causal nature of the top-k operation. In contrast, our current work with MoD focuses on the decoder-only setting, and so we propose a predictive router to enable efficient inference for conditional computation in transformers.
One successful formulation of conditional computation is the the "mixture-of-experts" layer (MoE) as introduced by Shazeer et al. (2017). Developed initially in the context of LSTMs, later work showed compelling empirical results for MoE with transformers (Lepikhin et al., 2020; Fedus et al., 2022; Zoph et al., 2022). Unlike other conditional computation approaches that try to conserve or expend additional compute, MoE transformers use conditional logic to route tokens to one of many expert MLPs while keeping total compute expenditure constant. Our mixture-of-depths method can be thought of as using the routing logic from MoE transformers, but rather than having multiple experts, MoD deploys a single expert which can be dynamically skipped.
3 Implementing Mixture-of-Depths Transformers
Our high-level strategy is as follows:
- Set a static compute budget that is less than that of an equivalent vanilla transformer by limiting the number of tokens in a sequence that can participate in a blockās computations (i.e., self-attention and subsequent MLP). For example, while a vanilla transformer might permit all the tokens in a sequence to participate in self-attention, we might limit the number to 50% of the tokens in a sequence. See section 3.1.
- Use a per-block router to emit a scalar weight for each token, which expresses the routerās preference for that token to participate in a blockās computations or to route around it. See section 3.2.
- Identify the top- $k$ scalar weights (per sequence, per block) to select those tokens that will participate in a blockās computations. Since precisely $k$ tokens will participate in the blockās computations, the computation graph and tensor sizes remain static throughout training; it is merely the tokensā participation that is dynamic and context-sensitive, as determined by the router. See section 3.3.
We then discuss some complications when sampling post-training in section 3.5.
3.1 Defining a compute budget
To enforce a total compute budget per forward pass we leverage the notion of capacity, which defines the total number of tokens that comprise the input to a given computation (e.g., the tokens participating in self-attention, a given expert in MoE transformers, etc). For example, the self-attention and MLP in each vanilla transformer block have a capacity of $T$ āthe total number of tokens across the sequence and batch. MoE transformers, on the other hand, use a capacity less than $T$ per expert MLP so as to more evenly divide the total compute across each expert. But, since they use multiple experts per block, their total capacity is approximately equal to that of a vanilla transformer.
Generally, it is the token capacity that determines the total FLOPs for transformers that use conditional computation, rather than the outcomes of any routing decisions. This is because static-graph implementations account for the worst-case scenarios decisions; e.g., a computationās inputs will be padded to its capacity amount even if relatively few tokens actually end up routing to it, and/or tokens will be dropped from the computation if the capacity is exceeded.
We can achieve our goal of using a smaller compute budget per forward pass compared to a vanilla transformer by lowering the capacity of the computations. However, using a smaller compute budget haphazardly will result in a performance degradation. We hypothesize that certain tokens might not require as much processing as others, and these tokens can be identified through learning. Therefore, if the network learns to choose the right tokens to fill up its capacities, then it may preserve its performance. In the following we describe routing schemes that can be used for this purpose.
3.2 Routing around transformer blocks
We consider the setting whereby we route tokens to one of two computational paths: (1) self-attention and MLP blocks, and (2) a residual connection. The latter is computationally cheap, and results in a block output that is entirely determined by the value of its input. The former path is computationally expensive.
The total number of FLOPs per forward pass will be fewer than that in a vanilla transformer if we set the capacity for path (1) to be anything less than $T$ (the total number of tokens across the sequence and batch). For example, if we were to set a blockās capacity to $\frac{T}{2}$ (i.e., half the number of tokens as would be the case in a vanilla transformer) then query-times-key matrix multiplication during self-attention becomes $25\%$ as FLOP-intensive as in a vanilla transformer ( $(\frac{T}{2})^{2}$ vs. $T^{2}$ ). Similar calculations can determine the FLOP-savings for the MLP.
Intuitively, the total FLOPs per forward pass decreases (and the time to complete a forward pass decreases) in proportion to how aggressively we shrink the blocksā capacities. However, downstream performance will also be affected by how aggressively we shrink the blocks capacities, and by the routing algorithm we implement.
At one extreme, if we leave each blockās capacity at $T$ and route every token to (rather than around) each block, then we recover a vanilla transformer. At the other extreme, if we set each blockās capacity to $0$ and route all tokens around each block, then weāre left with a very fast model that doesnāt engage with the vast majority of the transformerās parameters, and undoubtedly has poor downstream performance. We hypothesize that somewhere between these two extremes is an optimal model that is faster than a vanilla Transformer and performs as well, if not better, all while being faster to step.
3.3 Routing schemes
Naively, one can leverage stochasticity to route tokens, akin to layer or block ādropoutā. We present this routing scheme as a control, and will show that it significantly under-performs relative to vanilla transformers.
We hypothesize that learned routing is preferable. Intuitively, the network should be able to learn which tokens require more or less processing than others. If we are correct that Transformers often expend more compute than they need to make their predictions, then it is an empirical question as to how aggressively we can shrink each blockās capacity, and hence, how many tokens we can afford to route around each block.
<details>
<summary>2404.02258v1/routing.png Details</summary>

### Visual Description
# Technical Document Extraction: Token and Expert Choice Routing Mechanisms
## Diagram Overview
The image illustrates three routing mechanisms in a machine learning architecture, focusing on token and expert selection processes. The diagram uses color-coded bars, arrows, and route boxes to represent decision flows and token distribution.
---
### 1. **Token-choice Routing**
#### Components:
- **Experts**:
- **Expert 1**: Red and orange bars (two parallel bars per expert).
- **Expert 2**: Gray bar (single bar).
- **Expert 3**: Cyan and yellow bars (two parallel bars).
- **Route Box**:
- Contains a bar chart (gray bars) representing token distribution.
- Arrows labeled **"Routing decision"** point to experts.
- Dashed arrow labeled **"Dropped token"** indicates unrouted tokens.
- **Flow**:
- Tokens are routed to experts based on the route box's bar chart.
- Expert 2 receives tokens via a secondary (dashed) path.
#### Key Observations:
- Token distribution is explicit via the route box's bar chart.
- Expert 2 has a fallback routing mechanism (dashed arrow).
---
### 2. **Expert-choice Routing**
#### Components:
- **Experts**:
- **Expert 1**: Red and yellow bars.
- **Expert 2**: Gray and purple bars.
- **Expert 3**: Cyan and yellow bars.
- **Route Box**:
- Connects to **top-2 choices** via arrows.
- Arrows labeled **"Routing decision"** select experts.
- **Flow**:
- Tokens are routed to the top 2 experts based on criteria (e.g., relevance, confidence).
- Dashed lines indicate dropped tokens.
#### Key Observations:
- Hierarchical routing: Only top 2 experts are selected per token.
- Expert 3 receives tokens via a direct path from the route box.
---
### 3. **Expert-choice MoD (Mixture of Depths)**
#### Components:
- **Experts**:
- **Expert 1**: Red and yellow bars.
- **Expert 3**: Cyan and yellow bars.
- **Route Box**:
- Labeled **"Self-attention & MLP"**, indicating integrated mechanisms.
- Arrows labeled **"top-2 choices"** route tokens.
- **Flow**:
- Tokens are routed to the top 2 experts using a combination of self-attention and MLP.
- Dashed lines indicate dropped tokens.
#### Key Observations:
- Hybrid routing mechanism combining attention and multilayer perceptron (MLP).
- Simplified expert selection (only Experts 1 and 3 are active).
---
### Common Elements Across All Routing Mechanisms
1. **Color Coding**:
- Experts are represented by distinct color pairs (e.g., red/orange, gray, cyan/yellow).
- No explicit legend, but colors are consistent across sections.
2. **Dropped Tokens**:
- Represented by dashed arrows or unconnected bars.
3. **Routing Decisions**:
- Arrows labeled **"Routing decision"** or **"top-2 choices"** dictate token flow.
---
### Summary of Routing Strategies
| Mechanism | Token Selection | Expert Selection | Key Features |
|-------------------------|-----------------------|------------------------|---------------------------------------|
| **Token-choice** | Explicit distribution | All experts | Fallback routing for Expert 2 |
| **Expert-choice** | Top-2 experts | Top-2 experts | Hierarchical selection |
| **Expert-choice MoD** | Top-2 experts | Top-2 experts | Hybrid self-attention + MLP |
---
### Notes for Technical Implementation
1. **Token-choice Routing**:
- Requires a bar chart (route box) to quantify token distribution.
- Fallback mechanisms (dashed arrows) ensure robustness.
2. **Expert-choice Routing**:
- Depends on a scoring system to rank experts (e.g., relevance scores).
- Top-2 selection reduces computational overhead.
3. **Expert-choice MoD**:
- Combines self-attention (for contextual token relationships) and MLP (for expert-specific processing).
- Simplifies the expert pool to critical candidates.
---
### Diagram Flow Summary
1. **Input Tokens**: Represented by vertical bars (colors denote experts).
2. **Route Box**: Processes tokens to determine routing decisions.
3. **Expert Selection**: Tokens are routed to selected experts (solid arrows) or dropped (dashed arrows).
4. **Output**: Processed tokens from selected experts.
This diagram provides a framework for optimizing computational efficiency in large-scale models by dynamically routing tokens to specialized experts.
</details>
Figure 2: Routing schemes. Tokens are funnelled to the computational path of their choice when using token-choice routing (left). If a given path exceeds its capacity (e.g., more than two tokens in this example) then surplus tokens must be dropped (purple token). The exact token that is ultimately dropped depends on the precise implementation in the underlying code. For example, priority is often given to those tokens that come earlier in the sequence or batch order. With expert-choice routing (middle), precisely $k$ (in this case, two) tokens are chosen per path using a top- $k$ mechanism across the tokensā router weights. Here, tokens are dropped if they are not among the top- $k$ with respect to any given path (orange token), and some tokens may even be funnelled to multiple paths (yellow token). In this work we deploy expert-choice routing (right). However, because we use just a single path, we leverage the implicit knowledge that tokens will be dropped if $k$ is less than the sequence length so that we can route tokens away from the self-attention and MLP computations, thus expending fewer FLOPs in a given forward pass of the model.
There are two learned routing schemes we consider (see figure 2): token-choice and expert-choice. In token-choice routing, a router produces per-token probability distributions across computational paths (e.g., across expert identities in MoE Transformers). Tokens are then shuttled to the path they preferāi.e., that with the highest probabilityāand auxiliary losses ensure that all tokens donāt converge to the same path. Token-choice routing can have load balancing problems since there isnāt a guarantee that tokens divide themselves appropriately between the possible paths. āExpert choice routingā flips this recipe on its head: rather than having tokens choose the path they prefer, each path instead chooses the top- $k$ tokens based on the tokensā preferences. This ensures a perfect load balance since $k$ tokens are guaranteed to be shuttled to each path. However, it could result in over- or under-processing of some tokens, since some tokens may be among the top- $k$ for multiple paths, or for none of them.
We decided to leverage expert-choice routing for a few reasons. First, it obviates the need for an auxiliary balancing loss. Second, since the top- $k$ operation depends on the magnitude of the router weights, this routing scheme allows for relative routing weights to help determine which tokens most need the blockās computations; routers can try to ensure that the most critical tokens are among the top- $k$ by setting their weight appropriately, which is not possible with token-choice routing schemes. For our specific use-case, wherein one computational path is essentially a null operation, it might be critical that important tokens are routed away from the null operation. Third, because we only route through two paths, a single top- $k$ operation can efficiently split the tokens into two mutually exclusive sets, one for each computational path, preventing the over- or under-processing problem mentioned above.
3.4 Routing implementation
As a reminder of the high-level intuition, each token is processed by a router to produce a scalar weight, and the top- $k$ weights are then used to choose the token identities that will route through a transformerās block, which comprises self-attention and the subsequent MLP.
Suppose we have the set of token embeddings in a sequence of length $S$ for a given layer $l$ ; that is $X^{l}=\{x_{i}^{l}|i\text{ is an integer, }1⤠i⤠S\}$ . The router weight for a given token embedding is a scalar produced as a result of a linear projection, $r_{i}^{l}=w_{\theta}^{T}x_{i}^{l}$ .
Our goal is to use these router weights to determine the output of a blockās computation of each token. Suppose $P_{\beta}(R^{l})$ is the $\beta$ -th percentile of the set of router weights $R^{l}$ , where $\beta=1-C/S$ and $C$ is the user-defined capacity per batch element (an integer $<S$ that defines the number of tokens from a sequence that will be processed by a given function). A blockās output for a given token is:
$$
\displaystyle x_{i}^{l+1}=\begin{cases}r_{i}^{l}f_{i}(\tilde{X}^{l})+x_{i}^{l},&\text{if }r_{i}^{l}>P_{\beta}(R^{l})\\
x_{i}^{l},&\text{if }r_{i}^{l}<P_{\beta}(R^{l})\\
\end{cases} \tag{1}
$$
Here, $\tilde{X}^{l}$ is the set of tokens whose router values $r_{i}^{l}>P_{\beta}(R^{l})$ (that is, the ātop-kā tokens), and $f$ comprises self-attention and the subsequent MLP. Note that the output for a given token $x_{i}^{l+1}$ might depend on other tokens $x_{iā j}^{l}$ because of the self-attention operation. The cardinality of $\tilde{X}^{l}$ is $C$ (or $k$ ): the user-defined capacity. Therefore, the mixture-of-depths transformer accrues compute savings relative to the baseline because the input to the blockās computations $f$ comprise fewer tokens than usual ( $C<S$ ), rendering the self-attention and MLP less expensive.
Notably, we multiply the output of the function $f$ by the router weights. This puts the router weights along the āgradient pathā, thus subjecting them to the forces of gradient descent through the course of the language modeling task (We experimented with versions where the router weights are also included along the computational path for those tokens that bypass the blockās computations, but it seems to be sufficientāand implementationally simplerāto only include the router weights along the computational path for those tokens that do not bypass the blockās computations).
3.5 Sampling
While expert-choice routing has a number of advantages, it has one distinct problem: the top- $k$ operation is non-causal. This means that whether a given tokenās routing weight is among the top- $k$ for the sequence depends on the values of the routing weights for tokens that come after it, which we donāt have access to when autoregressively sampling.
We tested two methods to work around this problem. The first introduces a simple auxiliary loss that empirically affects the primary language modeling objective by approximately $0.2-0.3\%$ , but allows us to sample from the model autoregressively. We use a binary cross-entropy loss wherein the routerās outputs provide the logits, and the top- $k$ selections of these logits provide the targets (i.e. 1 if a token was among the top- $k$ , and 0 if not). Intuitively, this loss centers the sigmoid of the routerās outputs around $0.5$ ; those tokens that are selected among the top-k are pressured to produce router outputs above $0.5$ , and those not among the top-k will be pressured to produce router outputs below $0.5$ . The second method introduces a small auxiliary MLP predictor (akin to a second router) that receives the same inputs as the router (with a stop gradient), but whose output is a prediction whether that token will be among the top- $k$ or not in the sequence. This method does not affect the language modeling objective, and empirically does not significantly impact the step speed.
Equipped with these new methods, we can sample autoregressively by choosing to route tokens to or around a block based on the routerās output, which does not depend on any information from future tokens. We provide empirical evidence that this is a relatively easy auxiliary task that quickly achieves $99\%$ accuracy.
3.6 Training methods
All models use the same basic hyperparameter configurations (e.g. cosine schedules equal to $1Ć$ the training steps, 128 batch size, 2048 sequence length) except for changes to the number of layers, heads, and embedding size to produce differently sized models during isoFLOP analyses.
4 Results
4.1 Training, isoFLOP comparisons
<details>
<summary>2404.02258v1/results-1.png Details</summary>

### Visual Description
# Technical Document Extraction
## Chart 1: Loss vs. Parameters
### Axes
- **X-axis**: Parameters (50M to 1B)
- **Y-axis**: Loss (3.0 to 3.5)
### Legend
1. **Baseline** (solid black line)
2. **MoD (12.5% capacity, every 2)** (teal line)
3. **MoD (50% capacity, every 2)** (purple line)
4. **MoD (50% capacity, random routing, every 2)** (orange line)
### Key Trends
- Baseline loss decreases sharply initially, then increases after ~300M parameters.
- MoD (12.5% capacity) shows the steepest initial decline, stabilizing near 3.1 loss.
- MoD (50% capacity, every 2) and MoD (50% capacity, random routing) exhibit similar trends but with slight variations in loss values.
### Annotations
- Points labeled **1ā4** correspond to legend entries, marking critical performance thresholds.
---
## Chart 2: Loss vs. FLOPs (Top-Right)
### Axes
- **X-axis**: FLOPs (0 to 4,000 *1e18*)
- **Y-axis**: Loss (3.0 to 3.6)
### Legend
1. **Line 1** (solid black)
2. **Line 2** (dashed black)
3. **Line 3** (solid teal)
4. **Line 4** (dashed teal)
### Key Trends
- All lines show a monotonic decline in loss as FLOPs increase.
- Line 1 (solid black) achieves the lowest loss (~3.1) at 4,000 *1e18* FLOPs.
- Line 4 (dashed teal) has the highest loss (~3.4) across all FLOPs.
---
## Chart 3: FLOPs/step (Bottom-Left)
### Axes
- **X-axis**: Categories (Baselines: 1, 2; MoD: 3, 4)
- **Y-axis**: FLOPs/step (*10¹ā“*)
### Legend
- **Baselines** (dark blue bars)
- **MoD** (teal bars)
### Key Trends
- **Baselines**:
- Category 1: ~6 *10¹ā“* FLOPs/step
- Category 2: ~12 *10¹ā“* FLOPs/step
- **MoD**:
- Category 3: ~3 *10¹ā“* FLOPs/step
- Category 4: ~7 *10¹ā“* FLOPs/step
---
## Chart 4: Steps/TPUv5 (Bottom-Right)
### Axes
- **X-axis**: Categories (Baselines: 1, 2; MoD: 3, 4)
- **Y-axis**: Steps/TPUv5 (0 to 5)
### Legend
- **Baselines** (dark blue bars)
- **MoD** (teal bars)
### Key Trends
- **Baselines**:
- Category 1: ~2.5 Steps/TPUv5
- Category 2: ~1.5 Steps/TPUv5
- **MoD**:
- Category 3: ~4.5 Steps/TPUv5
- Category 4: ~2 Steps/TPUv5
---
## Cross-Referenced Observations
1. **Baseline vs. MoD**:
- MoD configurations (3, 4) outperform Baselines (1, 2) in FLOPs/step and Steps/TPUv5.
- MoD (50% capacity, random routing) achieves the lowest loss in Chart 1 but higher FLOPs/step in Chart 3.
2. **FLOPs vs. Loss**:
- Higher FLOPs correlate with lower loss across all configurations (Chart 2).
3. **Efficiency Trade-offs**:
- MoD (12.5% capacity) balances low loss and high FLOPs/step (Charts 1 and 3).
- MoD (50% capacity, random routing) prioritizes loss reduction at the cost of computational efficiency (Charts 1 and 4).
</details>
Figure 3: MoD hyperparameter tuning. Variants of the MoD transformer were trained for 6e18 FLOPs to determine the optimal hyperparameters for further isoFLOP analyses. On the left plot, the grey box indicates models that perform better than the isoFLOP optimal baseline. We found the best MoD variant to be that which has the option to route every other block, and which uses a top-k of $256$ (so, $256$ , or 12.5% of the sequenceās tokens are processed by self-attention and the subsequent MLP, while $1792$ tokens, or $87.5\%$ of the sequenceās tokens route around the block). Shown on the right are the learning curves for a selection of models. Notably, model #3 achieves equal performance to the isoFLOP optimal baseline but steps 66% faster, due to the relatively fewer FLOPs needed per forward pass.
We first trained models with a relatively small FLOP budget (6e18) to determine optimal hyperparameters (see figure 3). In general, we found that MoD transformers drag the baseline isoFLOP curve "down and to the right". That is, the optimal MoD transformer achieves a lower loss than the optimal baseline, and also has more parameters. A fortunate consequence of this effect is that there exist smaller MoD models that, while they are not themselves isoFLOP optimal for their hyperparameter setting, are nevertheless as- or better-performing than the optimal baseline model while being faster to step. For example, a 220M parameter MoD (figure 3 model #3) variant slightly outperforms the isoFLOP optimal baseline (also 220M, figure 3 model #1), but is upwards of 60% faster to step during training. Crucially, when run on equivalent hardware these two model variants take take approximately the same amount of wall-clock time to train (figure 3).
We tested routing every block or every other block, using capacities from 12.5% to 95% of the total sequence. While routing every other block was crucial for strong performance, we found that aggressive capacity reduction was best (gradual improvements were observed when reducing the capacity down to 12.5% of the total sequence, corresponding to 87.5% of tokens routing around blocks, with performance degrading beyond this point). So, it seems the network is robust to significant capacity reductions as long as there is frequent opportunity for full capacity self-attention and MLP computations.
Learned routing is crucial, as MoD transformers that use stochastic routing (implemented using a top- $k$ operation on router weights sampled from a Gaussian distribution) perform drastically worse than both the baseline and normal MoD transformer (figure 3).
<details>
<summary>2404.02258v1/mod-isoflop.png Details</summary>

### Visual Description
# Technical Document Extraction: Dual-Graph Analysis
## Left Graph: Loss vs Parameters
### Axes
- **X-axis**: "Parameters" (logarithmic scale)
- Markers: 50M, 100M, 300M, 1B, 2B
- **Y-axis**: "Loss" (linear scale)
- Range: 2.5 to 3.3
- Shaded bands:
- 2.6ā2.7 (bottom)
- 2.8ā2.9 (middle)
- 3.1ā3.2 (top)
### Legend
- **Baseline**: Black dots
- **Mixture-of-Depths**: Blue dots
### Data Trends
1. **Baseline (Black)**:
- Loss decreases from ~3.25 (50M params) to ~2.65 (300M params).
- Sharp increase to ~3.25 at 1B params, then gradual decline to ~2.65 at 2B params.
2. **Mixture-of-Depths (Blue)**:
- Loss decreases from ~3.25 (50M params) to ~2.85 (300M params).
- Slight increase to ~2.9 at 1B params, then decline to ~2.6 at 2B params.
3. **Annotations**:
- **1**: Peak at ~3.25 loss (Baseline, 1B params).
- **2**: Local minimum at ~2.85 loss (Mixture-of-Depths, 300M params).
- **3**: Baseline loss at ~2.7 (300M params).
- **4**: Mixture-of-Depths loss at ~2.6 (2B params).
## Right Graph: Normalized Loss vs Normalized FLOPs
### Axes
- **X-axis**: "Normalized FLOPs per FFW pass (to isoFLOP-optimal baseline)"
- Range: 0.2 to 3.0
- Vertical dashed line at 1.0 (isoFLOP-optimal baseline).
- **Y-axis**: "Normalized Loss"
- Range: 0.98 to 1.04
### Legend
- **Model Size (# of parameters)**: Blue circles (size proportional to parameter count).
### Data Trends
1. **Model Sizes**:
- Larger models (e.g., 2B params) cluster near the isoFLOP-optimal baseline (1.0 FLOPs).
- Smaller models (e.g., 50M params) show higher normalized loss (~1.03) and lower FLOPs (~0.2).
2. **Key Points**:
- **1**: Baseline model at 1.0 FLOPs (isoFLOP-optimal baseline).
- **2**: Model with ~1.03 normalized loss (smaller size, ~0.6 FLOPs).
- **3**: Model with ~0.99 normalized loss (larger size, ~1.4 FLOPs).
- **4**: Model with ~0.98 normalized loss (largest size, ~2.0 FLOPs).
### Cross-Referenced Observations
- **Efficiency Trade-off**: Larger models achieve lower normalized loss but require more FLOPs.
- **Baseline vs Mixture-of-Depths**:
- In the left graph, Mixture-of-Depths consistently outperforms Baseline in loss reduction across parameter scales.
- In the right graph, Mixture-of-Depths models align closer to the isoFLOP-optimal baseline, indicating better FLOP efficiency.
### Critical Notes
- Shaded bands in the left graph likely represent confidence intervals or target loss thresholds.
- The isoFLOP-optimal baseline (dashed line) serves as a reference for FLOP efficiency in the right graph.
</details>
Figure 4: isoFLOP analysis. We used the 12.5% capacity MoD variant to perform an isoFLOP analysis for 6e18, 2e19, and 1e20 FLOPs, training models varying in size from 60M to 3B parameters. Depicted on the right are the relative FLOPs per forward pass (normalized to the isoFLOP optimal baseline). There exist MoD variants that are both faster to step (by virtue of requiring fewer FLOPs per forward pass) and better performing than the isoFLOP optimal baseline.
Depicted in figure 4 is an isoFLOP analysis for 6e18, 2e19, and 1e20 total FLOPs. The trend that FLOP-optimal MoD transformers have more parameters than the baseline continues for these larger FLOP budgets. Notably, there exist MoD variants that are appreciably faster to step than the isoFLOP-optimal baseline (measured as steps per second when training on equivalent hardware) while also achieving a lower loss (in figure 4 we depict normalized FLOPs per forward pass rather than wall-clock step time per se, but from our experiments the two are tightly correlated. A similar plot can be produced showing relative wall-clock step times and the same basic trend is present).
Step-wise speed gains come from two sources. First, the FLOP-per-parameter ratio in MoD transformers is less than in the baselines because some proportion of tokens are routed around blocks. So, for a given model size, a transformer requires fewer FLOPs per forward pass. Second, since isoFLOP-optimal MoD transformers are both bigger and achieve a lower loss than the isoFLOP-optimal baseline, there exist smaller MoD variants that perform as well or better than the isoFLOP-optimal baseline, and these variants are faster to step because they are smaller. Altogether, then, there exist MoD transformers that perform as well as isoFLOP-optimal baselines and are faster to step, both because they use fewer FLOPs per parameter and because they use fewer parameters.
Figure 4 also reveals another important finding: the optimal MoD transformer is that which uses as many FLOPs per forward pass as the isoFLOP optimal baseline. This finding allows one to directly predict which sized MoD transformer will perform optimally for a given isoFLOP training budget: one just needs to tune the model size for a given MoD configuration (i.e., capacity and routing frequency) to produce a model that uses as many FLOPs per forward pass as the isoFLOP-optimal baseline, and they will have the optimally performing MoD variant for that configuration. Empirically, we find that it is better to add depth than to add width when adding FLOPs to the model.
Nevertheless, while the FLOPs per forward pass determines which model will be the isoFLOP optimal, it does not predict whether the optimal loss will improve upon the baseline (see figure 3. Namely, the optimal capacity appears to be empirically determinable. We found that it is best to use 12.5% capacity blocks, every other block.
We noticed that MoD transformers had memory savings relative to equivalently sized baseline models at larger sizes, with some variants requiring fewer total devices (i.e., a smaller TPU topology). We did not study this extensively, but we anticipate that as one scales to larger models, these savings could be an important consideration when choosing model variants to train, and could have significant positive effects in regards to the KV cache size during autoregressive sampling.
<details>
<summary>2404.02258v1/weight_analysis.png Details</summary>

### Visual Description
# Technical Document Extraction: Grid and Bar Chart Analysis
## Left Grid Component
### Structure
- **Axes**:
- **X-axis**: `Sequence Step` (horizontal progression)
- **Y-axis**: `Layer` (vertical stacking)
- **Legend**:
- `blue`: "to block"
- `green`: "around block"
### Key Observations
1. **Capacity Indicators**:
- **Top Arrow**: `100% Capacity Interleaved` (points to upper grid section)
- **Bottom Arrow**: `12.5% Capacity` (points to lower grid section)
2. **Pattern**:
- Alternating blue/green squares across sequence steps
- Vertical layering with distinct blue/green distributions
3. **Interleaving**:
- Blue squares (to block) appear in clusters at specific layers
- Green squares (around block) dominate remaining spaces
## Right Bar Chart Component
### Axes
- **X-axis**: `Router Weight` (0.0 to 1.0)
- **Y-axis**: `Count` (logarithmic scale: 10¹ to 10ā“)
### Data Trends
1. **Color Distribution**:
- **Green Bars** ("around block"):
- Concentrated in lower router weight range (0.0ā0.5)
- Peak count: ~10³ at 0.0 router weight
- **Blue Bars** ("to block"):
- Dominant in higher router weight range (0.5ā1.0)
- Peak count: 10ā“ at 1.0 router weight
### Notable Features
- **Logarithmic Scale**: Y-axis emphasizes exponential differences in counts
- **Transition Zone**: Router weight ~0.5 shows gradual shift from green to blue dominance
## Cross-Referenced Insights
1. **Grid-Bar Correlation**:
- Higher router weights (blue "to block") in the bar chart align with dense blue clusters in the grid's upper layers
- Lower router weights (green "around block") correspond to grid's lower layers with sparse blue squares
2. **Capacity Implications**:
- 100% capacity interleaving suggests maximum utilization of "to block" operations at sequence step boundaries
- 12.5% capacity indicates residual "around block" operations in lower utilization states
## Summary
The grid visualizes operational capacity distribution across sequence steps and layers, while the bar chart quantifies router weight distribution for blocking operations. Both components reveal a clear inverse relationship between router weight and "around block" operations, with "to block" operations dominating at higher utilization thresholds.
</details>
Figure 5: Routing analysis. We trained an MoD transformer that interleaved $12.5\%$ capacity routing blocks with full-attention blocks. As expected, the number of tokens that route to (rather than around) a block is sparse in routing blocks, though the network does sometimes preferentially route certain tokens to each block along its depth. This can be seen in the left figure that depicts routing decisions, where we observe a vertical band of dark blue towards the end of the sequence. As expected, the distribution of router weights are as the auxiliary loss dictates: approximately $12.5\%$ of weights are above 0.5 and $87.5\%$ are below (histogram, right).
Figure 5 shows the routing decisions for an MoD transformer trained with interleaved routing blocks. Despite aggressive routing around the blocks, transformers are able to achieve performance improvements relative to baselines. We observe patterns that might warrant further study; namely, some tokens appear to engage each block along the transformerās depth, while others decide to route around blocks whenever possible. Preliminary analyses suggest that the tokens that engage with blocks more frequently are correlated with output predictions that have higher entropy, which possibly corresponds to predictions that are more difficult to make.
4.2 Auto-regressive Evaluation
<details>
<summary>2404.02258v1/autoregressive.png Details</summary>

### Visual Description
# Technical Document Extraction
## Chart 1: Normalized Log-Likelihood vs. Normalized FLOPs
### Axes Labels
- **Y-Axis**: "Normalized log-likelihood (measured on separate test set)"
- Range: 0.990 to 1.010
- Increment: 0.005
- **X-Axis**: "Normalized FLOPs per FFW pass (to the isoFLOP-optimal baseline)"
- Range: 0.4 to 2.0
- Increment: 0.2
### Legend
- **Baseline**: Black dashed line with circular markers
- **MoD using top-k**: Blue dotted line with circular markers
- **MoD using predictor**: Green dotted line with circular markers
### Data Trends
1. **Baseline (Black)**:
- Starts at ~1.005 log-likelihood at 0.4 FLOPs.
- Drops sharply to ~0.990 at 0.8 FLOPs.
- Rises to ~1.005 at 1.8 FLOPs.
- Ends at ~1.005 at 2.0 FLOPs.
2. **MoD using top-k (Blue)**:
- Starts at ~0.998 log-likelihood at 0.4 FLOPs.
- Dips to ~0.990 at 0.8 FLOPs.
- Rises to ~1.005 at 1.8 FLOPs.
- Ends at ~1.005 at 2.0 FLOPs.
3. **MoD using predictor (Green)**:
- Starts at ~1.005 log-likelihood at 0.4 FLOPs.
- Dips to ~0.990 at 0.8 FLOPs.
- Rises to ~1.005 at 1.8 FLOPs.
- Ends at ~1.005 at 2.0 FLOPs.
### Shaded Region
- **X-Axis Range**: 0.8 to 1.0 FLOPs
- **Y-Axis Range**: 0.990 to 1.000
### Key Observations
- All methods converge to similar log-likelihood values at higher FLOPs (1.8ā2.0).
- The "MoD using predictor" method maintains higher log-likelihood than "MoD using top-k" across most FLOP ranges.
- The shaded region highlights a performance dip for all methods between 0.8ā1.0 FLOPs.
---
## Chart 2: Top-k Prediction Accuracy vs. Training Steps
### Axes Labels
- **Y-Axis**: "Top-k prediction accuracy"
- Range: 0.70 to 1.00
- Increment: 0.05
- **X-Axis**: "Training step"
- Range: 0 to 15,000
- Increment: 5,000
### Data Trends
- **Line**: Solid teal line with circular markers
- Starts at ~0.90 accuracy at 0 training steps.
- Drops sharply to ~0.70 accuracy at ~500 training steps.
- Rises to ~0.95 accuracy at ~10,000 training steps.
- Plateaus at ~0.95 accuracy for the remainder of training (10,000ā15,000 steps).
### Key Observations
- Initial accuracy drop suggests a learning phase or overfitting.
- Recovery to ~0.95 accuracy indicates stabilization after ~10,000 steps.
- No further improvement observed beyond 10,000 steps.
---
## Cross-Referenced Legend Consistency
- **Baseline (Black)**: Matches dashed line in Chart 1.
- **MoD using top-k (Blue)**: Matches dotted line in Chart 1.
- **MoD using predictor (Green)**: Matches dotted line in Chart 1.
</details>
Figure 6: Auto-regressive evaluation. Switching from the non-causal top- $k$ routing scheme in training to a causal predictor-based approach during auto-regressive sampling leads to minimal performance degradation. This is perhaps due to the ease of learning this prediction problem, which is upwards of 97% accurate soon into training.
We evaluated MoD variants during auto-regressive sampling (see figure 6). Each model was tested on exactly the same held-out data comprising $256000$ sequences ( $500$ M tokens). When switching from the top- $k$ routing method to the predictor-based routing method we observed little performance degradation. As in the training setting, there exist MoD variants that are better performing than the isoFLOP-optimal baseline, while requiring fewer FLOPs per forward pass. These results suggest that the compute savings offered by MoD transformers should translate beyond the training setting.
4.3 Mixture-of-Depths-and-Experts (MoDE)
<details>
<summary>2404.02258v1/mode.png Details</summary>

### Visual Description
# Technical Document Extraction
## Line Chart Analysis
### Axes
- **X-axis**: "Normalized FLOPs per FFW"
- Markers: 0.2, 0.6, 1.0
- **Y-axis**: "Normalized Loss"
- Range: 0.96 to 1.05
### Legend
1. **Baseline**
- Color: Black
- Line Style: Dashed
- Markers: Present
2. **Mixture-of-Experts (MoE)**
- Color: Blue
- Line Style: Solid
- Markers: Present
3. **Mixture-of-Depths-and-Experts (MoDE)**
- Color: Teal
- Line Style: Dashed
- Markers: Present
### Key Data Points
- **Baseline**
- (0.2, 1.035)
- (0.6, 1.01)
- (1.0, 1.00)
- **MoE**
- (0.2, 1.035)
- (0.6, 0.98)
- (1.0, 0.96)
- **MoDE**
- (0.2, 0.995)
- (0.6, 0.97)
- (1.0, 0.96)
### Observations
- **Baseline** shows a steady decline in loss as FLOPs increase.
- **MoE** demonstrates the steepest improvement, reducing loss by ~7% at 1.0 FLOPs.
- **MoDE** achieves comparable performance to MoE but with slightly higher loss at lower FLOPs.
- Shaded gray area represents the performance gap between Baseline and MoE.
---
## Diagrams Analysis
### 1. Staged MoDE
#### Components
- **Layers**: E1, E2 (highlighted), E3, E4
- **Blocks**:
- **Normalize** (red)
- **Route** (green)
- **Self-attention** (yellow)
- **Flow**:
- Input ā Normalize ā Route ā Self-attention ā Output (`X_{i+1}`)
- Loop weight: `w = 0.65`
### 2. Integrated MoDE
#### Components
- **Layers**: E1, E2, E3, E4 (highlighted)
- **Blocks**:
- **Normalize** (red)
- **Route** (green)
- **Self-attention** (yellow)
- **Flow**:
- Input ā Normalize ā Route ā Self-attention ā Output (`X_{i+1}`)
- No explicit loop weight shown
### Key Differences
- **Staged MoDE**: Explicit loop with `w = 0.65`; E2 highlighted.
- **Integrated MoDE**: No loop weight; E4 highlighted.
---
## Cross-Reference Validation
- **Legend Colors**:
- Black (Baseline) matches dashed black line.
- Blue (MoE) matches solid blue line.
- Teal (MoDE) matches dashed teal line.
- **Data Consistency**: All plotted points align with legend labels.
</details>
Figure 7: Mixture-of-Depths-and-Experts (MoDE). The MoD technique can be implemented alongside MoE (together comprising MoDE models) in two straightforward manners: staged, which first implements MoD machinery prior to MoE machinery, and integrated, which uses one routing operation to funnel tokens to either experts or no-op operations.
The MoD technique can be naturally integrated with MoE models (together comprising MoDE models) in addition to vanilla transformers. In figure 7 we present results showing that the performance improvments offered by MoD compound with those of MoE. We tried two variants: in staged MoDE, which routes tokens around or towards blocks prior to the self-attention step, and integrated MoDE, which implements MoD routing by integrating āno-opā experts among the conventional MLP experts. The former is advantageous because it allows for tokens to skip the self-attention step, while the latter is advantageous because it simplifies the routing machinery. We noticed that implementing MoDE in the integrated manner was distinctly better than simply reducing the capacity of experts in conventional MoE models, and relying on token dropping to implement residual routing. We believe this is because with the integrated MoDE machinery, tokens explicitly learn to choose the residual path around the experts, as opposed to preferring an expert but being dropped when implemented as a capacity reduction.
5 Discussion
Mixture-of-Depths transformers empirically demonstrate that one can improve on isoFLOP-optimal baseline performance with models that use fewer FLOPs per forward pass. This means thatāfor a given training FLOP budgetāwe can train models that are both faster and better performing than their baseline counterparts. Previously, to train models that are both faster and as- or better-performing than isoFLOP-optimal models, one would have to use surplus compute to overtrain smaller models (notably, this overtraining technique is still possible with MoD transformers, and speed gains should compound).
While MoD transformers require fewer FLOPs per forward pass, one cannot forego FLOPs indiscriminately. Rather, it is crucial to use learned routing decisionsāmuch like in Mixture-of-Experts transformersāto determine whether a token should participate in self-attention and the subsequent MLP (requiring FLOPs), or not (saving FLOPs).We can then use any saved FLOPs by, for example, making the model bigger or training it for longer. Our results show that indeed FLOPs may be inefficiently used in vanilla transformer models, and that there may be more efficient ways for them to be expended.
Learned routing mechanisms are sometimes non-causal; that is, information about the future is used to determine a given tokenās routing decision. This is generally true for top-k routing mechanisms, which are useful because they forego the need for auxiliary balancing losses. However, top-k routing mechanisms present difficulties in post-training autoregressive sampling, where it is impossible to use information about future token identities to determine routing decisions. In this work we show that one can successfully use a top-k routing scheme during training, but not require it during later autoregressive sampling. Eiher a simple auxiliary classifier, or auxiliary loss on the router, is sufficient to learn the top- $k$ routing decisions such that it can mimic the top- $k$ decisions during autoregressive sampling, with minimal to no performance degradation.
Intuitively, a token might learn to route around blocks because the prediction being made at that step is easier, and hence, does not require as much compute. However, this strategy is undoubtedly not all that the network learns. If a token does not participate in self-attention at a certain block, then later tokens will also not be able to attend to it. Thus, whether tokens decide to route or not impacts both the current stepās prediction and future predictions via causal self-attention, and how the network balances these effects is guided by their influence on the overall language modeling objective.
This insight opens the door to MoD variants that decouple the routing for queries, keys and values. For example, perhaps a token would prefer to be among the queries, but not the keys, for a given self-attention computation. One can imagine extending this idea even further into the domain of "long-term memory": perhaps there are tokens that would be extremely valuable as keys, regardless of whether it is useful for them to also be among the queries at the step of their occurrence. Learned routing could be a powerful mechanism for deciding which tokens these might be, perhaps funnelling them into a long-term memory buffer that is available during future self-attention. One advantage of such an approach to long-term memory is that tokens decide once, at the moment of "memory encoding", whether they should be retrieved in the future. This is more computationally efficient than performing a full content-based lookup across an entire memory buffer for each step in the future, and could be one step towards drastically increasing the context-length available for making a prediction.
Unlike MoE transformers that route between effectively the same computation (usually MLPs), MoD transformers demonstrate the value of routing among different types of computations. In this work the types were either the conventional transformer block, or a null computation (functionally equivalent to multiplying by zero). However, one can imagine extending this idea further by routing between even more types of computation. For example, perhaps some tokens are routed to "memory lookup" functions, and others are routed to "tool use" functions. In general, the routing machinery we deployed provides a knob for adjusting the types of computations available to the network and their relative cost (in total FLOPs); if one wants to introduce an expensive computation, then this can be offset by setting its capacity to some small amount, and hence, by routing only a small number of tokens to it.
Altogether, MoD transformers are another tool one can use to tune a modelās compute per forward pass (and hence inference time). The machinery used to implement MoD is also generic, and opens the doors to many extensions and integration with other techniques, such as MoE.
References
- Ainslie et al. (2023) J. Ainslie, T. Lei, M. de Jong, S. Ontañón, S. Brahma, Y. Zemlyanskiy, D. Uthus, M. Guo, J. Lee-Thorp, Y. Tay, Y.-H. Sung, and S. Sanghai. Colt5: Faster long-range transformers with conditional computation, 2023.
- Bapna et al. (2020) A. Bapna, N. Arivazhagan, and O. Firat. Controlling computation versus quality for neural sequence models. CoRR, abs/2002.07106, 2020. URL https://arxiv.org/abs/2002.07106.
- Bengio et al. (2016) E. Bengio, P.-L. Bacon, J. Pineau, and D. Precup. Conditional computation in neural networks for faster models, 2016.
- Bengio (2013) Y. Bengio. Deep learning of representations: Looking forward, 2013.
- Bengio et al. (2013) Y. Bengio, N. LƩonard, and A. Courville. Estimating or propagating gradients through stochastic neurons for conditional computation, 2013.
- Bolya et al. (2023) D. Bolya, C.-Y. Fu, X. Dai, P. Zhang, C. Feichtenhofer, and J. Hoffman. Token merging: Your vit but faster, 2023.
- Cho and Bengio (2014) K. Cho and Y. Bengio. Exponentially increasing the capacity-to-computation ratio for conditional computation in deep learning, 2014.
- Dehghani et al. (2018) M. Dehghani, S. Gouws, O. Vinyals, J. Uszkoreit, and Å. Kaiser. Universal transformers. arXiv preprint arXiv:1807.03819, 2018.
- Elbayad et al. (2019) M. Elbayad, J. Gu, E. Grave, and M. Auli. Depth-adaptive transformer. CoRR, abs/1910.10073, 2019. URL http://arxiv.org/abs/1910.10073.
- Fedus et al. (2022) W. Fedus, B. Zoph, and N. Shazeer. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity, 2022.
- Graves (2016) A. Graves. Adaptive computation time for recurrent neural networks. CoRR, abs/1603.08983, 2016. URL http://arxiv.org/abs/1603.08983.
- Guo et al. (2022) M. Guo, J. Ainslie, D. Uthus, S. Ontanon, J. Ni, Y.-H. Sung, and Y. Yang. Longt5: Efficient text-to-text transformer for long sequences, 2022.
- Gupta and Agrawal (2021) M. Gupta and P. Agrawal. Compression of deep learning models for text: A survey, 2021.
- He et al. (2021) J. He, C. Zhou, X. Ma, T. Berg-Kirkpatrick, and G. Neubig. Towards a unified view of parameter-efficient transfer learning. arXiv preprint arXiv:2110.04366, 2021.
- Jernite et al. (2017) Y. Jernite, E. Grave, A. Joulin, and T. Mikolov. Variable computation in recurrent neural networks, 2017.
- Lei et al. (2023) T. Lei, J. Bai, S. Brahma, J. Ainslie, K. Lee, Y. Zhou, N. Du, V. Y. Zhao, Y. Wu, B. Li, Y. Zhang, and M.-W. Chang. Conditional adapters: Parameter-efficient transfer learning with fast inference, 2023.
- Lepikhin et al. (2020) D. Lepikhin, H. Lee, Y. Xu, D. Chen, O. Firat, Y. Huang, M. Krikun, N. Shazeer, and Z. Chen. Gshard: Scaling giant models with conditional computation and automatic sharding. arXiv preprint arXiv:2006.16668, 2020.
- Liu et al. (2021) Z. Liu, Z. Xu, H.-J. Wang, T. Darrell, and E. Shelhamer. Anytime dense prediction with confidence adaptivity. arXiv preprint arXiv:2104.00749, 2021.
- Schuster et al. (2022) T. Schuster, A. Fisch, J. Gupta, M. Dehghani, D. Bahri, V. Q. Tran, Y. Tay, and D. Metzler. Confident adaptive language modeling, 2022.
- Shazeer et al. (2017) N. Shazeer, A. Mirhoseini, K. Maziarz, A. Davis, Q. Le, G. Hinton, and J. Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. arXiv preprint arXiv:1701.06538, 2017.
- Simoulin and CrabbĆ© (2021) A. Simoulin and B. CrabbĆ©. How many layers and why? An analysis of the model depth in transformers. In Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing: Student Research Workshop, pages 221ā228, Online, Aug. 2021. Association for Computational Linguistics. 10.18653/v1/2021.acl-srw.23. URL https://aclanthology.org/2021.acl-srw.23.
- Tay et al. (2020) Y. Tay, M. Dehghani, D. Bahri, and D. Metzler. Efficient transformers: A survey. CoRR, abs/2009.06732, 2020. URL https://arxiv.org/abs/2009.06732.
- Wang et al. (2017) X. Wang, F. Yu, Z. Dou, and J. E. Gonzalez. Skipnet: Learning dynamic routing in convolutional networks. CoRR, abs/1711.09485, 2017. URL http://arxiv.org/abs/1711.09485.
- Zoph et al. (2022) B. Zoph, I. Bello, S. Kumar, N. Du, Y. Huang, J. Dean, N. Shazeer, and W. Fedus. St-moe: Designing stable and transferable sparse expert models, 2022.