## Diagram: FairPFN Pre-training Framework
### Overview
This image is a technical diagram illustrating the pre-training framework for a model called "FairPFN." The diagram is divided into three sequential stages (a, b, c) at the top, which describe the process, and a larger visual flow below that maps these stages to a "Real-world Inference" pipeline. The overall purpose is to show how a Structural Causal Model (SCM) is used to generate data for training a transformer model (FairPFN) to make fair predictions by learning to map biased observables to fair outcomes.
### Components/Axes
The diagram is segmented into several key components:
1. **Top Text Blocks (Process Description):**
* **a) Data generation:** Describes generating an SCM and sampling a dataset `D` with a protected attribute `A`, biased observables `X_b`, and biased outcome `Y_b`. A fair outcome `Y_f` is also sampled by removing outgoing edges from `A`.
* **b) Transformer input:** Describes partitioning the observational dataset `D` into training and validation splits. The transformer uses in-context examples `D_train` to make predictions on the inference set `D_val = (A_val, X_val)`.
* **c) Fair prediction:** Describes the transformer making predictions `Ŷ_f` on the validation set. The pre-training loss is calculated with respect to the fair outcomes `Y_f`, teaching the model the mapping `X_b → Y_f`.
2. **Main Visual Flow (Left to Right):**
* **Left: Structural Causal Model (SCM):** A directed acyclic graph with nodes and arrows.
* **Nodes:** `A0` (blue circle), `U2` (green circle, appears twice), `X1`, `X2`, `X3` (purple circles), `Y_b` (orange circle), `Y_f` (yellow circle with red outline).
* **Arrows:** Black arrows indicate causal relationships. Red arrows originate from `A0` and point to `X2` and `X3`, indicating a biased or protected attribute's influence.
* **Center-Left: Observational Dataset:** A table with columns labeled `A` (blue header), `X1`, `X2`, `X3` (purple headers), and `Y_b` (orange header). The table contains 5 rows of colored cells (blue, purple, orange, light beige) representing data samples.
* **Center: FairPFN:** A large, green, abstract block diagram representing the transformer model. Above it are three small copies of the SCM graph. Below it is a mathematical equation: `p(y_f | x_b, D_b) ∝ ∫_φ p(y_f | x_b, φ) p(D_b | φ) p(φ) dφ`.
* **Center-Right: Pre-training Loss:** Two vertical bars.
* Left bar: Labeled `Ŷ_f` (predicted fair outcome), with a grayscale gradient from white (top) to black (bottom).
* Right bar: Labeled `Y_f` (true fair outcome), with a color gradient from light beige (top) to yellow (bottom).
* **Arrows:** Large, light green arrows show the data flow from the SCM to the dataset, from the dataset to FairPFN, and from FairPFN to the loss calculation. A final green arrow loops from the loss back to the SCM, indicating the pre-training cycle.
3. **Title:** The entire lower section is titled "**FairPFN Pre-training**" in a large, bold, serif font.
### Detailed Analysis
* **SCM Node Relationships:**
* `A0` has direct red arrows to `X2` and `X3`.
* `U2` (top) has arrows to `X1` and `Y_b`.
* `U2` (bottom) has an arrow to `X1`.
* `X1` has arrows to `X2` and `Y_b`.
* `X2` has an arrow to `Y_b`.
* `X3` has an arrow to `Y_b`.
* `Y_f` is shown as a separate node, derived from the SCM by "removing the outgoing edges of A" as per the text in (a).
* **Data Flow & Process Mapping:**
1. The **SCM** (Stage a) is used to generate the **Observational Dataset**.
2. The dataset is fed into the **FairPFN** model (Stages b & c).
3. The model produces predictions `Ŷ_f`.
4. The **Pre-training Loss** is computed by comparing `Ŷ_f` to the true fair outcome `Y_f`.
5. The loss is used to update the model, completing the pre-training loop.
* **Mathematical Equation:** The equation below FairPFN expresses the model's predictive distribution as a proportionality (`∝`) involving an integral over model parameters `φ`. It combines the likelihood of the fair outcome given the biased input and parameters, the likelihood of the biased data given the parameters, and a prior over the parameters.
### Key Observations
* **Color Coding is Systematic:** Blue is consistently used for the protected attribute `A`. Purple is used for the observables `X1, X2, X3`. Orange represents the biased outcome `Y_b`. Yellow represents the fair outcome `Y_f`.
* **Bias Representation:** The red arrows from `A0` in the SCM visually highlight the pathways of bias that the framework aims to remove to achieve `Y_f`.
* **Model Abstraction:** The FairPFN is represented as a generic green block, emphasizing its role as a black-box transformer model within this causal framework.
* **Loss Visualization:** The loss bars visually contrast the model's grayscale predictions (`Ŷ_f`) against the colored true fair targets (`Y_f`), implying the goal is to align the prediction distribution with the fair outcome distribution.
### Interpretation
This diagram outlines a methodology for instilling fairness into a predictive model through causal pre-training. The core idea is to use a known or assumed causal structure (the SCM) to generate synthetic data where the direct effect of a protected attribute (`A`) on the outcome is removed, creating a "fair" target (`Y_f`). A transformer model (FairPFN) is then trained not on the real-world biased outcome (`Y_b`), but on this constructed fair outcome.
The process is "pre-training" because it happens before the model is applied to real-world tasks. By learning the mapping from biased observables (`X_b`) to fair outcomes (`Y_f`) in a controlled, causal environment, the model is intended to internalize a fair prediction rule. The integral in the equation suggests a Bayesian approach, marginalizing over model parameters to make predictions. The framework implies that real-world inference (using the pre-trained model) will then produce fair predictions even when only biased data is available, as the model has learned to "ignore" the biased pathways from `A`. The key assumption is the validity of the initial SCM used for data generation.