# FlashDecoding++: Faster Large Language Model Inference on GPUs
**Authors**:
- & Infinigence-AI
- Qiuli Mao (Tsinghua University)
- & Infinigence-AI
- Kangdi Chen
- & Infinigence-AI
- Xiuhong Li (Peking University)
- Yuhan Dong (Tsinghua University)
- & Infinigence-AI
- Jun Liu (Shanghai Jiao Tong University)
- & Infinigence-AI
- Yu Wang✉ (Tsinghua University)
Abstract
As the Large Language Model (LLM) becomes increasingly important in various domains, the performance of LLM inference is crucial to massive LLM applications. However, the following challenges still remain unsolved in accelerating LLM inference: (1) Synchronized partial softmax update. The softmax operation requires a synchronized update operation among each partial softmax result, leading to $\sim$ 20% overheads for the attention computation in LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices performing GEMM in LLM inference is flat, leading to under-utilized computation and $>$ 50% performance loss after padding zeros in previous designs (e.g., cuBLAS, CUTLASS, etc.). (3) Performance loss due to static dataflow. Kernel performance in LLM depends on varied input data features, hardware configurations, etc. A single and static dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in LLM inference.
We present FlashDecoding++, a fast LLM inference engine supporting mainstream LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++ creatively proposes: (1) Asynchronized softmax with unified max value. FlashDecoding++ introduces a unified max value technique for different partial softmax computations to avoid synchronization. Based on this, the fine-grained pipelining is proposed. (2) Flat GEMM optimization with double buffering. FlashDecoding++ points out that flat GEMMs with different shapes face varied bottlenecks. Then, techniques like double buffering are introduced. (3) Heuristic dataflow with hardware resource adaptation. FlashDecoding++ heuristically optimizes dataflow using different hardware resource (e.g., Tensor Core or CUDA core) considering input dynamics.Due to the versatility of optimizations in FlashDecoding++, FlashDecoding++ can achieve up to 4.86 $×$ and 3.93 $×$ speedup on both NVIDIA and AMD GPUs compared to Hugging Face implementations. FlashDecoding++ also achieves an average speedup of 1.37 $×$ compared to state-of-the-art LLM inference engines on mainstream LLMs. footnotetext: These authors contributed equally to this work. footnotetext: Prof. Guohao Dai is the Chief Scientist at Infinigence-AI, Ke Hong, Jiaming Xu, Qiuli Mao, and Jun Liu are interns at Infinigence-AI. footnotetext: Prof. Guohao Dai and Prof. Yu Wang are the corresponding authors of this paper.
1 Introduction
<details>
<summary>2311.01282v4/x1.png Details</summary>

### Visual Description
# Technical Document Extraction
## Bar Chart: LLM Inference Throughput Comparison
### Axes and Labels
- **Y-axis**: "LLM inference throughput" (Token/s)
- **X-axis**:
- Categories: "SOTA", "w/ FlashDecoding++"
- **Legend**:
- **Color**: Green = NVIDIA Tesla A100, Red = AMD MI210
- **Placement**: Left side, above bars
### Data Points
| Category | AMD MI210 | NVIDIA Tesla A100 |
|-------------------|-----------|-------------------|
| **SOTA** | 38 | 92 |
| **w/ FlashDecoding++** | 83 | 107 |
### Annotations
- **SOTA**: 😢 (sad face emoji)
- **w/ FlashDecoding++**: 👍 (thumbs-up emoji)
---
## Line Charts: Latency Analysis
### Top Chart: Input Length = 1K
#### Axes and Labels
- **X-axis**: "first token latency/ms" (Range: 70–130)
- **Y-axis**: "each token latency/ms" (Range: 5–30)
- **Legend**:
- **Color/Marker**:
- Red circle = FlashDecoding++ (ours)
- Black square = Hugging Face/PyTorch
- Teal triangle = FlashDecoding
- Yellow diamond = DeepSpeed
- Blue cross = OpenPPL
- Gray plus = vllm
- **Placement**: Right side
#### Data Points
| Method | First Token Latency (ms) | Each Token Latency (ms) |
|----------------------|--------------------------|-------------------------|
| FlashDecoding++ | 70 | 5 |
| Hugging Face/PyTorch | 130 | 30 |
| FlashDecoding | 75 | 6 |
| DeepSpeed | 72 | 5.5 |
| OpenPPL | 78 | 7 |
| vllm | 80 | 8 |
#### Trends
- **Arrow**: Diagonal "faster" annotation pointing from (70, 5) to (130, 30).
---
## Bottom Chart: Input Length = 32K
#### Axes and Labels
- **X-axis**: "first token latency/ms" (Range: 3200–5000)
- **Y-axis**: "each token latency/ms" (Range: 30–80)
- **Legend**: Same as top chart (colors/markers).
#### Data Points
| Method | First Token Latency (ms) | Each Token Latency (ms) |
|----------------------|--------------------------|-------------------------|
| FlashDecoding++ | 3200 | 30 |
| Hugging Face/PyTorch | 5000 | 80 |
| FlashDecoding | 3300 | 35 |
| DeepSpeed | 3250 | 32 |
| OpenPPL | 3400 | 40 |
| vllm | 3500 | 45 |
#### Trends
- **Arrow**: Diagonal "faster" annotation pointing from (3200, 30) to (5000, 80).
---
## Key Observations
1. **Bar Chart**:
- NVIDIA Tesla A100 outperforms AMD MI210 in both SOTA and FlashDecoding++ scenarios.
- FlashDecoding++ improves throughput by ~21% for AMD (38 → 83) and ~16% for NVIDIA (92 → 107).
2. **Line Charts**:
- **1K Input**: FlashDecoding++ achieves the lowest latency (5 ms/token) with the fastest first token (70 ms).
- **32K Input**: FlashDecoding++ maintains the lowest latency (30 ms/token) despite higher first token latency (3200 ms).
3. **Legend Consistency**:
- All colors/markers in line charts match the legend (e.g., red circle = FlashDecoding++ in both charts).
4. **Efficiency Trends**:
- FlashDecoding++ scales better with longer input lengths (32K) compared to other methods.
- Hugging Face/PyTorch shows the worst performance in both charts.
</details>
Figure 1: Overview of comparison between FlashDecoding++ and state-of-the-art designs. The results in the figure are reported with Llama2-7B model [1]. The left is with batch size=1 and input length=1K, and TensorRT-LLM and Hugging Face are the SOTA baseline for NVIDIA/AMD according to our experimental results. The right shows the comprehensive comparison of both first token latency and each token latency.
As the Large Language Model (LLM) achieved unprecedented success in various domains [2, 3, 4, 5], the LLM inference workload is skyrocketing. For example, OpenAI reports that GPT-4 inference with 8K context length costs $0.03 per 1K input tokens and $0.06 per 1K output tokens [6]. Currently, OpenAI has 180.5 million users and receives over 10 million queries per day [7]. Consequently, the cost to operate OpenAI’s model like ChatGPT is approximately $7 million per day for the necessary computing hardware [8]. Thus, optimizations on LLM inference performance will have a huge impact considering massive LLM inference scenarios. Many recent works have proposed techniques to accelerate LLM inference tasks, including DeepSpeed [9], FlexGen [10], vLLM [11], OpenPPL [12], FlashDecoding [13], TensorRT-LLM [14], and etc [15, 16, 17, 12].
The LLM inference task generates tokens (e.g., words) from the input sequence autoregressively, and can be organized into two typical phases: the prefill phase and the decode phase. The prefill phase generates the first token by processing the input prompt, and previous research (e.g., FlashAttention [18, 19]) optimizes latency for this phase. The decode phase generates the following tokens sequentially, and many works [9, 10, 11, 15, 13, 14, 20] focus on improving the throughput of generating tokens (i.e., reducing latency of each token). The prefill phase dominates total time for scenarios of long-sequence input or generating short outputs [21, 22], while the decode phase constitutes a significant portion of the time when processing long output sequences [23].
<details>
<summary>2311.01282v4/x2.png Details</summary>

### Visual Description
# Technical Document Extraction: Transformer Attention Mechanism
## Diagram Overview
The image depicts a two-phase transformer model architecture with attention mechanisms, split into **Prefill phase** and **Decode phase**. The diagram uses color-coded matrices and labeled operations to illustrate data flow.
---
### **Legend & Color Coding**
- **Q (Query)**: Blue matrices
- **K (Key)**: Gray matrices
- **V (Value)**: Cyan matrices
- **O (Output)**: Dark blue matrices
- **FFN (Feedforward Network)**: Dark red matrices
- **Attention**: Light blue matrices
- **Softmax**: Light gray matrices
---
### **Prefill Phase**
#### Components & Flow
1. **Input Matrices**
- `Q` (blue), `K` (gray), `V` (cyan) matrices are initialized with weights `W_Q`, `W_K`, `W_V`.
2. **Partial Attention**
- `K` matrix undergoes partial attention (e.g., FlashAttention), reducing computational complexity.
3. **Attention Operation**
- **Step 1**: `Q × K` (GEMM operation)
- *GEMM*: General Matrix Multiply
- *GEMV/Flat GEMM*: Optimized variants
- **Step 2**: `softmax` applied to `Q × K`
- **Step 3**: `Attention × V` (GEMM)
- **Step 4**: `O projection` (GEMV/Flat GEMM)
- **Step 5**: `Feedforward` (GEMV/Flat GEMM)
#### Output
- Final output matrix labeled **"Pacific"** (example input: "What is the largest ocean?").
---
### **Decode Phase**
#### Components & Flow
1. **Cached Matrices**
- `K_cache` (gray) and `V_cache` (cyan) store precomputed keys and values for autoregressive decoding.
2. **Partial Attention**
- `K_cache` undergoes partial attention (e.g., FlashDecoding).
3. **Attention Operation**
- **Step 1**: `Q × K_cache` (GEMM)
- **Step 2**: `softmax` applied to `Q × K_cache`
- **Step 3**: `Attention × V_cache` (GEMM)
- **Step 4**: `O projection` (GEMV/Flat GEMM)
- **Step 5**: `Feedforward` (GEMV/Flat GEMM)
#### Output
- Final output matrix labeled **"Ocean"** (example input: "Pacific").
---
### **Key Observations**
1. **Autoregressive Decoding**:
- The Decode phase reuses cached `K` and `V` matrices to avoid recomputation, enabling efficient token generation.
2. **Optimized Operations**:
- `GEMV/Flat GEMM` replaces standard matrix operations for efficiency.
3. **Partial Attention**:
- Reduces memory and compute costs by focusing on relevant tokens (e.g., FlashAttention/FlashDecoding).
---
### **Spatial Grounding & Validation**
- **Legend Position**: Left-aligned, with colors matching matrix labels.
- **Trend Verification**:
- Prefill phase flows linearly from input matrices to output.
- Decode phase reuses cached data, avoiding redundant computations.
- **Component Isolation**:
- Prefill and Decode phases are distinct but share similar operations (e.g., GEMM, softmax).
---
### **Textual Transcription**
All text in the diagram is in **English**. No non-English content detected.
---
### **Conclusion**
The diagram illustrates a transformer model optimized for efficiency via partial attention and matrix operation optimizations. The Prefill phase processes input tokens, while the Decode phase generates outputs autoregressively using cached data.
</details>
Figure 2: Overview of Large Language Model inference dataflow. We show the dataflow comparison between the prefill phase and the decode phase. The prefill phase mainly involves the GEMM operation, while the decode phase mainly involves the GEMV/Flat GEMM operation.
Figure 2 shows the main dataflow of the LLM inference with one transformer layer for both the prefill phase and the decode phase. A transformer layer can be divided into linear GEMM (General Matrix Multiplication) operations (e.g., K, Q, V, O weight projection and the feedforward) and the attention/softmax computation. For the attention computation, a softmax operation is adopted for a row in the attention matrix. To improve the parallelism, previous designs [18, 13] divide the attention matrices into smaller tiles and rows are also split to compute partial softmax results. A synchronized softmax operation is adopted to update previous partial softmax results when a new partial softmax result is calculated. Such a synchronized partial softmax update accounts for 18.8% for the attention computation of Llama2-7B inference according to our profiling on NVIDIA Tesla A100 GPU with 1024 input length, resulting in the first challenge for accelerating LLM inference. Secondly, the computation resources is under-utilized for the flat GEMM operation during the decode phase. Because the decode phase sequentially generates tokens, the linear GEMM operation tends to be flat-shape (even turning into the GEMV (General Matrix-Vector Multiplication) operation when the batch size is 1). For the small batch size (e.g., 8), previous designs [24, 25] pad the matrix with zeros to perform GEMMs of larger sizes (e.g., 64), leading to over 50% computation under-utilization. Thirdly, the performance of LLM inference suffers from the static dataflow considering input dynamics and hardware configuration. For example, the small batch size makes the decode phase of LLM inference memory-bounded and the large batch size makes it compute-bounded. A single and static dataflow may lead to 50.25% performance loss for GEMMs of different shapes in LLM inference.
To tackle these challenges and enable a faster Large Language Model (LLM) inference, we present FlashDecoding++ in this paper. FlashDecoding++ creatively proposes the following contributions:
- Asynchronized softmax with unified max value. FlashDecoding++ leverages a unified max value for different partial softmax computations. Each partial softmax result can be processed individually without synchronized update.
- Flat GEMM optimization with double buffering. FlashDecoding++ only pads the matrix size to 8 rather than 64 in previous designs for flat-shaped GEMM to improve computation utilization. We point out that flat GEMMs with different shapes face varied bottlenecks, and further improve the kernel performance with techniques like double buffering.
- Heuristic dataflow with hardware resource adaption. FlashDecoding++ takes both input dynamics and hardware configurations into consideration and dynamically applies kernel optimization for the LLM inference dataflow.
<details>
<summary>2311.01282v4/x3.png Details</summary>

### Visual Description
# Technical Document Extraction: Image Analysis
## Section 1: Synchronized Partial Softmax Update (a)
### Diagram Components
1. **Core Operations**:
- `Attention N-1` → `mul1` → `max` → `exp` → `sum` → `mul2` → `Attention N+1`
- **Synchronized Update Path**:
- `mul1` and `mul2` reference operations ② & ④ in (a)
- `sum` operation precedes `Attention N+1`
2. **Key Textual Elements**:
- **Red Text**: "Synchronized partial softmax update"
- **Blue Text**: "Asynchronous softmax with unified max value"
- **Section Label**: "Section 3"
### Workflow
- **Synchronized Path**:
- Input: `Attention N-1`
- Operations: `mul1` → `max` → `exp` → `sum` → `mul2`
- Output: `Attention N+1`
- **Asynchronous Path**:
- Input: `Attention N-1`
- Operations: `mul1` → `exp` → `mul2` → `sum`
- Output: `Attention N+1`
---
## Section 2: Under-Utilized Computation of Flat GEMM (b)
### Diagram Components
1. **Flat GEMM Configuration**:
- **Padding Zeros**:
- Matrix `A` padded with zeros (dashed red box)
- **Direct Computation**:
- `A × B` (direct multiplication)
- **Optimized Path**:
- `load A` → `A × B` → `load A'` → `A' × B`
2. **Key Textual Elements**:
- **Red Text**: "Under-utilized computation of flat GEMM"
- **Blue Text**: "Flat GEMM optimization with double buffering"
- **Section Label**: "Section 4"
### Workflow
- **Baseline (Under-Utilized)**:
- `flat-shape GEMM` → `A × B` (with padding zeros)
- **Optimized**:
- `load A` → `A × B` → `load A'` → `A' × B` (double buffering)
---
## Section 3: Performance Loss to Static Dataflow (c)
### Diagram Components
1. **Static Dataflows**:
- **Static Dataflow 1**:
- `GEMM√` (enabled)
- `Flat GEMM×` (disabled)
- `GEMV×` (disabled)
- **Static Dataflow 2**:
- `GEMM×` (disabled)
- `Flat GEMM√` (enabled)
- `GEMV√` (enabled)
2. **Heuristic Dataflow**:
- All components enabled:
- `GEMM√` (enabled)
- `Flat GEMM√` (enabled)
- `GEMV√` (enabled)
3. **Key Textual Elements**:
- **Red Text**: "Performance loss to static dataflow"
- **Blue Text**: "Heuristic dataflow with hardware resource adaption"
- **Section Label**: "Section 5"
### Workflow
- **Static Dataflow 1**:
- `GEMM` → `static dataflow 1` → `GEMM√`, `Flat GEMM×`, `GEMV×`
- **Static Dataflow 2**:
- `GEMM` → `static dataflow 2` → `GEMM×`, `Flat GEMM√`, `GEMV√`
- **Heuristic Dataflow**:
- `GEMM` → `heuristic dataflow` → `GEMM√`, `Flat GEMM√`, `GEMV√`
---
## Metadata
- **Language**: English (primary), with no additional languages detected.
- **Legend Symbols**:
- `√`: Enabled component
- `×`: Disabled component
- **Spatial Grounding**:
- Legend located at the bottom of Section 3 (c), adjacent to dataflow diagrams.
## Key Trends
1. **Synchronized vs. Asynchronous Softmax**:
- Synchronized path uses `max` and `sum` operations for unified updates.
- Asynchronous path simplifies to `exp` and `sum` with unified max value.
2. **Flat GEMM Optimization**:
- Double buffering reduces padding zeros, improving computation utilization.
3. **Dataflow Configuration**:
- Static dataflows show partial component utilization (mixed `√`/`×`).
- Heuristic dataflow maximizes resource usage (all components enabled).
## Critical Notes
- **No numerical data points** present; focus on component statuses (`√`/`×`) and workflow logic.
- **No axis titles or numerical scales** in the image; diagrams are symbolic representations.
</details>
Figure 3: FlashDecoding++ proposes three solutions for corresponding challenges in Large Language Model inference. (a) FlashDecoding++ proposes the asynchronized softmax with unified max value technique, avoiding synchronized update to previous partial attention results. (b) FlashDecoding++ optimizes flat GEMM by improving computation utilization. (c) FlashDecoding++ heuristically optimizes dataflow.
Because of the versatility of optimizations, the effectiveness of FlashDecoding++ can be proved on both NVIDIA and AMD GPUs. FlashDecoding++ achieves up to 4.86 $×$ and 3.93 $×$ speedup on both NVIDIA and AMD GPUs compared with Hugging Face implementations, respectively. Our extensive results show that FlashDecoding++ achieves an average of 1.37 $×$ speedup compared with FlashDecoding [13], a state-of-the-art LLM inference engine on various LLMs (e.g., Llama2, ChatGLM2, etc.).
The rest of this paper is organized as follows. Section 2 introduces preliminaries of LLMs and related works on LLM inference acceleration. Our three techniques, the asynchronized softmax with unified max value, the flat GEMM optimization with double buffering, and the heuristic dataflow with hardware resource adaption are detailed in Section 3, 4, and 5, respectively. Section 6 presents the evaluation results. Related works on LLM inference are introduced in Section 7, and Section 8 concludes the paper.
2 Background
2.1 LLM Inference Dataflow Overview
The task of LLM inference is to generate tokens from the input sequence, which can be used to complete a sentence or answer a question. An overview of the LLM inference dataflow is shown in Figure 2. As we can see, the LLM inference dataflow can be organized into two typical phases with similar operations: one prefill phase and several decode phases. The prefill phase “understands" the input sequence (i.e., “What is the largest ocean?”). Each token (we set one word as a token in Figure 2 is encoded as an embedding vector, and the input sequence is organized into a matrix. The main output of the prefill phase is a new token, which is predicted to be the next token after the input sequence (i.e., “Pacific" in this figure). The decode phase “generates" the output sequence (i.e., “Pacific”, “Ocean", etc.) The output token of the prefill phase is taken as the input of the decode phase. The decode phase is executed autogressively, and each output token is used as the input token for the next The decode (e.g., “Ocean" is further used as the input).
2.2 Operations in LLM Inference
The main operations in LLM inference are depicted as operation ① to ⑥ in Figure 2, including the linear projection (① and ⑤), the attention (②, ③, and ④), and the feedforward network (⑥). For simplicity, operations like position embedding [26], non-linear activation [27, 28, 29], mask [26], and others are not shown in the figure. Operations in the prefill phase and the decode phase are different in the shape of data. Because only one token (batch size $=$ 1) or few tokens (batch size $>$ 1) are processed at one time, input matrices in the decode phase are flat-shape matrices or even vectors.
Linear Projection. The linear projection performs as the fully connected layer, multiplying the input with weight matrices (i.e., $W_{K},W_{Q},W_{V},W_{O}$ , called $K,Q,V$ projection and $O$ projection). For the prefill phase, the $K,Q,V$ projection generates matrices $K,Q,V$ . For the decode phase, the $K,Q,V$ projection generates three corresponding vectors and concatenated with $K$ and $V$ (i.e., KVcache, yellow and light blue in Figure 2 in the prefill phase.
$$
softmax(Q\times K^{T})\times V \tag{1}
$$
<details>
<summary>2311.01282v4/x4.png Details</summary>

### Visual Description
# Technical Document Extraction: Softmax Function Variants
## Diagram Analysis
### (a) Original Softmax
**Components:**
1. **Input Vector**: `X₁, X₂, ..., X_d`
2. **Max Calculation**: `m(x) = max(x_i)`
3. **Exponential Normalization**: `f(x) = (e^{x₁-m(x)}, ..., e^{x_d-m(x)})`
4. **Log-Sum Calculation**: `l(x) = Σf(x)`
5. **Softmax Output**: `softmax(x) = f(x)/l(x)`
**Annotations:**
- **High Parallelism**: ❌ (Red X)
- **Low Memory**: ❌ (Red X)
- **Synchronization-Free**: ✅ (Green check)
**Flow:**
1. Compute global max `m(x)`
2. Apply exponential normalization
3. Calculate log-sum `l(x)`
4. Normalize to produce softmax
---
### (b) Partial Softmax
**Components:**
1. **Input Vector**: `X₁, X₂, ..., X_d`
2. **Stepwise Calculations**:
- `Calculate m(x'), f(x'), l(x'), softmax(x')`
- `Update softmax(x') with m(x'), f(x'), l(x'), m(x''), f(x''), l(x'')`
3. **Next Partial Vector**: Process next segment
**Annotations:**
- **High Parallelism**: ✅ (Green check)
- **Low Memory**: ✅ (Green check)
- **Synchronization-Free**: ❌ (Red X)
**Flow:**
1. Compute partial max `m(x')`
2. Calculate partial softmax `softmax(x')`
3. Update with subsequent partial vectors
4. Requires synchronization between steps
---
### (c) Partial Softmax with Unified Max Value
**Components:**
1. **Input Vector**: `X₁, X₂, ..., X_d`
2. **Unified Max**: `m(x') = m(x'') = φ` (global max)
3. **Stepwise Calculations**:
- `Calculate f(x'), l(x')`
- `Calculate f(x''), l(x'')`
- `Calculate softmax(x)`
**Annotations:**
- **High Parallelism**: ✅ (Green check)
- **Low Memory**: ✅ (Green check)
- **Synchronization-Free**: ✅ (Green check)
**Flow:**
1. Compute unified max `φ`
2. Calculate partial logits `f(x')`, `f(x'')`
3. Compute partial log-sums `l(x')`, `l(x'')`
4. Final softmax calculation without intermediate synchronization
---
## Key Observations
1. **Parallelism**:
- Original softmax has low parallelism due to sequential max/log-sum calculations
- Partial variants enable high parallelism through segmented computations
2. **Memory Efficiency**:
- Original softmax requires storing all exponentials (`f(x)`)
- Partial variants reduce memory by processing segments incrementally
3. **Synchronization**:
- Original softmax is synchronization-free (no intermediate dependencies)
- Partial softmax requires synchronization between partial updates
- Unified max variant eliminates synchronization through global max precomputation
4. **Mathematical Equivalence**:
- All variants produce identical softmax outputs despite different computation paths
- Unified max variant demonstrates mathematical optimization through φ reuse
## Legend & Spatial Grounding
- **Legend Colors**:
- Red X: Negative property (low parallelism, high memory, synchronization required)
- Green check: Positive property (high parallelism, low memory, synchronization-free)
- **Spatial Analysis**:
- All annotations positioned near corresponding computation steps
- Checkmarks/Xs aligned with property evaluation criteria
## Trend Verification
- **Original Softmax**:
- Sequential dependency chain creates bottleneck
- Memory usage increases with input size (O(d) storage for f(x))
- **Partial Softmax**:
- Memory usage reduced to O(1) per partial segment
- Synchronization overhead increases with number of partial steps
- **Unified Max Variant**:
- Memory usage remains O(1) with global max precomputation
- Complete elimination of synchronization requirements
## Conclusion
The diagrams demonstrate three computational approaches to softmax calculation, with progressive optimizations in parallelism, memory efficiency, and synchronization requirements. The unified max variant achieves optimal performance across all three metrics while maintaining mathematical equivalence to the original implementation.
</details>
Figure 4: Comparison of different softmax computation schemes. (a) Softmax computation for the whole vector. (b) Computing partial softmax for each partial vector, and a synchronized update operation is required for all partial softmax results. (c) Computing partial softmax using a unified max value, and each partial vector is processed individually without synchronized update.
Attention. The attention operation is mainly divided into three operations (② to ④ $Q× K$ , $softmax$ , $Attention× V$ ), as shown in Eq. (1). For $P=Q× K^{T}$ , the softmax operation is performed for each row of the result matrix of $P$ . The detailed softmax computation is shown in Figure 4 (a). The maximum value $m(x)$ is first calculated. The exponent of each element divided by $e^{m(x)}$ , $f(x)$ , is then processed. These exponents are normalized to the summation of all exponents (i.e., $l(x)$ ) to get the softmax result.
Feedforward Network. The feedforward network primarily comprises two fully connected layers. The first one (⑥ $FFN_{1}$ ) expands the feature dimensions to enhance the representational capacity. The second one (⑥ $FFN_{2}$ ) restores the feature dimensions and serves as the output layer.
2.3 Attention Optimization
The softmax operation shown in Figure 4 (a) requires all global data to be calculated and stored before it can proceed. This results in high memory consumption and low parallelism. Latter works propose the partial softmax technique to reduce memory consumption [18, 19] or improve parallelism [13]. Figure 4 (b) shows the diagram of the partial softmax operation. The main idea is to divide the vector $x$ into partial vectors (i.e, $x^{\prime}$ and $x^{\prime\prime}$ ). The partial softmax results of $x^{\prime}$ and $x^{\prime\prime}$ are calculated separately according to Figure 4 (a), and then synchronously updated by each other. The detailed computation of this synchronized update is shown in Equation (2). With the implementation of partial softmax, we can achieve efficient parallelism of computation while reducing memory cost for attention computation.
$$
\displaystyle m(x) \displaystyle=max(m(x^{\prime}),m(x^{\prime\prime})) \displaystyle f(x^{\prime}) \displaystyle=e^{m(x^{\prime})-m(x)}f(x^{\prime}) \displaystyle f(x^{\prime\prime}) \displaystyle=e^{m(x^{\prime\prime})-m(x)}f(x^{\prime\prime}) \displaystyle l(x) \displaystyle=f(x^{\prime})+f(x^{\prime\prime}) \displaystyle softmax([x^{\prime},x^{\prime\prime}]) \displaystyle=[f(x^{\prime}),f(x^{\prime\prime})]\div l(x) \tag{2}
$$
However, since the partial softmax needs to be updated according to other partial softmax results, it unavoidably introduces data synchronization operations. According to our profiling result, such a synchronized update operation leads to 18.8% overheads in the attention computation for Llama2-7B inference on NVIDIA Tesla A100 GPU with 1024 input length.
3 Asynchronized Softmax with Unified Maximum Value
Motivation. The partial softmax operation requires synchronization among different partial vectors, leading to $\sim$ 20% overheads of the attention operation. As is shown in Figure 3 (a), the synchronization is required after the maximum value of the partial vector is calculated. The maximum value is used to update previous partial softmax (i.e., recompute previous attention) results. Thus, to reduce synchronization overheads, the key problem to be solved is how to compute each partial softmax result without requiring results from other partial softmax computation.
Challenge. The reason that synchronization is required lies in that the maximum value of each partial vector is different. The maximum value is used to avoid overflow of the exponent operation ( $f(x)$ in Figure 4 (a)), and exponents are summed ( $l(x)$ in Figure 4 (a)) as the denominator of the softmax operation. Such a non-linear operation on each partial maximum value makes the synchronization among each partial softmax computation unavoidable.
<details>
<summary>2311.01282v4/x5.png Details</summary>

### Visual Description
# Technical Document Extraction: Probability Density Analysis
## Overview
The image contains three comparative probability density charts for language models: **Llama2-7B**, **OPT-6.7B**, and **ChatGLM2-6B**. Each chart visualizes the distribution of a metric (likely log-probabilities or similar) with confidence intervals and key statistical markers.
---
## Chart 1: Llama2-7B
### Key Components
- **Title**: Llama2-7B
- **X-Axis**: Ranges from **-70** to **10** (labeled "Probability Density").
- **Y-Axis**: Labeled "Probability Density" (no explicit scale).
- **Blue Line**: Represents the probability density distribution.
- **Red Dashed Lines**: Mark the **99.99% confidence interval**.
- **Annotations**:
- **Peak**: At **x = 0**, with a value of **99.99%**.
- **Left Tail**: Arrow points to **x = -16.8** (value: `[-16.8]`).
- **Right Tail**: Arrow points to **x = 6.5** (value: `[6.5]`).
### Trends
- The distribution is sharply peaked at **x = 0**, indicating high concentration around the mean.
- The 99.99% confidence interval spans **[-16.8, 6.5]**, with the right tail extending further than the left.
---
## Chart 2: OPT-6.7B
### Key Components
- **Title**: OPT-6.7B
- **X-Axis**: Ranges from **-1400** to **400** (labeled "Probability Density").
- **Y-Axis**: Labeled "Probability Density" (no explicit scale).
- **Blue Line**: Represents the probability density distribution.
- **Red Dashed Lines**: Mark the **99.99% confidence interval**.
- **Annotations**:
- **Peak**: At **x = 0**, with a value of **99.99%**.
- **Left Tail**: Arrow points to **x = -496.8** (value: `[-496.8]`).
- **Right Tail**: Arrow points to **x = 363.5** (value: `[363.5]`).
### Trends
- The distribution is also sharply peaked at **x = 0**, but with a much wider spread.
- The 99.99% confidence interval spans **[-496.8, 363.5]**, showing significant asymmetry (left tail is longer).
---
## Chart 3: ChatGLM2-6B
### Key Components
- **Title**: ChatGLM2-6B
- **X-Axis**: Ranges from **-10** to **15** (labeled "Probability Density").
- **Y-Axis**: Labeled "Probability Density" (no explicit scale).
- **Blue Line**: Represents the probability density distribution.
- **Red Dashed Lines**: Mark the **99.99% confidence interval**.
- **Annotations**:
- **Peak**: At **x = 0**, with a value of **99.99%**.
- **Left Tail**: Arrow points to **x = -10.5** (value: `[-10.5]`).
- **Right Tail**: Arrow points to **x = 13.7** (value: `[13.7]`).
### Trends
- The distribution is sharply peaked at **x = 0**, with a narrower spread compared to OPT-6.7B.
- The 99.99% confidence interval spans **[-10.5, 13.7]**, showing near-symmetry.
---
## Cross-Chart Analysis
1. **Peak Consistency**: All models show a peak at **x = 0** with **99.99%** probability density.
2. **Confidence Intervals**:
- Llama2-7B: **[-16.8, 6.5]** (asymmetric).
- OPT-6.7B: **[-496.8, 363.5]** (highly asymmetric).
- ChatGLM2-6B: **[-10.5, 13.7]** (nearly symmetric).
3. **X-Axis Ranges**:
- Llama2-7B: Narrowest range (**-70 to 10**).
- OPT-6.7B: Widest range (**-1400 to 400**).
- ChatGLM2-6B: Moderate range (**-10 to 15**).
---
## Notes
- **No Legend**: Colors (blue for lines, red for dashed lines) are consistent across charts but lack explicit legend labels.
- **Language**: All text is in English.
- **Missing Elements**: No explicit axis scales (y-axis values are unlabeled), and no secondary metrics are provided.
---
## Conclusion
The charts compare the probability density distributions of three language models, highlighting their central tendencies (peaks at 0) and variability (confidence intervals). OPT-6.7B exhibits the widest spread, while ChatGLM2-6B shows the most symmetry.
</details>
Figure 5: The statistical distribution of $x_{i}$ (elements in the input vectors of softmax) in typical LLMs with different inputs.
Analysis and Insights. According to the formula of softmax computation, the maximum value is used as the scaling factor for both the numerator and the denominator (i.e., $f(x)$ and $l(x)$ in Figure 4 (a)). Our key insight is, the scaling factor can be an arbitrary number rather than using the maximum value mathematically, shown in Equation (3). When we set $\phi=0$ , it becomes the original softmax computation [30].
$$
\displaystyle softmax(x) \displaystyle=\frac{[e^{x_{1}-m(x)},...,e^{x_{d}-m(x)}]}{\sum_{i}e^{x_{i}-m(x)}} \displaystyle=\frac{[e^{x_{1}-\phi},...,e^{x_{d}-\phi}]}{\sum_{i}e^{x_{i}-\phi%
}},\forall\phi\in\mathbb{R} \tag{3}
$$
However, the scaling factor cannot be an arbitrary number considering the overflowing of the exponent computation. For the case where $x_{i}\gg\phi$ , $e^{x_{i}-\phi}$ overflows and cannot be represented using a fix-width floating point number (e.g., float32 for exponent results in current LLM engines). For another case where $x_{i}\ll\phi$ , $e^{x_{i}-\phi}→ 0$ , leading to precision loss. Thus, a proper scaling factor $\phi$ should be carefully selected to avoid the two cases above. Figure 5 shows the statistical distribution of $x_{i}$ (elements in the input vectors of softmax) in typical LLMs with different inputs [31]. Our key insight is, $>99.99\%$ $x_{i}$ are within a certain range. Specifically, for Llama2-7B, we have $-16.8<x_{i}<6.5$ for $>99.99\%$ $x_{i}$ . Because $e^{b-a}$ and $e^{a-b}$ can be represented by a float32 format, we can set $\phi=a$ in Equation (3). For OPT-6.7B, we do not apply the technique in this section because of the large range in Figure 5.
Approach: Asynchronization. Based on the insights above, each partial softmax computation shares a unified maximum value, $\phi$ . After the softmax operation, an inner product operation is executed between the softmax result and a column of $V$ (i.e., $v$ ). Assume that the input vector $x$ can be divided into $p$ partial vectors, $x=[x^{(1)},...,x^{(p)}]$ ( $v=[v^{(1)},...,v^{(p)}]$ correspondingly), we have:
$$
\displaystyle\left<softmax(x),v\right> \displaystyle=\frac{\sum_{i}e^{x_{i}-\phi}\cdot v_{i}}{\sum_{i}e^{x_{i}-\phi}} \displaystyle=\frac{\sum_{j=1}^{p}\sum_{i=1}^{d/p}e^{x_{i}^{(j)}-\phi}\cdot v_%
{i}^{(j)}}{\sum_{j=1}^{p}\sum_{i=1}^{d/p}e^{x_{i}^{(j)}-\phi}} \tag{4}
$$
<details>
<summary>2311.01282v4/x6.png Details</summary>

### Visual Description
# Technical Document Extraction: Softmax Calculation Flowchart
## Overview
The image depicts a two-part computational flowchart for calculating **softmax(x)v<sup>T</sup>** and **softmax(y)v<sup>T</sup>**, with a focus on numerical stability in part (b). Key components include matrix operations, exponentiation, and overflow handling.
---
## Part (a): Calculate softmax(x)v<sup>T</sup>
### Inputs
- **X**: `[x₁=4, x₂=5, x₃=6, x₄=7]` (highlighted in blue)
- **V**: `[V₁, V₂, V₃, V₄]` (highlighted in yellow)
### Steps
1. **Retrieve x₁, x₂ from Q, K**
- Compute exponents: `e^(4-6) = e⁻²`, `e^(5-6) = e⁻¹`
2. **Numerator Calculation**
- `numerator += e⁻²·V₁ + e⁻¹·V₂`
3. **Denominator Calculation**
- `denominator += e⁻² + e⁻¹`
4. **Final Output**
- `numerator ÷ denominator`
---
## Part (b): Calculate softmax(y)v<sup>T</sup>
### Inputs
- **Y**: `[y₁=3, y₂=6, y₃=9, y₄=6]` (highlighted in cyan)
- **V**: `[V₁, V₂, V₃, V₄]` (same as part (a))
### Steps
1. **Retrieve y₁, y₂ from Q, K**
- Compute exponents: `e^(3-6) = e⁻³`, `e^(6-6) = e⁰`
2. **Numerator Calculation**
- `numerator += e⁻³·V₁ + e⁰·V₂`
3. **Denominator Calculation**
- `denominator += e⁻³ + e⁰`
4. **Recomputation Process (Overflow Handling)**
- **Trigger**: `y₃=9` causes `e^(9-6) = e³` (red dashed box)
- **Actions**:
- Compute `softmax₁` for `y₁, y₂`
- Compute `softmax₂` for `y₃, y₄`
- Update `softmax₁` and `softmax₂` iteratively
### Critical Note
- **Overflow Warning**: `9-5=2` (red text) indicates numerical instability due to large exponent values.
---
## Key Observations
1. **Spatial Grounding**:
- **X/Y Values**: Left side (blue/cyan boxes)
- **V Values**: Right side (yellow boxes)
- **Computation Blocks**: Gray rectangles with mathematical operations
2. **Legend & Labels**:
- No explicit legend present; colors denote input types:
- Blue: X values
- Cyan: Y values
- Yellow: V values
3. **Trend Verification**:
- **Part (a)**: Exponents increase monotonically (`e⁻² → e¹`), ensuring stable softmax.
- **Part (b)**: `y₃=9` introduces `e³`, dominating the denominator and causing overflow.
---
## Data Table Reconstruction
| Component | Part (a) | Part (b) |
|-----------|----------|----------|
| **X/Y** | [4,5,6,7] | [3,6,9,6] |
| **V** | [V₁,V₂,V₃,V₄] | [V₁,V₂,V₃,V₄] |
| **Critical Operation** | `e^(x_i-6)·V_i` | `e^(y_i-6)·V_i` + Recomputation |
---
## Conclusion
The flowchart emphasizes numerical stability in softmax calculations, particularly for large input values (e.g., `y₃=9`). Part (b) introduces a recomputation mechanism to mitigate overflow risks, highlighting the importance of exponent normalization in machine learning operations.
</details>
Figure 6: Example of asynchronized partial softmax computation. (a) Each partial softmax result is process individually without the synchronized update. (b) The recomputation process for all parital softmax computation is required when overflow happens.
The inner accumulation in both the numerator and the denominator only take the partial vectors $x^{(j)}$ and $v^{(j)}$ as input, thus they can be processed asynchronously and individually. The outer accumulation is only processed after all partial vectors are processed. As we can see in Figure 4 (c), each $f(x^{(j)})$ is calculated individually, and $softmax(x)$ is calculated after all $x^{(j)}$ is calculated.
Approach: Recomputation. Without loss of generality, we assume $a<x_{i}-\phi<b$ for each $x_{i}$ to ensure precision and avoid overflow. Then, the partial softmax operation is processed individually. However, when $x_{i}-\phi≤ a$ or $x_{i}-\phi≥ b$ , the asynchronized partial softmax computation is terminated for the vector $x$ where $x_{i}$ belongs to. The softmax is then recomputed using the synchronized partial softmax scheme (used in FlashAttention [18, 19] and FlashDecoding [13]) shown in Figure 4 (b). Such a recomputation scheme avoids overflow while introducing negligible overheads based on the statistical data shown in Figure 5.
Example. Figure 6 shows an example of the asynchronized softmax scheme. We set $a=-3,b=3,\phi=6$ . Two vectors $x$ and $y$ are calculated from $Q× K^{T}$ in Equation (1), and are divided into 2 partial vectors. We omit the process from $Q× K^{T}$ to these partial vectors. For each $x_{i}$ , we have $a<x_{i}-\phi<b$ , we process $e^{x_{1}-\phi}· v_{1}+e^{x_{2}-\phi}· v_{2}$ and $e^{x_{1}-\phi}+e^{x_{2}-\phi}$ for the first partial vector of $x$ using two asynchronized threads. Then, each thread moves to the next partial vector for the corresponding computation (i.e., $e^{x_{3}-\phi}· v_{3}+e^{x_{4}-\phi}· v_{4}$ and $e^{x_{3}-\phi}+e^{x_{4}-\phi}$ ). Two threads are synchronized when all partial vectors are processed, and perform the division operation in Equation (4). For $y$ , the first partial vector is processed similarly. However, we find that $y_{3}-\phi>b$ , then two threads are terminated and the first thread recomputes all partial vectors according to the synchronized partial softmax scheme in Figure 4 (b).
4 Flat GEMM Optimization with Double Buffering
Motivation. The process of the decode phase is mainly composed of GEMV (batch size=1) or flat GEMM (batch size $>$ 1) operation. Without loss of generality, GEMV/GEMM operations can be represented using $M,N,K$ , where the sizes of two multiplied matrices are $M× K$ and $K× N$ . Previous LLM inference engines utilize Tensor Core to accelerate these operations using libraries like cuBLAS [24] and CUTLASS [25]. Although modern Tensor Core architectures [32] process GEMM with $M=8$ , these libraries usually tile the $M-$ dimension to 64 to hide memory latency. However, for GEMV or flat GEMM operations in the decode phase, we usually have $M\ll 64$ and the $M-$ dimension is padded to 64 with zeros. The padding leads to under-utilized computation, and the key problem is to process GEMV or flat GEMM operations with smaller tiles (i.e., padding to 8 corresponding to modern Tensor Core architectures) in the $M-$ dimension.
Challenge. Processing GEMV or flat GEMM operations is non-trivial when the $M-$ dimension is padded to 8. The tiling technique in modern libraries like cuBLAS [24] and CUTLASS [25] can only be applied to the $N-$ dimension and the $K-$ dimension. Tiles on the $K-$ dimension are processed sequentially in a GPU block to avoid atomic operations during reduction. Tiling on the $N-$ dimension affects both parallelism and computation/memory ratio, which are both important for GEMV and flat GEMM acceleration.
<details>
<summary>2311.01282v4/x7.png Details</summary>

### Visual Description
# Technical Document Extraction: Heatmap Analysis of GEMM Performance
## Image Description
The image contains two side-by-side heatmaps labeled **B_N**, each representing the performance of a computational kernel (likely General Matrix Multiply, GEMM) across varying values of **N** (rows) and **K** (columns). The heatmaps use color gradients and annotations to indicate performance characteristics and optimal configurations.
---
### **Axis Labels and Markers**
- **X-axis (K values)**:
Labeled as **K**, with discrete values:
`32, 64, 128, 256, 512`
These represent the second dimension of the matrix in GEMM operations.
- **Y-axis (N values)**:
Labeled as **N**, with discrete values:
`1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144`
These represent the first dimension of the matrix in GEMM operations.
- **Panel Titles**:
- Left panel: **B_N with K=4096**
- Right panel: **B_N with K=12288**
These indicate fixed **K** values for each heatmap.
---
### **Legend and Color Coding**
- **Legend**: Located at the bottom of both panels.
- **✓ Symbol**: Represents **B_N with the best flat GEMM performance for a certain N**.
- **Teal (dashed line)**: Indicates **parallelism-bounded** regions.
- **Red (dashed line)**: Indicates **memory-bounded** regions.
---
### **Key Trends and Data Points**
#### **Left Panel (K=4096)**
- **Checkmark (✓) Placement**:
- Diagonal line from **N=1024, K=32** to **N=262144, K=512**.
- Indicates optimal GEMM performance for increasing **N** and **K** values.
- **Regions**:
- **Parallelism-bounded**: Upper-left quadrant (high **N**, low **K**).
- **Memory-bounded**: Lower-right quadrant (low **N**, high **K**).
#### **Right Panel (K=12288)**
- **Checkmark (✓) Placement**:
- Diagonal line from **N=1024, K=32** to **N=262144, K=512**.
- Similar trend to the left panel but shifted toward higher **N** and **K** values.
- **Regions**:
- **Parallelism-bounded**: Upper-left quadrant (high **N**, low **K**).
- **Memory-bounded**: Lower-right quadrant (low **N**, high **K**).
---
### **Spatial Grounding of Legend**
- **Legend Position**: Bottom of both panels.
- **Color Verification**:
- Checkmarks (✓) are **blue**, matching the legend.
- Parallelism-bounded regions are **teal**, matching the legend.
- Memory-bounded regions are **red**, matching the legend.
---
### **Component Isolation**
1. **Header**: Panel titles (**B_N with K=4096** and **B_N with K=12288**).
2. **Main Chart**:
- Heatmap grids with **N** (y-axis) and **K** (x-axis).
- Annotations for **parallelism-bounded**, **memory-bounded**, and **✓** symbols.
3. **Footer**: Legend explaining symbols and color coding.
---
### **Data Table Reconstruction**
| N | K=32 | K=64 | K=128 | K=256 | K=512 | Region/Annotation |
|---------|------|------|-------|-------|-------|----------------------------|
| 1024 | ✓ | ✓ | ✓ | ✓ | ✓ | Parallelism-bounded |
| 2048 | ✓ | ✓ | ✓ | ✓ | ✓ | Parallelism-bounded |
| 4096 | ✓ | ✓ | ✓ | ✓ | ✓ | Parallelism-bounded |
| 8192 | ✓ | ✓ | ✓ | ✓ | ✓ | Parallelism-bounded |
| 16384 | ✓ | ✓ | ✓ | ✓ | ✓ | Parallelism-bounded |
| 32768 | ✓ | ✓ | ✓ | ✓ | ✓ | Memory-bounded |
| 65536 | ✓ | ✓ | ✓ | ✓ | ✓ | Memory-bounded |
| 131072 | ✓ | ✓ | ✓ | ✓ | ✓ | Memory-bounded |
| 262144 | ✓ | ✓ | ✓ | ✓ | ✓ | Memory-bounded |
---
### **Trend Verification**
- **Left Panel (K=4096)**:
- Checkmarks form a diagonal line from **lower-left (N=1024, K=32)** to **upper-right (N=262144, K=512)**.
- Parallelism-bounded regions dominate the upper-left, while memory-bounded regions dominate the lower-right.
- **Right Panel (K=12288)**:
- Checkmarks follow a similar diagonal trend but are concentrated in higher **N** and **K** values.
- Parallelism-bounded and memory-bounded regions align with the left panel’s structure.
---
### **Final Notes**
- The heatmaps illustrate how GEMM performance is constrained by **parallelism** (teal) and **memory** (red) for different **N** and **K** values.
- The **✓** symbols highlight the optimal **B_N** configurations for each **K** value.
- No non-English text is present in the image.
</details>
Figure 7: Normalized flat GEMM performance under different $N-$ dimension sizes and $N-$ dimension tiling sizes. We set $M=8$ and execute GEMM on the NVIDIA Tesla A100 GPU.
Analysis and Insights. Assume that tiling sizes of the $N-$ dimension and the $K-$ dimension are $B_{N}$ and $B_{K}$ , respectively. The computation of each GEMM tile is $2× M× B_{N}× B_{K}$ with total $B=\frac{N× K}{B_{N}× B_{K}}$ GEMM tiles. The total memory access is $(M× B_{K}+B_{N}× B_{K})× B+M× N$ . Thus, the computation/memory ratio is:
$$
\displaystyle\frac{2\times M\times B_{N}\times B_{K}\times B}{(M\times B_{K}+B%
_{N}\times B_{K})\times B+M\times N} \displaystyle= \displaystyle\frac{2\times M\times K}{K+\frac{M\times K}{B_{N}}+M} \tag{5}
$$
On the other hand, the parallelism is $\frac{N}{B_{N}}$ . Thus, the computation/memory ratio shows a positive correlation with $B_{N}$ while the parallelism shows a negative correlation with $B_{N}$ , exposing a contradiction on improving the performance of GEMV or flat GEMM. We depict the normalized performance of the flat GEMM in Figure 7 with different $N$ and $B_{N}$ . Our key insight is, for the smaller $N$ , the flat GEMM is parallelism-bounded. There are 108 Streaming Multiprocessors (SMs) in the NVIDIA Tesla A100. $\frac{N}{B_{N}}$ tends to be a constant (e.g., 128 or 256), which is related to the hardware parallelism (number of SMs). Another key insight is, for the larger $N$ , the flat GEMM becomes memory-bounded. The performance of these cases can be improved by hiding memory access latency.
Approach: Double Buffering. In order to hide memory access latency, we introduce the double buffering technique. for the flat GEMM operation. We allocate two separate buffers in the shared memory. The tile in one buffer performs the GEMM operation, while another buffer loads a new tile for the next GEMM operation. Thus, the computation and the memory access are overlapped. We apply such a technique when $N$ is large in our practice.
Example. Figure 8 shows the example of our flat GEMM optimization with double buffering. For $M<8$ , the $M-$ dimension is first padded to 8 considering modern Tensor Core architectures. Workloads in the $K-$ dimension are processed within one GPU block (e.g., $A_{1},A_{2},A_{3},...$ ), while workloads in the $N-$ dimension are processed in parallel using different GPU blocks (e.g., $C_{1},C_{2},...$ ). We take GPU Block ${}_{1}$ as an example, the first tile for each matrix in the $K-$ dimension (i.e., $A_{1}$ and $B_{1}$ ) is loaded to the left buffer in the shared memory. Then, the GEMM operation is performed between $A_{1}$ and $B_{1}$ . Consequently, $A_{2}$ and $B_{2}$ are loaded to the right buffer in the shared memory. The following tiles are processed similarly according to the double buffering scheme.
<details>
<summary>2311.01282v4/x8.png Details</summary>

### Visual Description
# Technical Diagram Analysis
## Overview
The diagram illustrates a GPU-based matrix multiplication process with shared memory buffers. It shows data flow from input matrices through computational blocks to output storage.
## Key Components
### Matrices
1. **Matrix M** (Input A)
- Elements: A₁, A₂, A₃, ..., A
- Color: Teal (#008080)
- Position: [x=0, y=0] to [x=K, y=0]
2. **Matrix B** (Input B)
- Elements: B₁, B₂, B₃, ..., B_N
- Color: Blue (#0000FF)
- Position: [x=K, y=0] to [x=N, y=0]
3. **Matrix C** (Output)
- Elements: C₁, C₂, ..., C
- Color: Orange (#FFA500)
- Position: [x=N, y=0] to [x=?, y=0]
### GPU Architecture
1. **GPU Block 1**
- **Loading Buffer**: Striped pattern (A₁B₁, A₂B₂, A₃B₃)
- **Computing Buffer**: Striped pattern (A₁B₁', A₂B₂', A₃B₃')
- **Idle Sections**: White blocks labeled "idle"
- **Timeline**: Left-to-right sequence
2. **GPU Block 2**
- **Loading Buffer**: Striped pattern (A₁B₁', A₂B₂', A₃B₃')
- **Computing Buffer**: Striped pattern (A₁B₁'', A₂B₂'', A₃B₃'')
- **Idle Sections**: White blocks labeled "idle"
- **Timeline**: Right-to-left sequence
## Computation Flow
1. **Data Loading Phase**
- Matrices A and B are loaded into shared memory buffers
- Buffer pattern: `A_iB_i` → `A_iB_i'` → `A_iB_i''`
2. **Computation Phase**
- Matrix multiplication occurs in parallel blocks
- Result accumulation: `C = A₁·B₁ + A₂·B₂ + A₃·B₃ + ...`
3. **Output Storage**
- Final results stored in matrix C
- Color transition: Teal → Blue → Orange
## Mathematical Representation
- **C₁ Calculation**:
`C₁ = A₁·B₁ + A₂·B₂ + A₃·B₃ + ...`
- **C₂ Calculation**:
`C₂ = A₁·B₁' + A₂·B₂' + A₃·B₃' + ...`
## Spatial Analysis
- **Legend Position**: Not explicitly shown (assumed top-right)
- **Color Consistency Check**:
- All A elements: Teal (#008080)
- All B elements: Blue (#0000FF)
- All C elements: Orange (#FFA500)
## Trend Verification
- **Data Flow**: Left-to-right progression through GPU blocks
- **Computation Pattern**: Striped buffers indicate active computation phases
- **Idle Periods**: White blocks show non-computational intervals
## Missing Elements
- No explicit numerical data points or heatmap values present
- No secondary y-axis or colorbar legend
- No textual annotations beyond component labels
## Conclusion
This diagram demonstrates a parallel matrix multiplication algorithm optimized for GPU architecture, utilizing shared memory buffers for efficient data loading and computation. The process involves three main phases: data loading, parallel computation, and result storage, with explicit timing relationships between GPU blocks.
</details>
Figure 8: Double buffering for flat GEMM when $N-$ dimension is large. The $M-$ dimension is padded to 8 and not tiled.
5 Heuristic Dataflow with Hardware Resource Adaption
Motivation. Although FlashDecoding++ optimizes the flat GEMM operation in Section 4, it does not cover all operations (even only for GEMMs) in the LLM inference. As mentioned in Figure 2, the shapes of GEMMs in different operations and two phases vary. Thus, the GEMM workload in the LLM inference can be GEMV (batch size=1 for the decode phase), flat GEMM (small batch size for the decode phase and short sequence length for the prefill phase) and conventional GEMM (large batch size or long sequence length for the prefill phase). In order to leverage the powerful computational ability of Tensor Core, current frameworks like FasterTransformer [33] and DeepSpeed [9] tend to utilize the highly optimized GEMM implementation from cuBLAS [24] to deal with different workloads. However, the Tensor Core implementation fails with the GEMV workload. The GEMV workload can be optimized by utilizing CUDA Core in previous designs like FastGEMV [34]. For a Llama2-7B linear layer in the decode phase, the Tensor Core implementation from cuBLAS only achieves 82.15% of the performance of CUDA Core implementation using FastGEMV on an NVIDIA A100 GPU. On the other hand, using CUDA Core to do the projection on a batchsize=4 decoding input only achieves 49.75% performance compared with the Tensor Core implementation. Thus, in order to approach the optimal computation performance, a heuristic dataflow is supposed to be exploited in for different workloads.
Challenge. Although a heuristic dataflow potentially exists in the implementation of different linear workloads, it is challenging to build the mapping from a certain workload to an optimal implementation. In the scenario of LLM inference, there are various factors that influence the implementation performance of linear workloads: (a) Input dynamics. The variety of the batch size and the input sequence length brings dynamic workloads. (b) Model diversity. The linear workload varies with different model structures and sizes. (c) GPU capacities. The relative performance between implementations changes with GPU characteristics, such as memory bandwidth, cache size, and computational ability. (d) Engineering effects. The engineering effort also highly impacts the kernel performance. All these influential factors build a large search space, making it non-trivial to generate an effective mapping between the linear workload and the corresponding optimal implementation.
Analysis and Insights. Although all influential factors form a large search space, the homogeneity of different layers in LLM significantly reduces the search space for operator optimization. Figure 2 shows four linear GEMV/GEMM operations in the prefill phase and the decode phase, i.e., $K,Q,V$ projection, $O$ projection, and two feedforward operations. Each GEMV/GEMM operation can be can be abstracted as a multiplication between an ( $M× K$ )-shaped matrix and a ( $K× N$ )-shaped matrix. Our key insight is, there are only four $[K,N]$ shapes for a certain LLM. Moreover, $M$ is only related to the input sequence length and the batch size for the prefill phase, and the batch size for the decode phase. Figure 9 (a) shows limited shapes of GEMV/GEMM operations in the LLM inference.
<details>
<summary>2311.01282v4/x9.png Details</summary>

### Visual Description
# Technical Document Extraction
## (a) Different shapes of GEMMs in LLM
### Table Structure
| Operation | M | N | K |
|--------------------------|------------------|------------------|------------------|
| **Prefill phase** | | | |
| K, Q, V projection | SeqLen*B | HD*3 | HD |
| O projection | SeqLen*B | HD | HD |
| FFN1 | SeqLen*B | FD | HD |
| FFN2 | SeqLen*B | HD | FD |
| **Decode phase** | | | |
| K, Q, V projection | B | HD*3 | HD |
| O projection | B | HD | HD |
| FFN1 | B | FD | HD |
| FFN2 | B | HD | FD |
### Key Notes
- **Color Coding**:
- Blue: Prefill phase
- Red: Decode phase
- **Footnotes**:
- HD: Hidden dimension size
- FD: Dimension size after first FFN
- B: Batch size
- SeqLen: Input sequence length
- **Highlight**: "Only 4 shapes!" (red text)
## (b) Decision flow
### Flowchart Description
1. **Start**: "For a certain LLM, traverse four [N, K] selections"
2. **First Decision**:
- `Impl.B > Impl.A?`
- **Yes**: `M++` (increment M)
- **No**: Proceed to next decision
3. **Second Decision**:
- `Impl.C > Impl.B?`
- **Yes**: `M++` (increment M)
- **No**: Find `M₂` (final M value)
4. **Termination**: "End"
### Abbreviations
- `ImplA`: FastGEMV
- `ImplB`: Our flat GEMM
- `ImplC`: CUTLASS
## (c) Example of heuristic dataflow with hardware resource adaptation
### Table Structure
| M | Pattern Description | Label Description | [N, K] Dimensions |
|---------|-----------------------------------|--------------------------------------------|-------------------------|
| M=17 | Striped (blue) | Using cuBLAS/CUTLASS... | Not specified |
| M=16 | Striped (blue) | Using cuBLAS/CUTLASS... | Not specified |
| M=9 | Striped (blue) | Using cuBLAS/CUTLASS... | Not specified |
| M=8 | Striped (blue) | Using cuBLAS/CUTLASS... | Not specified |
| M=3 | Dotted (red) | Using our flat GEMM optimization | Not specified |
| M=2 | Dotted (red) | Using our flat GEMM optimization | Not specified |
| M=1 | Solid (blue) | Using GEMV on CUDA Core (e.g., FastGEMV) | Not specified |
### Footnotes
- `[N, K] = [12288, 4096]` (M=17)
- `[N, K] = [4096, 4096]` (M=16)
- `[N, K] = [11008, 4096]` (M=9)
- `[N, K] = [4096, 11008]` (M=1)
### Color Legend
- **Blue**: Prefill phase / cuBLAS/CUTLASS usage
- **Red**: Decode phase / Flat GEMM optimization
- **Solid Blue**: GEMV on CUDA Core
## Spatial Grounding & Trend Verification
1. **Table (a)**:
- All entries follow `[Operation, M, N, K]` format
- Color coding matches phase labels
- No numerical trends (categorical data)
2. **Flowchart (b)**:
- Linear decision tree with two branching points
- No numerical data, only logical conditions
3. **Table (c)**:
- M values decrease from 17 to 1
- Pattern changes from striped → dotted → solid
- [N, K] dimensions vary non-linearly
## Component Isolation
1. **Header**:
- Title: "Different shapes of GEMMs in LLM"
- Subtitle: "Only 4 shapes!" (highlighted)
2. **Main Chart**:
- Table (a) with phase-specific operations
- Flowchart (b) with decision logic
3. **Footer**:
- Table (c) with hardware adaptation examples
- Footnotes explaining abbreviations
## Critical Observations
1. **Hardware Optimization**:
- Different GEMM implementations (FastGEMV, Flat GEMM, CUTLASS) correspond to specific M values
- Resource adaptation shown through [N, K] dimension changes
2. **Phase-Specific Operations**:
- Prefill phase uses larger dimensions (SeqLen*B)
- Decode phase uses batch size (B) with reduced dimensions
3. **Decision Logic**:
- M value selection depends on implementation comparisons
- Final M value determined through sequential comparisons
## Missing Information
- No explicit numerical trends (all data categorical)
- No explicit axis titles beyond table headers
- No explicit legend placement coordinates
## Language Notes
- All text in English
- No non-English content detected
</details>
Figure 9: Heuristic dataflow with hardware resource adaption in FlashDecoding++. (a) Only four $[N,K]$ shapes exist for a certain LLM. (b) The decision flow. We traverse all $[N,K]$ selections and profile the performance of three representative implementations. $M$ is increased to find two inflection points for runtime heuristic dataflow. (c) FlashDecoding++ heuristically utilizes Tensor Core/CUDA Core with the corresponding GEMV/GEMM implementation by referring to a lookup table.
Approach: Decision flow for inflection points. Because only four $[K,N]$ shapes exist for a certain LLM, we use three types of implementations for GEMV/GEMM operations when $M$ varies: FastGEMV for the GEMV and flat GEMM operations (ImplA), our flat GEMM optimization in Section 4 (ImplB), and the CUTLASS [25] libraries optimized for the conventional GEMM (ImplC). Thus, it is important to decide whether applying ImplA or ImplB for a small $M$ , and ImplB or ImplC for a large $M$ . Figure 9 (b) shows the decision flow. FlashDecoding++ profiles the performance of ImplA and ImplB for a certain $M$ , and increases $M$ to find an inflection point $M_{1}$ where the performance of ImplB is better than ImplA. Another inflection point $M_{2}$ is found similarly where the performance of ImplC is better than ImplB. Note that each $[N,K]$ gets its individual $M_{1}$ and $M_{2}$ .
Approach: Heuristic dataflow. For the runtime LLM inference, FlashDecoding++ adopts ImplA using CUDA Core when $M<M_{1}$ , and ImplB/ImplC using Tensor Core when $M_{1}≤ M<M_{2}$ / $M_{2}≤ M$ . Note that the decision flow are executed offline, it does not affect the performance of runtime LLM inference.
Example. Figure 9 (c) shows an example of applying the heuristic dataflow for the Llama2-7B model. Four $[N,K]$ shapes are [12288, 4096] for $K,Q,V$ projection, [4096, 4096] for $O$ projection, [11008, 4096] and [4096, 11008] for FFN. For each $[N,K]$ , the inflection points are found based on the decision flow in Figure 9 (c). Then, a lookup table is formed, and each GEMV/GEMM operation is executed according to corresponding implementations during runtime. In this example, FastGEMV is adopted for the $K,Q,V$ projection when batch size=1 ( $M=1$ ) for the decode phase, and our flat GEMM optimization is applied when batch size=1/input sequence length=8 for FFN ${}_{1}$ ( $M=8$ ).
6 Evaluation
6.1 Experiments Setup
We evaluate the performance of FlashDecoding++ on different GPUs with various Large Language Models. We compare the performance with several state-of-the-art LLM inference engines.
6.1.1 Hardware Platforms
We evaluate the performance of FlashDecoding++ and other LLM engines on both NVIDIA and AMD platforms to make a comprehensive comparison. We choose two different GPUs for each platform: Tesla A100 and RTX3090 for NVIDIA, MI210 and RX7900XTX for AMD. We show the detailed configuration in Table 1.
Table 1: Hardware Platforms
| GPU | Tesla A100 | RTX3090 | MI210 | RX7900XTX |
| --- | --- | --- | --- | --- |
| 80 GB | 24 GB | 64GB | 24GB | |
| CUDA 12.2 | CUDA 11.6 | ROCm 5.7 | ROCm 5.6 | |
| CPU | Intel Xeon | Intel Xeon | AMD EPYC | Intel Core |
| Silver 8358P | Gold 6226R | 7K62 | i9-10940X | |
| 2.60 GHz | 2.90GHz | 2.60GHz | 3.30GHz | |
6.1.2 LLM Engine Baselines
We implement our FlashDecoding++ using the Pytorch-based front-end with the C++ and CUDA backend for NVIDIA GPUs while ROCm for AMD GPUs. We compare the inference performance in both prefill phase and decode phase with the following LLM engine baselines: Hugging Face (HF) [35], vLLM [11], DeepSpeed [9], TensorRT-LLM [14], OpenPPL [12], and FlashAttention2/FlashDecoding [19, 13]. These baselines are introduced in Section 7.
Table 2: Model Configuration
| Llama2-7B | 4096 | 32 | 32 | 4k |
| --- | --- | --- | --- | --- |
| Llama2-13B | 5120 | 40 | 40 | 4k |
| OPT-6.7B | 4096 | 32 | 32 | 2k |
| ChatGLM2-6B | 4096 | 32 | 32 | 32k |
6.1.3 Models
We evaluate the performance of FlashDecoding++ with other LLM inference engines on three typical Large Language Models: Llama2, OPT, and ChatGLM2. Table 2 shows the detailed configuration of these models. Note that there may be several models in one LLM (e.g., Llama2-7B, Llama2-13B) with different configurations (e.g., number of heads and layers).
- Llama2 [1] is a mainstream open-source LLM set released by Meta in 2023. It is a collection of pretrained and fine-tuned generative text models ranging in scale from 7B to 70B parameters.
- OPT [36], is a suite of decoder-only pre-trained transformers ranging from 125M to 175B parameters released by Meta AI.
- ChatGLM2 [37] is an open-source LLM supporting bilingual (Chinese-English) chat.
6.2 Comparison with State-of-the-art
We compare FlashDecoding++ with state-of-the-art LLM inference engines in Figure 10 and Figure 11 on NVIDIA GPUs, Figure 12 and Figure 13 for AMD GPUs. For the decode phase, FlashDecoding++ achieves up to 4.86 $×$ speedup compared with Hugging Face implementations on three LLMs and two GPUs. The average speedup over vLLM, DeepSpeed, TensorRT-LLM, OpenPPL, and FlashDecoding is 1.24 $×$ , 1.44 $×$ , 1.13 $×$ , 1.24 $×$ , and 1.21 $×$ (1.37 $×$ on Tesla A100 compared with FlashDecoding), respectively. For the prefill phase, FlashDecoding++ achieves up to 1.40 $×$ speedup compared with Hugging Face implementations. The average speedup over DeepSpeed, TensorRT-LLM, OpenPPL, FlashAttention2 and FlashDecoding is 1.05 $×$ , 1.06 $×$ , 1.08 $×$ , 1.09 $×$ , and 1.08 $×$ , respectively. We also show the decode results on two AMD GPUs. Currently, only the original Hugging Face implementation can be executed on AMD GPUs as the baseline. FlashDecoding++ achieves up to 2.27 $×$ and 3.93 $×$ compared with the baseline on RX7900XTX and MI210, respectively.
<details>
<summary>2311.01282v4/x10.png Details</summary>

### Visual Description
# Technical Document Extraction: Model Performance Analysis
## 1. **Legend & Key Labels**
- **Legend Position**: Top of the image (spatial coordinates: [x=0, y=0] to [x=1000, y=50]).
- **Legend Entries**:
- `HF`: Light gray bars.
- `vLLM`: Dark gray bars.
- `FlashDecoding`: Blue bars.
- `DeepSpeed`: Dark blue bars.
- `TensorRT-LLM`: Light blue bars.
- `Ours`: Red bars with diamond markers (denoted as "Ours (token/s)" in legend).
## 2. **Axis Titles & Markers**
- **X-Axis (Horizontal)**:
- Label: `batch size = [1, 2, 4, 8]`.
- Tick Marks: `128, 1k, 8k, 32k` (repeated across subplots).
- **Y-Axes (Vertical)**:
- **Left Y-Axis**: Label `Speedup` (scale: 0–6, increments of 1).
- **Right Y-Axis**: Label `Throughput` (scale: 0–1000, increments of 200).
## 3. **Subplot Structure**
Six grouped bar charts (labeled a–f) comparing model performance across datasets and batch sizes. Each subplot has:
- **X-Axis**: Batch sizes (`128, 1k, 8k, 32k`).
- **Y-Axes**:
- Left: Speedup (0–6).
- Right: Throughput (0–1000 or 0–600, depending on subplot).
- **Bars**: Colored by model (per legend).
- **Diamond Markers**: Red diamonds represent "Ours (token/s)" for throughput.
## 4. **Dataset-Specific Subplots**
### (a) Llama2-7B@A100
- **X-Axis**: `128, 1k, 8k, 32k`.
- **Trends**:
- `Ours (token/s)` (red diamonds) shows peak throughput at `8k` batch size (~500 tokens/s), then declines at `32k`.
- `vLLM` (dark gray) has the highest speedup (~2.5x) at `8k` batch size.
### (b) OPT-6.7B@A100
- **X-Axis**: `128, 1k, 8k, 32k`.
- **Trends**:
- `Ours (token/s)` peaks at `8k` (~500 tokens/s), drops at `32k`.
- `TensorRT-LLM` (light blue) achieves ~1.8x speedup at `8k`.
### (c) ChatGLM2-6B@A100
- **X-Axis**: `128, 1k, 2k, 4k`.
- **Trends**:
- `Ours (token/s)` peaks at `4k` (~400 tokens/s).
- `vLLM` (dark gray) shows ~1.5x speedup at `4k`.
### (d) Llama2-7B@3090
- **X-Axis**: `128, 1k, 2k, 4k`.
- **Trends**:
- `Ours (token/s)` peaks at `4k` (~400 tokens/s).
- `DeepSpeed` (dark blue) achieves ~1.2x speedup at `4k`.
### (e) OPT-6.7B@3090
- **X-Axis**: `128, 1k, 2k, 4k`.
- **Trends**:
- `Ours (token/s)` peaks at `4k` (~400 tokens/s).
- `TensorRT-LLM` (light blue) shows ~1.3x speedup at `4k`.
### (f) ChatGLM2-6B@3090
- **X-Axis**: `128, 1k, 2k, 4k`.
- **Trends**:
- `Ours (token/s)` peaks at `4k` (~400 tokens/s).
- `vLLM` (dark gray) achieves ~1.4x speedup at `4k`.
## 5. **Key Observations**
- **Speedup vs. Batch Size**:
- All models generally show increased speedup with larger batch sizes up to a threshold (e.g., `8k` or `4k`), after which performance plateaus or declines.
- **Throughput (Token/s)**:
- `Ours` consistently achieves the highest throughput across datasets and batch sizes, with peaks at mid-range batch sizes (e.g., `8k` for Llama2-7B@A100).
- **Model Efficiency**:
- `vLLM` and `TensorRT-LLM` often outperform other models in speedup for larger batch sizes.
- `FlashDecoding` (blue) shows moderate performance across datasets.
## 6. **Language Notes**
- **Primary Language**: English.
- **No Non-English Text Detected**.
## 7. **Data Table Reconstruction**
| Dataset | Batch Size | Model | Speedup | Throughput (token/s) |
|-----------------------|------------|----------------|---------|----------------------|
| Llama2-7B@A100 | 128 | HF | ~1.2 | ~200 |
| Llama2-7B@A100 | 8k | Ours | ~2.5 | ~500 |
| OPT-6.7B@A100 | 1k | vLLM | ~1.5 | ~300 |
| ChatGLM2-6B@3090 | 4k | Ours | ~1.8 | ~400 |
*Note: Numerical values are inferred from bar heights; exact values not provided in the image.*
## 8. **Spatial Grounding & Color Verification**
- **Legend Colors Match Bars**:
- `HF` (light gray) consistently matches light gray bars across all subplots.
- `Ours` (red diamonds) aligns with red bars in throughput charts.
- **Y-Axis Alignment**:
- Speedup values on the left y-axis correspond to bar heights.
- Throughput values on the right y-axis correspond to red diamond markers.
## 9. **Conclusion**
The image compares model performance (speedup and throughput) across datasets (`Llama2-7B`, `OPT-6.7B`, `ChatGLM2-6B`) and batch sizes. `Ours` (red diamonds) demonstrates superior throughput, while `vLLM` and `TensorRT-LLM` excel in speedup for larger batches. Exact numerical data is not provided but can be inferred visually.
</details>
Figure 10: Speedup of the decode phase on NVIDIA GPUs. Blank bars represent the model cannot be executed (e.g., OpenPPL does not support OPT-6.7B/ChatGLM2-6B, TensorRT-LLM fails to compile the model with $>8$ K input length, and etc.)
<details>
<summary>2311.01282v4/x11.png Details</summary>

### Visual Description
# Technical Document Extraction: Performance Comparison of Language Models
## Chart Overview
The image presents a comparative analysis of language model performance across different configurations, visualized through grouped bar charts with secondary latency markers. Five sub-charts (a)-(e) represent distinct model configurations, each comparing multiple optimization techniques across batch sizes.
---
## Legend & Color Mapping
Legend located at the top of the chart:
- **HF** (gray): Base model performance
- **FlashDecoding** (light blue): Flash decoding optimization
- **TensorRT-LLM** (dark blue): TensorRT-LLM optimization
- **DeepSpeed** (gray): DeepSpeed optimization
- **PPL** (light blue): Perplexity metric
- **Ours-FTL** (red diamond): First token latency marker
---
## Axis Configuration
### Primary Y-Axis (Speedup)
- Scale: 0.0 to 4.0
- Unit: Speedup factor (relative to base model)
- Position: Left side of chart
### Secondary Y-Axis (Latency)
- Scale: Logarithmic (1.E+00 to 1.E+04 ms)
- Unit: Milliseconds
- Position: Right side of chart
### X-Axis (Batch Size)
- Categories: 1k, 8k, 32k
- Sub-category groupings separated by dashed vertical lines
- Position: Bottom of chart
---
## Sub-Chart Analysis
### (a) Llama2-7B@A100
- **Speedup Trends**:
- HF: Peaks at 32k batch size (2.5x)
- FlashDecoding: Consistent 2.0-2.8x across batches
- TensorRT-LLM: 1.5-2.2x range
- DeepSpeed: 1.0-1.8x range
- PPL: 0.8-1.2x range
- **Latency Markers**:
- Ours-FTL: 1.2E+04 ms at 1k, 8.5E+03 ms at 8k, 6.2E+03 ms at 32k
### (b) Llama2-13B@A100
- **Speedup Trends**:
- HF: Max 2.8x at 8k batch
- FlashDecoding: 2.2-2.6x range
- TensorRT-LLM: 1.8-2.4x range
- DeepSpeed: 1.2-1.9x range
- PPL: 0.9-1.3x range
- **Latency Markers**:
- Ours-FTL: 1.8E+04 ms at 1k, 1.1E+04 ms at 8k, 9.3E+03 ms at 32k
### (c) ChatGLM2-6B@A100
- **Speedup Trends**:
- HF: 1.5x at 1k, 1.2x at 8k, 1.0x at 32k
- FlashDecoding: 1.3-1.5x range
- TensorRT-LLM: 1.1-1.4x range
- DeepSpeed: 0.9-1.2x range
- PPL: 0.8-1.1x range
- **Latency Markers**:
- Ours-FTL: 2.5E+03 ms at 1k, 1.8E+03 ms at 8k, 1.5E+03 ms at 32k
### (d) Llama2-7B@3090
- **Speedup Trends**:
- HF: 1.8x at 1k, 1.5x at 8k, 1.2x at 32k
- FlashDecoding: 1.6-1.9x range
- TensorRT-LLM: 1.4-1.7x range
- DeepSpeed: 1.2-1.5x range
- PPL: 1.0-1.3x range
- **Latency Markers**:
- Ours-FTL: 1.2E+04 ms at 1k, 9.8E+03 ms at 8k, 8.1E+03 ms at 32k
### (e) ChatGLM2-6B@3090
- **Speedup Trends**:
- HF: 1.4x at 1k, 1.2x at 8k, 1.0x at 32k
- FlashDecoding: 1.3-1.5x range
- TensorRT-LLM: 1.1-1.4x range
- DeepSpeed: 0.9-1.2x range
- PPL: 0.8-1.1x range
- **Latency Markers**:
- Ours-FTL: 2.1E+03 ms at 1k, 1.6E+03 ms at 8k, 1.3E+03 ms at 32k
---
## Key Observations
1. **Batch Size Impact**:
- Speedup generally decreases with larger batch sizes across all models
- Exceptions: Llama2-7B@A100 shows peak performance at 32k batch
2. **Optimization Effectiveness**:
- FlashDecoding consistently outperforms other optimizations
- Ours-FTL shows inverse relationship with speedup (higher latency = lower speedup)
3. **Latency Patterns**:
- Ours-FTL latency decreases with increasing batch size
- Red diamond markers consistently above other model latencies
4. **Model-Specific Behavior**:
- Llama2 models show higher absolute latency values
- ChatGLM2 models demonstrate better speedup efficiency
---
## Technical Notes
- All latency values follow logarithmic scale (1.E+00 = 1 ms, 1.E+01 = 10 ms, etc.)
- Speedup values represent multiplicative improvement over base model (HF)
- Red diamond markers specifically denote first token latency measurements
- Dashed vertical lines separate batch size groupings for visual clarity
---
## Data Table Reconstruction
| Model Configuration | Batch Size | HF Speedup | FlashDecoding Speedup | TensorRT-LLM Speedup | DeepSpeed Speedup | PPL Speedup | Ours-FTL Latency (ms) |
|---------------------|------------|------------|------------------------|-----------------------|-------------------|-------------|------------------------|
| Llama2-7B@A100 | 1k | 1.2 | 1.8 | 1.5 | 1.0 | 0.9 | 1.2E+04 |
| Llama2-7B@A100 | 8k | 2.5 | 2.8 | 2.2 | 1.8 | 1.1 | 8.5E+03 |
| Llama2-7B@A100 | 32k | 2.0 | 2.6 | 1.9 | 1.5 | 0.8 | 6.2E+03 |
| Llama2-13B@A100 | 1k | 1.0 | 1.5 | 1.3 | 0.9 | 0.7 | 1.8E+04 |
| Llama2-13B@A100 | 8k | 2.8 | 2.6 | 2.4 | 1.9 | 1.3 | 1.1E+04 |
| Llama2-13B@A100 | 32k | 1.5 | 2.2 | 1.7 | 1.4 | 0.9 | 9.3E+03 |
| ChatGLM2-6B@A100 | 1k | 1.5 | 1.3 | 1.1 | 0.9 | 0.8 | 2.5E+03 |
| ChatGLM2-6B@A100 | 8k | 1.2 | 1.5 | 1.4 | 1.2 | 1.1 | 1.8E+03 |
| ChatGLM2-6B@A100 | 32k | 1.0 | 1.2 | 1.0 | 0.8 | 0.7 | 1.5E+03 |
| Llama2-7B@3090 | 1k | 1.8 | 1.6 | 1.4 | 1.2 | 1.0 | 1.2E+04 |
| Llama2-7B@3090 | 8k | 1.5 | 1.9 | 1.7 | 1.5 | 1.3 | 9.8E+03 |
| Llama2-7B@3090 | 32k | 1.2 | 1.7 | 1.5 | 1.3 | 1.1 | 8.1E+03 |
| ChatGLM2-6B@3090 | 1k | 1.4 | 1.3 | 1.1 | 0.9 | 0.8 | 2.1E+03 |
| ChatGLM2-6B@3090 | 8k | 1.2 | 1.5 | 1.4 | 1.2 | 1.1 | 1.6E+03 |
| ChatGLM2-6B@3090 | 32k | 1.0 | 1.2 | 1.0 | 0.8 | 0.7 | 1.3E+03 |
---
## Language Notes
- Primary language: English
- No non-English text detected
- All technical terms and metrics are in English
</details>
Figure 11: Speedup of the prefill phase on NVIDIA GPUs.
<details>
<summary>2311.01282v4/x12.png Details</summary>

### Visual Description
# Technical Document Extraction: Performance Comparison Chart
## 1. Chart Structure and Labels
### Main Axes
- **X-axis**: Batch sizes (repeated for model comparisons)
- Labels: `128`, `1k`, `2k`, `4k` (repeated for each model comparison)
- **Y-axis (Left)**:
- **(a) Llama2-7B**: Speedup (0-3 scale)
- **(b) OPT-6.7B**: Speedup (0-3 scale)
- **Y-axis (Right)**: Throughput (0-400 scale)
### Legend
- **Location**: Top-right corner
- **Components**:
- `■ HuggingFace (PyTorch)` (gray)
- `■ Ours` (blue)
- `◆ Ours (token/s)` (red diamond)
## 2. Data Series and Trends
### Section (a): Llama2-7B
| Batch Size | HuggingFace (PyTorch) | Ours | Ours (token/s) |
|------------|-----------------------|------|----------------|
| 128 | 1.0 | 1.8 | 120 |
| 1k | 1.0 | 1.9 | 150 |
| 2k | 1.0 | 2.0 | 180 |
| 4k | 1.0 | 2.1 | 200 |
| 128 | 1.0 | 1.7 | 140 |
| 1k | 1.0 | 1.8 | 160 |
| 2k | 1.0 | 1.9 | 170 |
| 128 | 1.0 | 1.6 | 130 |
| 1k | 1.0 | 1.7 | 150 |
**Trend Analysis**:
- HuggingFace (PyTorch) shows **flat performance** (constant 1.0 speedup)
- "Ours" demonstrates **increasing speedup** with batch size (1.8→2.1 at 4k)
- "Ours (token/s)" shows **linear throughput growth** (120→200 at 4k)
### Section (b): OPT-6.7B
| Batch Size | HuggingFace (PyTorch) | Ours | Ours (token/s) |
|------------|-----------------------|------|----------------|
| 128 | 1.0 | 1.7 | 110 |
| 1k | 1.0 | 1.8 | 140 |
| 2k | 1.0 | 1.9 | 170 |
| 4k | 1.0 | 2.0 | 200 |
| 128 | 1.0 | 1.6 | 120 |
| 1k | 1.0 | 1.7 | 130 |
| 2k | 1.0 | 1.8 | 160 |
| 128 | 1.0 | 1.5 | 110 |
| 1k | 1.0 | 1.6 | 120 |
**Trend Analysis**:
- Similar flat baseline for HuggingFace (PyTorch)
- "Ours" shows **moderate speedup improvement** (1.7→2.0 at 4k)
- "Ours (token/s)" demonstrates **consistent throughput scaling** (110→200 at 4k)
## 3. Spatial Grounding
- Legend position: `[x=top-right, y=top]`
- Color verification:
- Gray squares = HuggingFace (PyTorch)
- Blue bars = "Ours"
- Red diamonds = "Ours (token/s)"
## 4. Key Observations
1. **Baseline Consistency**: HuggingFace (PyTorch) maintains 1.0 speedup across all configurations
2. **Batch Size Impact**:
- Speedup improves with larger batches for "Ours" models
- Throughput scales linearly with batch size for "Ours (token/s)"
3. **Model Comparison**:
- Llama2-7B shows higher absolute throughput than OPT-6.7B
- OPT-6.7B demonstrates better relative speedup improvement
## 5. Missing Elements
- No textual annotations present in the chart
- No secondary y-axis labels for throughput
- No colorbar present
## 6. Language Analysis
- All text in English
- No non-English content detected
</details>
Figure 12: Speedup of the decode phase on AMD RX7900XTX.
<details>
<summary>2311.01282v4/x13.png Details</summary>

### Visual Description
# Technical Document Extraction: Image Analysis
## Image Description
The image is a composite bar chart with three subplots labeled **(a)**, **(b)**, and **(c)**, comparing the performance of three models across varying batch sizes. Each subplot includes:
- **Left y-axis**: Speedup (logarithmic scale, 0–3).
- **Right y-axis**: Throughput (linear scale, 0–600 for subplots (a) and (c), 0–300 for subplot (b)).
- **X-axis**: Batch sizes (128, 1k, 2k, 8k, 128, 1k, 4k, 128, 1k, 2k, 128, 1k, 4k).
- **Legend**:
- Gray bars: **HuggingFace (PyTorch)**.
- Blue bars: **Ours**.
- Red diamonds: **Ours (token/s)**.
---
## Subplot (a): Llama2-7B
### Labels & Axes
- **X-axis**: Batch sizes (128, 1k, 2k, 8k, 128, 1k, 4k, 128, 1k, 2k, 128, 1k, 4k).
- **Left y-axis**: Speedup (0–3).
- **Right y-axis**: Throughput (0–600).
### Data Points & Trends
| Batch Size | HuggingFace (PyTorch) | Ours | Ours (token/s) |
|------------|-----------------------|------|----------------|
| 128 | ~1.0 | ~2.0 | ~0.5 |
| 1k | ~1.0 | ~2.0 | ~0.5 |
| 2k | ~1.0 | ~2.5 | ~0.6 |
| 8k | ~1.0 | ~3.0 | ~0.7 |
| 128 | ~1.0 | ~1.5 | ~0.8 |
| 1k | ~1.0 | ~2.0 | ~0.9 |
| 4k | ~1.0 | ~2.5 | ~1.0 |
| 128 | ~1.0 | ~1.8 | ~1.2 |
| 1k | ~1.0 | ~2.2 | ~1.5 |
| 2k | ~1.0 | ~2.8 | ~2.0 |
| 4k | ~1.0 | ~3.5 | ~2.5 |
**Trends**:
- **HuggingFace (gray)**: Speedup remains constant (~1.0) across all batch sizes.
- **Ours (blue)**: Speedup increases with batch size (e.g., 2.0 at 128 → 3.5 at 4k).
- **Ours (token/s) (red)**: Speedup increases with batch size but lags behind "Ours" (e.g., 0.5 at 128 → 2.5 at 4k).
- **Throughput**: Scales linearly with batch size (e.g., 200 at 128 → 800 at 8k).
---
## Subplot (b): Llama2-13B
### Labels & Axes
- **X-axis**: Batch sizes (128, 1k, 2k, 8k, 128, 1k, 4k, 128, 1k, 2k, 128, 1k, 4k).
- **Left y-axis**: Speedup (0–3).
- **Right y-axis**: Throughput (0–300).
### Data Points & Trends
| Batch Size | HuggingFace (PyTorch) | Ours | Ours (token/s) |
|------------|-----------------------|------|----------------|
| 128 | ~1.0 | ~2.0 | ~0.5 |
| 1k | ~1.0 | ~2.0 | ~0.5 |
| 2k | ~1.0 | ~2.5 | ~0.6 |
| 8k | ~1.0 | ~3.0 | ~0.7 |
| 128 | ~1.0 | ~1.5 | ~0.8 |
| 1k | ~1.0 | ~2.0 | ~0.9 |
| 4k | ~1.0 | ~2.5 | ~1.0 |
**Trends**:
- **HuggingFace (gray)**: Speedup remains constant (~1.0).
- **Ours (blue)**: Speedup increases with batch size (e.g., 2.0 at 128 → 2.5 at 4k).
- **Ours (token/s) (red)**: Speedup increases with batch size but lags behind "Ours".
- **Throughput**: Scales linearly with batch size (e.g., 200 at 128 → 300 at 4k).
---
## Subplot (c): OPT-6.7B
### Labels & Axes
- **X-axis**: Batch sizes (128, 1k, 2k, 8k, 128, 1k, 4k, 128, 1k, 2k, 128, 1k, 4k).
- **Left y-axis**: Speedup (0–4).
- **Right y-axis**: Throughput (0–600).
### Data Points & Trends
| Batch Size | HuggingFace (PyTorch) | Ours | Ours (token/s) |
|------------|-----------------------|------|----------------|
| 128 | ~1.0 | ~2.0 | ~0.5 |
| 1k | ~1.0 | ~2.0 | ~0.5 |
| 2k | ~1.0 | ~2.5 | ~0.6 |
| 8k | ~1.0 | ~3.0 | ~0.7 |
| 128 | ~1.0 | ~1.5 | ~0.8 |
| 1k | ~1.0 | ~2.0 | ~0.9 |
| 4k | ~1.0 | ~2.5 | ~1.0 |
**Trends**:
- **HuggingFace (gray)**: Speedup remains constant (~1.0).
- **Ours (blue)**: Speedup increases with batch size (e.g., 2.0 at 128 → 2.5 at 4k).
- **Ours (token/s) (red)**: Speedup increases with batch size but lags behind "Ours".
- **Throughput**: Scales linearly with batch size (e.g., 200 at 128 → 600 at 8k).
---
## Legend & Spatial Grounding
- **Legend Position**: Top-center of the image.
- **Color Mapping**:
- Gray: HuggingFace (PyTorch).
- Blue: Ours.
- Red: Ours (token/s).
---
## Key Observations
1. **Speedup**:
- "Ours" consistently outperforms HuggingFace across all batch sizes.
- "Ours (token/s)" shows lower speedup than "Ours" but scales with batch size.
2. **Throughput**:
- Increases linearly with batch size for all models.
- Subplot (a) and (c) have higher throughput ranges (600) compared to subplot (b) (300).
---
## Notes
- The x-axis labels repeat across subplots, likely indicating separate configurations or datasets.
- No textual data tables or non-English content are present.
</details>
Figure 13: Speedup of the decode phase on AMD MI210.
7 Related Works
Large language model inference acceleration has gained significant attention in recent research, with several notable approaches and techniques emerging in the field. DeepSpeed [9] is a comprehensive engine that optimizes both the training and inference phases for LLMs. It achieves robust inference performance through kernel fusion and efficient GPU memory management, with a particular focus on optimizing memory usage for KVcache. vLLM [11] improves GPU memory utilization by efficient memory management techniques and the PageAttention method, leading to increased maximum batch sizes and elevating the upper limit of inference performance. FlashAttention [18, 19] optimizes the self-attention computation process during the prefill phase through improved parallelism and workload distribution. FlashDecoding [13] is an extension of FlashAttention and enhances the parallelism through spliting $K$ and $V$ , supporting efficient self-attention computation for long sequence during the decode phase. FasterTransformer [33] and OpenPPL [12] implement large model inference engines using C++ to reduce overhead resulting from kernels scheduling, compared to Python implementations. They also employ memory management techniques and kernel fusion to achieve efficient LLM inference. TensorRT-LLM [14] is built upon the TensorRT [38] and the FasterTransformer [33] engine (C++) and incorporates cutting-edge open-source technologies such as FlashAttention [18, 19]. Additionally, it enhances its ease of use by providing the Python API.
8 Conclusion
We propose FlashDecoding++, a fast Large Language Model inference engine in this paper. FlashDecoding++ accelerates mainstream LLMs with multiple hardware backend support. FlashDecoding++ proposes three novel designs: the asynchronized softmax with unified max value, the flat GEMM optimization with double buffering, and the heuristic dataflow with hardware resource adaption, achieving up to 4.86 $×$ and 3.93 $×$ speedup on NVIDIA and AMD GPUs compared with Hugging Face implementations. FlashDecoding++ also achieves an average of 1.37 $×$ speedup compared with state-of-the-art LLM inference engines, FlashDecoding, on various LLMs.
References
- [1] Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
- [2] Arun James Thirunavukarasu, Darren Shu Jeng Ting, Kabilan Elangovan, Laura Gutierrez, Ting Fang Tan, and Daniel Shu Wei Ting. Large language models in medicine. Nature medicine, 29(8):1930–1940, 2023.
- [3] Rohan Anil, Andrew M. Dai, Orhan Firat, Melvin Johnson, Dmitry Lepikhin, Alexandre Passos, Siamak Shakeri, Emanuel Taropa, Paige Bailey, Zhifeng Chen, Eric Chu, Jonathan H. Clark, Laurent El Shafey, Yanping Huang, Kathy Meier-Hellstern, Gaurav Mishra, Erica Moreira, Mark Omernick, Kevin Robinson, Sebastian Ruder, Yi Tay, Kefan Xiao, Yuanzhong Xu, Yujing Zhang, Gustavo Hernandez Abrego, Junwhan Ahn, Jacob Austin, Paul Barham, Jan Botha, James Bradbury, Siddhartha Brahma, Kevin Brooks, Michele Catasta, Yong Cheng, Colin Cherry, Christopher A. Choquette-Choo, Aakanksha Chowdhery, Clément Crepy, Shachi Dave, Mostafa Dehghani, Sunipa Dev, Jacob Devlin, Mark Díaz, Nan Du, Ethan Dyer, Vlad Feinberg, Fangxiaoyu Feng, Vlad Fienber, Markus Freitag, Xavier Garcia, Sebastian Gehrmann, Lucas Gonzalez, Guy Gur-Ari, Steven Hand, Hadi Hashemi, Le Hou, Joshua Howland, Andrea Hu, Jeffrey Hui, Jeremy Hurwitz, Michael Isard, Abe Ittycheriah, Matthew Jagielski, Wenhao Jia, Kathleen Kenealy, Maxim Krikun, Sneha Kudugunta, Chang Lan, Katherine Lee, Benjamin Lee, Eric Li, Music Li, Wei Li, YaGuang Li, Jian Li, Hyeontaek Lim, Hanzhao Lin, Zhongtao Liu, Frederick Liu, Marcello Maggioni, Aroma Mahendru, Joshua Maynez, Vedant Misra, Maysam Moussalem, Zachary Nado, John Nham, Eric Ni, Andrew Nystrom, Alicia Parrish, Marie Pellat, Martin Polacek, Alex Polozov, Reiner Pope, Siyuan Qiao, Emily Reif, Bryan Richter, Parker Riley, Alex Castro Ros, Aurko Roy, Brennan Saeta, Rajkumar Samuel, Renee Shelby, Ambrose Slone, Daniel Smilkov, David R. So, Daniel Sohn, Simon Tokumine, Dasha Valter, Vijay Vasudevan, Kiran Vodrahalli, Xuezhi Wang, Pidong Wang, Zirui Wang, Tao Wang, John Wieting, Yuhuai Wu, Kelvin Xu, Yunhan Xu, Linting Xue, Pengcheng Yin, Jiahui Yu, Qiao Zhang, Steven Zheng, Ce Zheng, Weikang Zhou, Denny Zhou, Slav Petrov, and Yonghui Wu. Palm 2 technical report, 2023.
- [4] Jan Clusmann, Fiona R Kolbinger, Hannah Sophie Muti, Zunamys I Carrero, Jan-Niklas Eckardt, Narmin Ghaffari Laleh, Chiara Maria Lavinia Löffler, Sophie-Caroline Schwarzkopf, Michaela Unger, Gregory P Veldhuizen, et al. The future landscape of large language models in medicine. Communications Medicine, 3(1):141, 2023.
- [5] Can Cui, Yunsheng Ma, Xu Cao, Wenqian Ye, and Ziran Wang. Receive, reason, and react: Drive as you say with large language models in autonomous vehicles. arXiv preprint arXiv:2310.08034, 2023.
- [6] OpenAI. Openai pricing. [Online], 2023. https://openai.com/pricing.
- [7] Nerdynav. Up-to-date chatgpt statistics & user numbers [oct 2023]. [Online], 2023. https://nerdynav.com/chatgpt-statistics.
- [8] AFZAL AHMAD DYLAN PATEL. The inference cost of search disruption - large language model cost analysis. [Online], 2023. https://www.semianalysis.com/p/the-inference-cost-of-search-disruption.
- [9] Reza Yazdani Aminabadi, Samyam Rajbhandari, Ammar Ahmad Awan, Cheng Li, Du Li, Elton Zheng, Olatunji Ruwase, Shaden Smith, Minjia Zhang, Jeff Rasley, et al. Deepspeed-inference: enabling efficient inference of transformer models at unprecedented scale. In SC22: International Conference for High Performance Computing, Networking, Storage and Analysis, pages 1–15. IEEE, 2022.
- [10] Ying Sheng, Lianmin Zheng, Binhang Yuan, Zhuohan Li, Max Ryabinin, Beidi Chen, Percy Liang, Christopher Re, Ion Stoica, and Ce Zhang. Flexgen: High-throughput generative inference of large language models with a single gpu. 2023.
- [11] Woosuk Kwon, Zhuohan Li, Siyuan Zhuang, Ying Sheng, Lianmin Zheng, Cody Hao Yu, Joseph Gonzalez, Hao Zhang, and Ion Stoica. Efficient memory management for large language model serving with pagedattention. In Proceedings of the 29th Symposium on Operating Systems Principles, pages 611–626, 2023.
- [12] Sensetime. Openppl: A high-performance deep learning inference platform. [Online], 2023. https://openppl.ai/home.
- [13] Tri Dao, Daniel Haziza, Francisco Massa, and Grigory Sizov. Flash-decoding for long-context inference. [Online], 2023. https://crfm.stanford.edu/2023/10/12/flashdecoding.html.
- [14] Neal Vaidya, Fred Oh, and Nick Comly. Optimizing inference on large language models with nvidia tensorrt-llm, now publicly available. [Online], 2023. https://github.com/NVIDIA/TensorRT-LLM.
- [15] Sensetime. A light and fast inference service for llm. [Online], 2023. https://github.com/ModelTC/lightllm.
- [16] Text generation inference: Fast inference optimize for llms. [Online], 2023. https://github.com/huggingface/text-generation-inference/.
- [17] Mlc llm: Machine learning compilation for large language models. [Online], 2023. https://github.com/mlc-ai/mlc-llm.
- [18] Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
- [19] Tri Dao. Flashattention-2: Faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023.
- [20] Aaron Pham, Chaoyu Yang, Sean Sheng, Shenyang Zhao, Sauyon Lee, Bo Jiang, Fog Dong, Xipeng Guan, and Frost Ming. OpenLLM: Operating LLMs in production, June 2023.
- [21] Zihang Dai, Zhilin Yang, Yiming Yang, Jaime Carbonell, Quoc V Le, and Ruslan Salakhutdinov. Transformer-xl: Attentive language models beyond a fixed-length context. arXiv preprint arXiv:1901.02860, 2019.
- [22] Z Dong, T Tang, L Li, and WX Zhao. A survey on long text modeling with transformers. arxiv 2023. arXiv preprint arXiv:2302.14502.
- [23] Guangxuan Xiao, Yuandong Tian, Beidi Chen, Song Han, and Mike Lewis. Efficient streaming language models with attention sinks. arXiv preprint arXiv:2309.17453, 2023.
- [24] NVIDIA. cublas: Basic linear algebra on nvidia gpus. [Online], 2017. https://developer.nvidia.com/cublas.
- [25] NVIDIA. Cutlass: Cuda templates for linear algebra subroutines. [Online], 2017. https://github.com/NVIDIA/cutlass.
- [26] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
- [27] Vinod Nair and Geoffrey E Hinton. Rectified linear units improve restricted boltzmann machines. In Proceedings of the 27th international conference on machine learning (ICML-10), pages 807–814, 2010.
- [28] Dan Hendrycks and Kevin Gimpel. Gaussian error linear units (gelus). arXiv preprint arXiv:1606.08415, 2016.
- [29] Prajit Ramachandran, Barret Zoph, and Quoc V Le. Searching for activation functions. arXiv preprint arXiv:1710.05941, 2017.
- [30] John Bridle. Training stochastic model recognition algorithms as networks can lead to maximum mutual information estimation of parameters. Advances in neural information processing systems, 2, 1989.
- [31] Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models, 2016.
- [32] NVIDIA. Nvidia tensor core. [Online], 2023. https://www.nvidia.com/en-us/data-center/tensor-cores/.
- [33] NVIDIA. Fastertransformer: About transformer related optimization, including bert, gpt. [Online], 2017. https://github.com/NVIDIA/FasterTransformer.
- [34] Siping Wang. Fastgemv: High-speed gemv kernels. [Online], 2023. https://github.com/wangsiping97/FastGEMV.
- [35] Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue, Anthony Moi, Pierric Cistac, Tim Rault, Remi Louf, Morgan Funtowicz, Joe Davison, Sam Shleifer, Patrick von Platen, Clara Ma, Yacine Jernite, Julien Plu, Canwen Xu, Teven Le Scao, Sylvain Gugger, Mariama Drame, Quentin Lhoest, and Alexander Rush. Transformers: State-of-the-art natural language processing. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, pages 38–45, Online, October 2020. Association for Computational Linguistics.
- [36] Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christopher Dewan, Mona Diab, Xian Li, Xi Victoria Lin, Todor Mihaylov, Myle Ott, Sam Shleifer, Kurt Shuster, Daniel Simig, Punit Singh Koura, Anjali Sridhar, Tianlu Wang, and Luke Zettlemoyer. Opt: Open pre-trained transformer language models, 2022.
- [37] Zhengxiao Du, Yujie Qian, Xiao Liu, Ming Ding, Jiezhong Qiu, Zhilin Yang, and Jie Tang. Glm: General language model pretraining with autoregressive blank infilling. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 320–335, 2022.
- [38] NVIDIA. Nvidia tensorrt: An sdk for high-performance deep learning inference. [Online]. https://developer.nvidia.com/tensorrt.