# Diffusion Causal Models for Counterfactual Estimation
## Diffusion Causal Models for Counterfactual Estimation
Pedro Sanchez
Sotirios A. Tsaftaris The University of Edinburgh
Editors: Bernhard Sch¨ olkopf, Caroline Uhler and Kun Zhang
Figure 1: Counterfactuals on ImageNet 256x256 generated by Diff-SCM. From left to right : a random image sampled from the data distribution and its counterfactuals do ( class ) , corresponding to 'how the image should change in order to be classified as another class?'.
<details>
<summary>Image 1 Details</summary>

### Visual Description
## Photographs with Labels: Four-Panel Image Set
### Overview
The image consists of four distinct photographs arranged horizontally in a single row. Each photograph is accompanied by a text label positioned directly above it. The set appears to be a composite image, likely from a dataset or research paper, showcasing different subjects or actions. The overall composition is a simple, non-interactive grid of four images.
### Components/Axes
The image is segmented into four vertical panels of equal width. Each panel contains:
1. **A Photograph:** A rectangular image depicting a specific scene or object.
2. **A Text Label:** A single line of text in a sans-serif font, centered above its corresponding photograph.
**Labels (from left to right):**
1. `carbonara`
2. `do(cliff)`
3. `do(espresso maker)`
4. `do(waffle iron)`
### Detailed Analysis
**Panel 1 (Leftmost):**
* **Label:** `carbonara`
* **Image Content:** A close-up photograph of a plate of pasta. The dish appears to be spaghetti carbonara, featuring long pasta strands coated in a creamy sauce, mixed with pieces of cured meat (likely guanciale or pancetta) and possibly green peas or herbs. The pasta is served on a white plate. In the background, out of focus, is a metallic object that resembles part of a kitchen appliance or utensil holder.
**Panel 2 (Second from left):**
* **Label:** `do(cliff)`
* **Image Content:** A landscape photograph of a coastal cliff. The foreground shows dry, grassy vegetation and rocky terrain. The mid-ground features a steep, rugged cliff face descending towards the sea. The background shows the ocean with visible white waves crashing against the base of the cliffs under a partly cloudy sky.
**Panel 3 (Third from left):**
* **Label:** `do(espresso maker)`
* **Image Content:** A close-up action shot of an espresso machine in operation. The focus is on the portafilter (the handled component that holds the coffee grounds) from which a stream of espresso is being extracted into a cup below. The coffee is forming a circular, layered pattern in the cup. The machine's metallic group head and handle are prominently visible.
**Panel 4 (Rightmost):**
* **Label:** `do(waffle iron)`
* **Image Content:** A close-up photograph of a waffle iron containing freshly cooked waffles. The waffles have a deep, grid-like pattern and a golden-brown color. They are topped with a generous amount of a dark, chunky substance, possibly a fruit compote, chocolate spread, or savory topping. The open waffle iron's cooking plates are visible around the edges of the waffles.
### Key Observations
1. **Subject Diversity:** The four panels depict highly diverse subjects: a prepared food dish (carbonara), a natural landscape (cliff), a kitchen appliance in use (espresso maker), and another kitchen appliance with food (waffle iron).
2. **Label Syntax:** The labels use two distinct formats. "carbonara" is a simple noun. The other three labels use a `do(object)` format, which strongly suggests an action or interaction with the named object (e.g., the action of making coffee with an espresso maker, or cooking with a waffle iron).
3. **Photography Style:** All images are realistic, color photographs with a shallow depth of field in the food/appliance shots, focusing attention on the main subject. The cliff image has a deeper focus to capture the landscape.
4. **Potential Context:** The combination of a static object label (`carbonara`) with three action-oriented labels (`do(...)`) is notable. This could indicate a dataset for training AI models to recognize both objects and the actions performed with them.
### Interpretation
This composite image is most likely a figure from a technical document, such as a computer vision or machine learning research paper. The `do(...)` notation is a common way in such fields to denote an action class (e.g., from datasets like Something-Something or Charades). The figure likely serves to illustrate example samples from a dataset or the outputs of a model capable of classifying both objects and actions.
The inclusion of "carbonara" without the `do()` prefix might represent a control, a different class of label (pure object recognition), or an error/inconsistency in the figure's labeling scheme. The primary informational content is not numerical data but categorical: it defines four distinct visual concepts. The relationship between the panels is categorical rather than sequential or comparative; they are presented as separate, equally important examples. The key takeaway is the demonstration of a labeling system that distinguishes between a static object ("carbonara") and actions performed with specific tools or in specific environments ("do(cliff)", "do(espresso maker)", "do(waffle iron)").
</details>
## Abstract
We consider the task of counterfactual estimation from observational imaging data given a known causal structure. In particular, quantifying the causal effect of interventions for highdimensional data with neural networks remains an open challenge. Herein we propose Diff-SCM, a deep structural causal model that builds on recent advances of generative energy-based models. In our setting, inference is performed by iteratively sampling gradients of the marginal and conditional distributions entailed by the causal model. Counterfactual estimation is achieved by firstly inferring latent variables with deterministic forward diffusion, then intervening on a reverse diffusion process using the gradients of an anti-causal predictor w.r.t the input. Furthermore, we propose a metric for evaluating the generated counterfactuals. We find that Diff-SCM produces more realistic and minimal counterfactuals than baselines on MNIST data and can also be applied to ImageNet data. Code is available https://github.com/vios-s/Diff-SCM .
## 1. Introduction
The notion of applying interventions in learned systems has been gaining significant attention in causal representation learning (Scholkopf et al., 2021). In causal inference, relationships between variables are directed. An intervention on the cause will change the effect, but not the other way around. This notion goes beyond learning conditional distributions p ( x ( k ) | x ( j ) ) based on the data alone, as in the classical statistical learning framework (Vapnik, 1999). Building causal models implies capturing the underlying physical mechanism that generated the data into a model (Pearl,
PEDRO.SANCHEZ@ED.AC.UK
## SANCHEZ TSAFTARIS
2009). As a result, one should be able to quantify the causal effect of a given action. In particular, when an intervention is applied for a given instance, the model should be able the generate hypothetical scenarios. These are the so-called counterfactuals .
Building causal models that quantify the effect of a given action for a given causal structure and available data is referred to as causal estimation . However, estimating the effect of interventions for high-dimensional data remains an open problem (Pawlowski et al., 2020; Yang et al., 2021). While machine learning is a powerful tool for learning relationships between high-dimensional variables, most causal estimation methods using neural networks (Johansson et al., 2016; Louizos et al., 2017; Shi et al., 2019; Du et al., 2021) are only applied in semi-synthetic low-dimensional datasets (Hill, 2012; Shimoni et al., 2018). Therefore, causal estimation through learning deep neural networks for high-dimensional variables remains a desired quest. We show that we can estimate the effect of interventions by generating counterfactuals on imaging datasets, as illustrated in Fig. 1.
Herein, we leverage recent advances in generative energy based models (EBMs) (Song et al., 2021b; Ho et al., 2020) to devise approaches for causal estimation. This formulation has two key advantages: (i) the stochasticity of the diffusion process relates to uncertainty-aware causal models; and (ii) the iterative sampling can be naturally extended for applying interventions. Additionally, we propose an algorithm for counterfactual inference and a metric for evaluating the results. In particular, we use neural networks that learn to reverse a diffusion process (Ho et al., 2020) via denoising. These models are trained to approximate the gradient of a log-likelihood of a distribution w.r.t. the input. We also employ neural networks that are learned in the anti-causal direction (Sch¨ olkopf et al., 2012; Kilbertus et al., 2018) to sample via the causal mechanisms. We use the gradients of these anti-causal predictors for applying interventions in specific variables during sampling. Counterfactual estimation is possible via a deterministic version of diffusion models (Song et al., 2021a) which recovers manipulable latent spaces from observations. Finally, the counterfactuals are generated iteratively using Markov Chain Monte Carlo (MCMC) algorithms.
In summary, we devise a framework for causal effect estimation with high-dimensional variables based on diffusion models entitled Diff-SCM. Diff-SCM behaves as a structured generative model where one can sample from the interventional distribution as well as estimate counterfactuals. Our contributions: (i) We propose a theoretical framework for causal modeling using generative diffusion models and anti-causal predictors (Sec. 3.2). (ii) We investigate how anti-causal predictors can be used for applying interventions in the causal direction (Sec. 3.3). (iii) We propose an algorithm for counterfactual estimation using Diff-SCM (Sec. 3.4). (iv) We propose a metric term counterfactual latent divergence for evaluating the minimality of the generated counterfactuals (Sec. 5.2). We use this metric to compared our method with the selected baselines and hyperparameter search (Sec. 5.3)
## 2. Background
## 2.1. Generative Energy-Based Models
A family of generative models based on diffusion processes (Sohl-Dickstein et al., 2015; Ho et al., 2020; Song et al., 2021b) has recently gained attention even achieving state-of-the-art image generation quality (Dhariwal and Nichol, 2021).
In particular, Denoising Diffusion Probabilistic Models (DDPMs) (Ho et al., 2020) consist in learning to denoise images that were corrupted with Gaussian noise at different scales. DDPMs are defined in terms of a forward Markovian diffusion process. This process gradually adds Gaussian
noise, with a time-dependent variance β t ∈ [0 , 1] , to a data point x 0 ∼ p data ( x ) . Thus, the latent variable x t , with t ∈ [0 , T ] , is learned to correspond to versions of x 0 perturbed by Gaussian noise following p ( x t | x 0 ) = N ( x t ; √ α t x 0 , (1 -α t ) I ) , where α t := ∏ t j =0 (1 -β j ) and I is the identity matrix.
As such, p ( x t ) = ∫ p data ( x ) p ( x t | x )d x should approximate the data distribution p ( x 0 ) ≈ p data at time t = 0 and a zero centered Gaussian distribution at time t = T . Generative modelling is achieved by learning to reverse this process using a neural network θ trained to denoise images at different scales β t . The denoising model effectively learns the gradient of a log-likelihood w.r.t. the observed variable ∇ x log p ( x ) (Hyv¨ arinen, 2005).
Training. With sufficient data and model capacity, the following training procedure ensures that the optimal solution to ∇ x log p t ( x ) can be found by training θ to approximate ∇ x log p t ( x t | x 0 ) . The training procedure can be formalised as
$$\theta ^ { \ast } = \arg _ { g } \min E _ { t , x _ { 0 } } P ( d t )$$
Inference. Once the model θ is learned using Eq. 1, generating samples consists in starting from x T ∼ N ( 0 , I) and iteratively sampling from the reverse Markov chain following:
$$x ( t - 1 ) = \frac { 1 } { \sqrt { 1 - \beta _ { t } ^ { 2 } } } [ x$$
We note that, in the DDPM setting, z is re-sampled at each iteration. Diffusion models are Markovian and stochastic by nature. As such, they can be defined as a stochastic differential equation (SDE) (Song et al., 2021b). We adopt the time-dependent notation from Song et al. (2021b) as it will be useful for the connection with causal models in Sec. 3.2.
## 2.2. Causal Models
Counterfactuals can be understood from a formal perspective using the causal inference formalism (Pearl, 2009; Peters et al., 2017; Scholkopf et al., 2021). Structural Causal Models (SCM) G := ( S , p U ) consist of a collection S = ( f (1) , f (2) , ...., f ( K ) ) of structural assignments (so-called mechanisms ), defined as
$$x ^ { ( k ) } := f ^ { ( k ) } ( p a ^ { ( k ) } , u ^ { ( k ) } ) .$$
where X = { x (1) , x (2) , ..., x ( K ) } are the known endogenous random variables, pa ( k ) is the set of parents of x ( k ) (its direct causes) and U = { u (1) , u (2) , ..., u ( K ) } are the exogenous variables. The distribution p ( U ) of the exogenous variables represents the uncertainty associated with variables that were not taken into account by the causal model. Moreover, variables in U are mutually independent following the joint distribution:
$$p ( U ) = \sum _ { k = 1 } ^ { K } p ( u ^ { k } ) .$$
These structural equations can be defined graphically as a directed acyclic graph. Vertices are the endogenous variables and edges represent (directional) causal relationships between them. In particular, there is a joint distribution p G ( X ) = ∏ K k =1 p ( x ( k ) | pa ( k ) ) which is Markov related to G . In other words, the SCM G represents a joint distributions over the endogenous variables. A graphical example of a SCM is depicted on the left part of Fig. 2. Finally, SCMs should comply to what is known as Pearl's Causal Hierarchy (see Appendix B for more details).
## 3. Causal Modeling with Diffusion Processes
## 3.1. Problem Statement
In this work, we build a causal model capable of estimating counterfactuals of high-dimensional variables. We will base our work on three assumptions: (i) The SCM is known and the intervention is identifiable. (ii) The variables over which the counterfactuals will be estimated need to contain enough information to recover their causes; i.e. an anti-causal predictor can be trained. (iii) All endogenous variables in the training set are annotated.
Notation. We use x ( k ) t is the k th endogenous random variable in a causal graph G at diffusion time t . x ( k ) t,i is a sample i ∈ [ CF , F ] (F and CF being factual and counterfactual respectively) from x ( k ) t . Whenever t is omitted, it should be considered zero, i.e. the sample is not corrupted with Gaussian noise. an ( k ) for the ancestors, with pa ( k ) ⊂ an ( k ) , and de ( k ) for the descendants of x ( k ) in G .
## 3.2. Diff-SCM: Unifying Diffusion Processes and Causal Models
Figure 2: Illustration of a diffusion process as weakening of causal relationships. Left: Example of a SCM with endogenous variables x ( k ) and respective exogenous variables u ( k ) . Right: The diffusion process weakens the relationship between endogenous variables until they become completely independent at t = T . Arrows with solid lines indicate the causal relationship between variables and direction, while the thickness of the arrow indicates strength of the relation. Note that time t is a fiction used as reference for the diffusion process and is not a causal variable.
<details>
<summary>Image 2 Details</summary>

### Visual Description
## Diagram: Dynamic System with State Variables and External Inputs
### Overview
The image is a technical diagram illustrating a dynamic system composed of state variables (denoted by `x`) and external inputs (denoted by `u`). It is divided into two distinct panels by a vertical line. The left panel shows a static, instantaneous relationship graph. The right panel depicts the temporal evolution of the system's state variables from an initial time `t=0` to a final time `t=T`, while the external inputs are shown as separate, constant entities at the final time.
### Components/Axes
**Left Panel (Static Graph):**
* **Nodes (State Variables):** Three dark gray circles labeled `x^(1)`, `x^(2)`, and `x^(3)`.
* **Nodes (External Inputs):** Three light gray circles labeled `u^(1)`, `u^(2)`, and `u^(3)`.
* **Edges (Directed Relationships):**
* Solid black arrows connect the state variables: `x^(1)` → `x^(3)`, `x^(3)` → `x^(2)`, `x^(2)` → `x^(1)`.
* Dotted black arrows connect external inputs to state variables: `u^(1)` → `x^(1)`, `u^(2)` → `x^(2)`, `u^(3)` → `x^(3)`.
**Right Panel (Temporal Evolution):**
* **Temporal Axis:** An arrow at the top points from left to right, labeled `t = 0` at the start and `t = T` at the end.
* **State Variable Nodes at t=0:** Three dark gray circles in a triangular formation, labeled `x^(1)`, `x^(2)`, and `x^(3)`. They are connected by solid black arrows in a cycle: `x^(1)` → `x^(3)`, `x^(3)` → `x^(2)`, `x^(2)` → `x^(1)`.
* **Ellipses:** The symbol `...` appears twice, indicating intermediate time steps between `t=0`, an intermediate state, and `t=T`.
* **State Variable Nodes at Intermediate Time:** Three dark gray circles labeled `x_t^(1)`, `x_t^(2)`, and `x_t^(3)`. They are connected by solid black arrows in the same cyclic pattern as at `t=0`.
* **State Variable Nodes at t=T:** The state variable nodes are absent from the final panel.
* **External Input Nodes at t=T:** Three light gray circles, positioned separately on the right side, labeled `u^(1)`, `u^(2)`, and `u^(3)`. They are not connected by any arrows in this panel.
### Detailed Analysis
**Spatial Grounding & Component Isolation:**
1. **Header Region (Temporal Axis):** The top of the right panel contains the time progression indicator `t = 0` → `t = T`.
2. **Main Chart Region (Left Panel - Static View):** This is a complete directed graph. The state variables (`x`) form a closed loop (a cycle). Each external input (`u`) has a one-way, dotted connection to a specific state variable.
3. **Main Chart Region (Right Panel - Temporal View):** This panel is segmented into three time phases:
* **Initial State (t=0, Left):** Shows the initial configuration of the state variables and their internal cyclic relationships.
* **Intermediate State (Center):** Shows the state variables at some time `t`, denoted with a subscript `_t`. The internal cyclic structure is preserved.
* **Final State (t=T, Right):** Only the external inputs (`u^(1)`, `u^(2)`, `u^(3)`) are displayed. The state variables are not shown, implying their values at `t=T` are the result of the process but are not explicitly diagrammed here.
**Flow and Relationships:**
* The diagram contrasts two views: a **structural view** (left) showing all components and their connection types, and a **temporal view** (right) showing how the state variables evolve while the external inputs are presented as constant, exogenous factors at the end of the timeline.
* The solid arrows represent internal dynamics or couplings between state variables.
* The dotted arrows represent external influences or control inputs acting on specific state variables.
### Key Observations
* **Consistency of Internal Structure:** The cyclic relationship between `x^(1)`, `x^(2)`, and `x^(3)` is identical in the static graph and at both `t=0` and the intermediate time `t` in the temporal view.
* **Separation of Concerns:** The diagram explicitly separates the system's internal state dynamics (the `x` cycle) from its external drivers (the `u` nodes).
* **Temporal Abstraction:** The right panel abstracts away the specific values of the state variables at `t=T`, focusing instead on the presence of the external inputs. The ellipses (`...`) indicate a continuous or multi-step process between the shown time points.
* **Labeling Convention:** State variables use superscripts `(1), (2), (3)` for identification. The temporal view adds a subscript `_t` to denote the variable's value at an intermediate time.
### Interpretation
This diagram is a conceptual model for a **dynamic system with feedback and external control**. It is likely used in fields like control theory, systems biology, or network dynamics.
* **What it demonstrates:** It illustrates a system where three state variables (`x`) influence each other in a closed loop (feedback cycle). Simultaneously, each state variable is independently influenced by an external input (`u`). The right panel emphasizes that this is a process unfolding over time, driven by these persistent external inputs.
* **Relationship between elements:** The external inputs (`u`) are the exogenous drivers or control signals. The state variables (`x`) are the endogenous, evolving components of the system. The solid arrows define the system's internal architecture, while the dotted arrows define how it is perturbed or controlled from the outside.
* **Notable implication:** The absence of the `x` nodes at `t=T` is significant. It suggests the diagram's purpose is to show the *process* of evolution (from `t=0` through intermediate `t`) under the influence of `u`, rather than to specify the final state. The final state is implied to be the outcome of applying the inputs `u` over the time interval `T` to the initial state via the defined dynamics. The diagram defines the model structure and the input scenario, not the numerical solution.
</details>
SCMs have been associated with ordinary (Mooij et al., 2013; Rubenstein et al., 2018) and stochastic (Sokol and Hansen, 2014; Bongers and Mooij, 2018) differential equations as well as other types of dynamical systems (Blom et al., 2020). In these cases, differential equations are useful for modeling time-dependent problems such as chemical kinetics or mass-spring systems. From the energy-based models perspective, Song et al. (2021b) unify denoising diffusion models (Sohl-Dickstein et al., 2015; Ho et al., 2020) and denoising score models (Song and Ermon, 2019) into a framework based on SDEs. In Song et al. (2021b), SDEs are used for formalising a diffusion process in a continuous manner where a model is learned to reverse the SDE in order to generate images.
Here, we unify the SDE framework with causal models. Diff-SCM models the dynamics of causal variables as an Ito process x ( k ) t , ∀ t ∈ [0 , T ] (Øksendal, 2003; S¨ arkk¨ a and Solin, 2019) going
from an observed endogenous variable x ( k ) 0 = x ( k ) to its respective exogenous noise x ( k ) T = u ( k ) and back. In other words, we formulate the forward diffusion as a gradual weakening of the causal relations between variables of a SCM , as illustrated in Fig. 2.
The diffusion forces the exogenous noise u ( j ) corresponding to a variable x ( j ) of interest to be independent of other u ( i ) , ∀ i = j , following the constraints from Eq. 4. The Brownian motion (diffusion) leads to a Gaussian distribution, which can be seen as a prior. Analogously, the original joint distribution entailed by the SCM p G ( X ) diffuses to independent Gaussian distributions equivalent to p ( U ) . As such, the time-dependent joint distribution p ( X t ) , ∀ t ∈ [0 , T ] have as bounds p ( X T ) = p ( U ) and p ( x 0 ) = p G ( X ) . Note that p ( X t ) refers to time-dependent distribution over all causal variables x ( k ) .
We follow Song et al. (2021b) in defining the diffusion process from Sec. 2.1 in terms of an SDE. Since SDEs are stochastic processes, their solution follows a certain probability distribution instead of a deterministic value. By constraining this distribution to be the same as the distribution p G ( X ) entailed by an SCM G , we can define a deep structural causal model (DSCM) as a set of SDEs (one for each node k ):
$$dx ^ { ( k ) } = - \frac { 1 } { 2 } b _ { t } x ^ { ( k ) } d t + \sqrt { \beta _ { t } } d w , \forall k \in [ 1 , K ] ,$$
Here, w denotes the Wiener process (or Brownian motion). The first part of the SDE ( -1 2 β t x ( k ) ) is known as drift function (S¨ arkk¨ a and Solin, 2019) 1 .
The generative process is the solution of the reverse-time SDE from Eq. 6 in time. This process is done by iteratively updating the exogenous noise x ( k ) T = u ( k ) with the gradient of the data distribution w.r.t. the input variable ∇ x ( k ) t log p ( x ( k ) t ) , until it becomes x ( k ) 0 = x ( k ) with:
$$d x ( t ) = [ - \frac { 1 } { 2 } p _ { t } + p _ { t } \sum _ { i = 1 } ^ { k } x _ { i } ( t ) log p ( x _ { i } )$$
The reverse SDE can, therefore, be considered as the process of strengthening causal relations between variables. More importantly, the iterative fashion of the generative process (reverse SDE) is ideal in a causal framework due to the flexibility of applying interventions. We refer the reader to Song et al. (2021b) for a detailed description and proofs of SDE formulation for score-based diffusion models.
## 3.3. How to Apply Interventions with Anti-Causal Predictors?
An interesting result of Eq. 6 is that one only needs the gradients of the distribution entailed by the SCM p G for sampling. This allows learning of the anti-causal conditional distributions p G -and applying interventions with the causal mechanism. This can be useful when anti-causal learning is more straightforward (Sch¨ olkopf et al., 2012). In these cases, one would train classifiers in the anti-causal direction for each edge and diffusion models for each node (over which one wants to
1. The drift function can potentially be used to define temporal relations between variables as in Rubenstein et al. (2018) and Blom et al. (2020).
## SANCHEZ TSAFTARIS
measure the effect of interventions) in the graph. Then, one might use the gradients of the classifiers and diffusion models to propagate the intervention in the causal direction over the nodes. Following this idea, proposition 1 arises as a result of Eq. 6.
Proposition 1 (Interventions as anti-causal gradient updates) We consider the SCM G and a variable x ( j ) ∈ an ( k ) . The effect observed on x ( k ) caused by an intervention on x ( j ) , p G ( x ( k ) | do ( x ( j ) = x ( j ) )) , is equivalent to solving a reverse-diffusion process for x ( k ) t . Since the sampling process involves taking into account the distribution entailed by G , it is guided by the gradient of an anti-causal predictor w.r.t. the effect when the cause is assigned a specific value:
$$\sum _ { x _ { i } \in P _ { 6 } ^ { ( x ) } } ( x ^ { ( y ) } | x _ { i } ^ { ( z ) } ).$$
Proposition 1 respects the principle of independent causal mechanisms (ICM) 2 (Peters et al., 2017; Sch¨ olkopf et al., 2012). It implies independence between the cause distribution and the mechanism producing the effect distribution. As shown in Eq. 7, sampling with the causal mechanism does not require the distribution of the cause p ( x ( j ) ) (Scholkopf et al., 2021).
## 3.4. Counterfactual Estimation with Diff-SCM
Apowerful consequence of building causal models, following Pearl's Causal Hierarchy , is the estimation of counterfactuals. Counterfactuals are hypothetical scenarios for a given factual observation under a local intervention. Estimation of counterfactuals differentiates of sampling from an interventional distribution because the changes are applied for a given observation. As detailed in Pearl (2016), sec. 4.2.4, counterfactual estimation requires three steps: (i) abduction of exogenous noise - forward diffusion with DDIM algorithm (Song et al., 2021a) following Alg. 3 in Appendix D; (ii) action - graph mutilation by erasing the edges between the intervened variable and its parents; (iii) prediction - reverse diffusion controlled by the gradients of an anti-causal classifier.
Here, we are interested in estimating x ( k ) CF based on the observed (factual) x ( k ) F for the random variable x ( k ) after assigning a value x ( j ) CF to x ( j ) ∈ an ( k ) , i.e. applying an intervention do ( x ( j ) = x ( j ) CF ) . It's equivalent to sample from counterfactual distribution p G ( x ( k ) | do ( x ( j ) = x ( j ) CF ); x ( k ) = x ( k ) F ) . We will consider a setting where only x ( j ) and x ( k ) are present in the graph as a simplifying assumption for Alg. 1. Considering only two variables removes the need for the graph mutilation explained above. It is also the setting used in our experiments. We will leave an extension to more complex SCMs for future work. We detail in Alg. 1 how abduction of exogenous noise and prediction is done.
Abduction of Exogenous Noise. The first step for estimating a counterfactual is the abduction of exogenous noise. Note from Eq. 3 that the value of a causal variable depends both on its parents and on its respective exogenous noise. From a deep learning perspective (Pawlowski et al., 2020), one might consider the exogenous u ( k ) an inferred latent variable. The prior p ( u ( k ) ) of u ( k ) in Diff-SCM is a Gaussian as detailed in Sec. 3.2.
With diffusion models, abduction can be done with a derivation done by Song et al. (2021a) and Song et al. (2021b). Both works make a connection between diffusion models and neural ODEs (Chen et al., 2018). They show that one can obtain a deterministic inference system while training
2. The principle states that 'The causal generative process of a system's variables is composed of autonomous modules that do not inform or influence each other.'
with a diffusion process, which is stochastic by nature. This formulation allows the process to be invertible by recovering a latent space u ( k ) by performing the forward diffusion with the learned model. The algorithm for recovering u ( k ) is highlighted as the first box in Alg. 1.
Prediction under Intervention. Once the abduction of exogenous noise u ( k ) is done for a given factual observation x ( k ) F , counterfactual estimation consists in applying an intervention in the reverse diffusion process with the gradients of an anti-causal predictor. In particular, we use the formulation of guided DDIM from Dhariwal and Nichol (2021) which forms the second part of Alg. 1.
Controlling the Intervention. There are three main factors contributing for the counterfactual estimation in Alg. 1: (i) The inferred u ( k ) keeps information about the factual observation; (ii) ∇ x ( k ) t log p φ ( x ( j ) CF | x ( k ) t ) guide the intervention towards the desired counterfactual class; and (iii) θ ( x ( k ) t , t ) forces the estimation to belong to the data distribution. We follow Dhariwal and Nichol (2021) in adding an hyperparameter s which controls the scale of ∇ x ( k ) t log p φ ( x ( j ) CF | x ( k ) t ) . High values of s might result in counterfactuals that are too different from the factual data. We show this empirically and discuss the effects of this hyperparameter in Sec. 5.3.
Algorithm 1 Inference of counterfactual for a variable x ( k ) from an intervention on x ( j ) ∈ an ( k )
Models: trained diffusion model θ and anti-causal predictor p φ ( x ( j ) | x ( k ) t )
Input : factual variable x ( k ) 0 , F , target intervention x ( j ) 0 , CF , scale s
Output: counterfactual x ( k ) 0 , CF
Abduction of Exogenous Noise - Recovering u ( k ) from x ( k ) 0 , F
$$\begin{aligned}
& \int _ { t } ^ { - 0 } t e^{i ( k ) x } d t \\
& = \frac { i ( k ) } { T _ { F } } \int _ { t } ^ { - 0 } \frac { e^{i ( k ) x } } { \sqrt { 1 + e^{2 i ( k ) x } } } d x \\
& + \sqrt { a _ { t } + 1 } e^{i ( k ) x _ { F } } \int _ { t } ^ { - 0 } \frac { e^{i ( k ) x } } { \sqrt { 1 + e^{2 i ( k ) x } } } d x \\
& = \frac { i ( k ) } { T _ { F } } \left( \frac { e^{i ( k ) t } } { \sqrt { 1 + e^{2 i ( k ) t } } } - \frac { e^{i ( k ) t } } { \sqrt { 1 + e^{2 i ( k ) t } } } \right) \\
& = x _ { F } e^{i ( k ) t }
\end{aligned}$$
## Generation under Intervention
$$\begin{array}{ll}
\text{for } t \to - T \to 0 \text{ do} & \begin{aligned}
e^{-\epsilon q ( x _ { t } ^ { k } , t ) - s \sqrt{1 - a _ { t } ^ { \nabla } } \frac{\log p _ { \phi } ( x _ { t } ^ { k } )}{x _ { t } ^ { k } - \sqrt{a _ { t } ^ { 2 } e}} & \\
\end{aligned}
& + \sqrt{a _ { t } ^ { 2 } e}
\end{array}
\end{equation}
\begin{array}{ll}
\text{end} & x ( k )_{0,CF} = x ( k )_{0}
\end{array}$$
## 4. Related Work
Generative EBMs. Our generative framework is inspired on the energy based models literature (Ho et al., 2020; Song et al., 2021b; Du and Mordatch, 2019; Grathwohl et al., 2020). In particular, we leverage the theory around denoising diffusion models (Sohl-Dickstein et al., 2015; Ho et al.,
## SANCHEZ TSAFTARIS
2020; Nichol and Dhariwal, 2021). We take advantage of a non-Markovian definition DDIM (Song et al., 2021a) which allows faster sampling and recovering latent spaces from observations. Our theory connecting diffusion models and SDEs follows Song et al. (2021b), but from a different perspective. Even though Du et al. (2020) are not constrained to causal modeling, they also use the idea of guiding the generation with gradient of conditional energy models. Recently, Sinha et al. (2021) proposed a version of diffusion models for manipulable generation based on contrastive learning. Finally, Dhariwal and Nichol (2021) derive a conditional sampling process for DDIM that is used in this paper as detailed in Sec. 3.3. Here, we re-interpret their generation algorithm from a causal perspective and add deterministic latent inference for counterfactual estimation. The main, but key difference, is that we add the abduction of exogenous noise . Without this abduction, we cannot ensure that the resulting image will match other aspects of the original image whilst altering only the intended aspect (ie. Where we want to intervene). We can sample from a counterfactual distribution instead of the interventional distribution.
Counterfactuals. Designing causal models with deep learning components has allowed causal inference with high-dimensional variables (Pawlowski et al., 2020; Shen et al., 2020; Dash et al., 2020; Xia et al., 2021; Zeˇ cevi et al., 2021). Given a factual observation, counterfactuals are obtained by measuring the effect of an intervention in one of the ancestral attributes. They have been used in a range of applications such as (i) explaining predictions (Verma et al., 2020; Goyal et al., 2019; Looveren and Klaise, 2021; Hvilshøj et al., 2021); (ii) defining fairness (Kusner et al., 2017); (iii) mitigating data biases (Denton et al., 2019); (iv) improving reinforcement learning (Lu et al., 2020); (v) predicting accuracy (Kaushik et al., 2020); (vi) increasing robustness against spurious correlations (Sauer and Geiger, 2021). Most similar to our work, Schut et al. (2021) estimate counterfactuals via iterative updates using the gradients of a classifier. However, their method is based on adversarial updates computed via epistemic uncertainty, not diffusion processes.
## 5. Experiments
Ground truth counterfactuals are, by definition, impossible to acquire. Counterfactuals are hypothetical predictions. In an ideal scenario, the SCM of problem is fully specified. In this case, one would be able to verify if unrelated causal variables kept their values 3 . However, a complete causal graph is rarely known in practice. In this section, we (i) present ideas on how to evaluate counterfactuals without access to the complete causal graph nor semi-synthetic data; (ii) show with quantitative and qualitative experiments that our method is appropriate for counterfactual estimation; (iii) propose CLD, a metric for quantitative evaluation of counterfactuals; and (iv) use CLD for fine tuning an important hyperparameter of our framework.
Causal Setup. Weconsider a causal model G image with two variables x (1) ← x (2) following the example in Sec. 3.3. Here, x (1) represents an image and x (2) a class. In practice, the gradient of the marginal distribution of x (1) is learned with a diffusion model, which we refer as θ , as in Sec. 2.1. The anti-causal conditional distribution is also learned with a neural network p φ ( x (2) | x (1) ) . Our experiments aim at sampling from the counterfactual distribution p G ( x (1) | do ( x (2) = x (2) CF ); x (1) F ) . Extra experiments on sampling from interventional distribution are in Appendix F.
Implementation. θ is implemented as an encoder-decoder architecture with skip-connections, i.e. a Unet-like network (Ronneberger et al., 2015). For anti-causal classification tasks, we use the
3. Remember that interventions only change descendants in a causal graph.
encoder of θ with a pooling layer followed by a linear classifier. Both θ and p φ ( x (2) | x (1) ) dependent on diffusion time. The diffusion model and anti-causal predictor are trained separately. Implementation details are in Appendix E.
Baselines. Weconsider Schut et al. (2021) and Looveren and Klaise (2021) because they (i) generate counterfactuals based on classifiers decisions; and (ii) evaluate results with metrics tailored to counterfactual estimation on images.
Datasets. Considering the causal model G image described above, we compare our method quantitatively and qualitatively with baselines on MNIST data (Lecun et al., 1998). Furthermore, we show empirically that our approach works with more complex, higher-resolution images from the ImageNet dataset (Deng et al., 2009). We only perform qualitative evaluations on ImageNet since the baseline methods cannot generate counterfactuals for this dataset.
## 5.1. Evaluating Counterfactuals: Realism and Closeness to Data Manifold
Taking into account the causal model G image , we now employ the strategies for counterfactual estimation in Sec. 3.4. In particular, given an image x (1) F ∼ x (1) and a target intervention x (2) CF in the class variable, we wish to estimate the counterfactual x (1) CF for the image x (1) F . We use two metrics proposed by Looveren and Klaise (2021), IM1 and IM2, to measure the realism, interpretability and closeness to the data manifold based on the reconstruction loss of autoencoders trained on specific classes. See details in Appendix G.
Experimental Setup. We run Alg. 1 over the test set with randomly sampled target counterfactual classes x (2) CF ∼ x (2) , ∀ x (2) = x (2) F . For example. we generate counterfactuals of all MNIST classes for a given factual image, as illustrated in Appendix H. We evaluate realism of Diff-SCM, Schut et al. and Looveren and Klaise using the IM1 and IM2 metrics. Diff-SCM achieves better results (lower is better) in both metrics 4 , as shown in Tab. 1. We show qualitative results on ImageNet in Fig. 1 and on MNIST in Appendix H. A qualitative comparison between methods is depicted in Fig. 3( b ).
Table 1: Quantitative comparison between Diff-SCM and baselines. Lower is better for all metrics. Results are presented with mean ( µ ) and standard deviation σ over the test set in the format µ σ .
| Method | IM1 ↓ | IM2 ↓ | CLD ↓ |
|---------------------|---------------|---------------|---------------|
| Diff-SCM (ours) | 0 . 94 0 . 02 | 0 . 04 0 . 00 | 1 . 08 0 . 03 |
| Looveren and Klaise | 1 . 10 0 . 03 | 0 . 05 0 . 00 | 1 . 25 0 . 03 |
| Schut et al. | 1 . 05 0 . 01 | 0 . 10 0 . 00 | 1 . 19 0 . 01 |
## 5.2. Counterfactual Latent Divergence (CLD)
Since one cannot measure changes in all variables of a real SCM, we leverage the sparse mechanism shift (SMS) hypothesis 5 (Scholkopf et al., 2021) for justifying a minimality property of counterfac-
4. We highlight that our setting is slightly different from baseline works where the target counterfactual classes were similar to the factual classes. e.g. Transforming MNIST digits from 2 → [3 , 7] or 4 → [1 , 9] . Since we are sampling target classes randomly, their metric values will look lower than in their respective papers.
5. SMS states that a 'small distribution changes tend to manifest themselves in a sparse or local way in the causal factorization, that is, they should usually not affect all factors simultaneously.'
## SANCHEZ TSAFTARIS
Figure 3: ( a ) A t-SNE visualization of the 20-dimensional latent vector of a variational autoencoder VAE over all MNIST samples. Each point represents an MNIST image and colors represent the ground-truth label of each sample. CLD's goal is to estimate a relative similarity between the factual data and the counterfactual. The distance between the generated counterfactual do (0) and factual observation is compared to the distances between the factual observation and all other data points from factual and counterfactual classes. ( b ) Qualitative comparison with baselines approaches for counterfactual estimation. Each column represents one method and each row a different intervention on digit class. The train. column shows training samples belonging to the target intervention class.
<details>
<summary>Image 3 Details</summary>

### Visual Description
## [Composite Technical Figure]: Causal Latent Diffusion (CLD) Intuition and Qualitative Comparison
### Overview
The image is a two-part technical figure from a research paper, likely in the field of machine learning or causal inference. It visually explains a method called "CLD" (Causal Latent Diffusion) and provides a qualitative comparison of its performance against other methods on a handwritten digit dataset (presumably MNIST). Part (a) illustrates the conceptual intuition behind the method using a scatter plot and example images. Part (b) presents a grid of generated images to compare the results of different techniques.
### Components/Axes
**Part (a): CLD Intuition**
* **Central Element:** A scatter plot (likely a t-SNE or UMAP projection) showing clusters of data points.
* **Legend:** Positioned to the right of the scatter plot. It is a vertical list mapping colors to digit classes:
* Blue dot: 0
* Orange dot: 1
* Green dot: 2
* Red dot: 3
* Purple dot: 4
* Brown dot: 5
* Pink dot: 6
* Gray dot: 7
* Olive dot: 8
* Cyan dot: 9
* **Annotated Images:** Three example handwritten digit images are connected to the scatter plot via dashed lines:
* **Top-Left:** Labeled `do(0)`. Shows a white digit "0" on a black background. A dashed line connects it to a point within the blue cluster (digit 0).
* **Top-Right:** Labeled `factual`. Shows a white digit "3" on a black background. A dashed line connects it to a point within the red cluster (digit 3).
* **Bottom-Left:** Labeled `train`. Shows a white digit "0" on a black background. A dashed line connects it to a different point within the blue cluster (digit 0).
* **Caption:** `(a) CLD Intuition` is centered below this section.
**Part (b): Qualitative Comparison**
* **Structure:** A 4-row by 5-column grid of small grayscale images.
* **Row Labels (Left Side):** Each row is labeled with a causal intervention command:
* Row 1: `do(8)`
* Row 2: `do(3)`
* Row 3: `do(9)`
* Row 4: `do(4)`
* **Column Labels (Bottom):** Each column is labeled with a method or data source:
* Column 1: `orig.` (Original)
* Column 2: `Diff-SCM (ours)` (The authors' proposed method)
* Column 3: `Schut et al.` (A competing method)
* Column 4: `Looveren & Klaise` (Another competing method)
* Column 5: `train.` (Training data sample)
* **Caption:** `(b) Qualitative comparison` is centered below this section.
### Detailed Analysis
**Part (a) Analysis:**
* The scatter plot shows distinct, non-overlapping clusters for each digit class (0-9), indicating that the model's latent space effectively separates different digits.
* The `do(0)` and `train` images are both examples of the digit 0, connected to the same blue cluster, but to different specific points within it. This suggests they are different samples from the same class.
* The `factual` image (digit 3) is connected to the red cluster, demonstrating the link between a specific data point and its class cluster in the latent space.
**Part (b) Analysis - Row by Row Trend Verification:**
* **Row `do(8)`:** The `orig.` image is a clear "8". The `Diff-SCM (ours)` output is a slightly noisier but recognizable "8". `Schut et al.` produces a very noisy, distorted "8". `Looveren & Klaise` generates a blurry, less distinct "8". The `train.` sample is a clear "8".
* **Row `do(3)`:** The `orig.` is a clear "3". `Diff-SCM (ours)` produces a clear "3". `Schut et al.` generates a noisy "3" with artifacts. `Looveren & Klaise` produces a very noisy, almost unrecognizable "3". The `train.` sample is a clear "3".
* **Row `do(9)`:** The `orig.` is a clear "9". `Diff-SCM (ours)` produces a clear "9". `Schut et al.` generates a very noisy, fragmented "9". `Looveren & Klaise` produces a noisy "9" with a disconnected loop. The `train.` sample is a clear "9".
* **Row `do(4)`:** The `orig.` is a clear "4". `Diff-SCM (ours)` produces a clear "4". `Schut et al.` generates a noisy "4". `Looveren & Klaise` produces a very noisy, distorted "4". The `train.` sample is a clear "4".
**General Trend in (b):** Across all rows, the `Diff-SCM (ours)` column consistently produces digit images that are the clearest and most faithful to the `orig.` and `train.` examples among the three methods compared. The methods by `Schut et al.` and `Looveren & Klaise` consistently introduce significant noise, blurriness, or structural distortions.
### Key Observations
1. **Method Superiority:** The proposed `Diff-SCM` method visually outperforms the two baseline methods (`Schut et al.` and `Looveren & Klaise`) in generating clean, recognizable digits after a causal intervention (`do`-operation).
2. **Latent Space Structure:** Part (a) confirms that the model has learned a well-structured latent space where digits of the same class are grouped together.
3. **Intervention Consistency:** The `do` operation appears to successfully target a specific digit class, as all outputs in a given row of part (b) attempt to generate that digit, albeit with varying success.
4. **Noise Profiles:** The failure modes of the baseline methods are distinct: `Schut et al.` tends to produce high-frequency noise and speckles, while `Looveren & Klaise` often results in blurriness and loss of structural integrity.
### Interpretation
This figure serves a dual purpose in a research context. **First**, part (a) provides an intuitive, visual explanation of the core concept: the model operates on a latent space where data points (digits) are clustered by class, and causal interventions (`do`-operations) can be performed on these points. The dashed lines explicitly map concrete image examples to their abstract positions in this latent space, bridging the gap between data and model representation.
**Second**, part (b) acts as empirical, qualitative evidence for the method's effectiveness. By placing its results side-by-side with established baselines and the ground truth (`orig.`, `train.`), the authors demonstrate that their method generates higher-fidelity images. This is crucial for tasks like counterfactual image generation or algorithmic fairness, where producing realistic and accurate modifications is essential. The clear visual superiority of `Diff-SCM` suggests it better preserves the underlying data distribution while successfully applying the desired causal intervention, making it a more reliable tool for causal reasoning in image domains. The figure effectively argues that the proposed method's approach to modeling causality in latent space yields tangible, visible improvements in output quality.
</details>
tuals. SMS translates, in our setting, to an intervention will not change many elements of the observed data . Therefore, an important property of counterfactuals is minimality or proximity to the factual observation. We suggest here a new metric entitled counterfactual latent divergence (CLD), illustrated in Fig. 3( a ), that estimates minimality.
Note that the metrics IM1 and IM2 from Sec. 5.1 do not take minimality into account. In addition, previous work (Wachter et al., 2018; Schut et al., 2021) only used the mean absolute error or 1 distance in the data space for measuring minimality. However, measuring similarity at pixellevel can be challenging as an intervention might change the structure of the image whilst keeping other factors unchanged. In this case, a pixel-level comparison might not be informative about the other factors of variation.
Latent Similarity. Therefore, we choose to measure similarity between latent representation. In addition, we want a representation that captures all factors of variation on the input data. In particular, we train a variational autoencoder (VAE) (Kingma and Welling, 2014) for recovering probabilistic latent representations that capture all factors of variation in the data. The latent spaces computed with the VAE's encoder E φ are denoted as µ i , σ i = E φ ( x (1) i ) , where subscript i means different samples from x (1) ( t = 0 ). We use the Kullback-Leibler divergence (KL) divergence for measuring the distances between latents. The divergence for a given counterfactual estimation and
factual observation pair ( x (1) CF , x (1) F ) can, therefore, be denoted as
$$\div D ( x ^ { i } , x ^ { j } ) = D _ { k L } ( N ( \mu _ { j } , \sigma _ { j } ) .$$
Relative Measure. However, absolute similarity measures give limited information. Therefore, we leverage class information for measuring minimality whilst making sure that the counterfactual is far enough from the factual class. A relative measure is obtained by estimating the probability of sets of divergence measures between the factual observation and other data points in the dataset (formalized in the Eq. 9) to less or greater than div . In particular, we compare div with the set S class of divergence measures between the factual observation x (1) F and all data points x (1) in a dataset D = { ( x (1) , x (2) ) | x (1) ∈ R 2 , x (2) ∈ N } for which the class x (2) is x (2) class is denoted in set-builder notation 6 with:
$$S _ { class } = \{ D ( x ^ { ( 1 ) } , x ^ { ( 2 ) } ) | ( x ^ { ( 1 ) } , x ^ { ( 2 ) } ) \in D \} \cap$$
The sets S CF and S F are obtained by replacing 'class' in S class with the appropriate target class of the counterfactual and factual observation class respectively.
The relative measures are: (i) P ( S CF ≤ div ) for comparing div with the distance between all data points of the counterfactual class and the factual image; and (ii) P ( S F ≥ div ) for comparing div with the distance between all other data points of the factual class and the factual image. We aim for counterfactuals with low P ( S CF ≤ div ) , enforcing minimality, and low P ( S F ≥ div ) , enforcing bigger distances from the factual class.
CLD. We highlight the competing nature of the two measures P ( S CF ≤ div ) and P ( S F ≥ div ) in the counterfactual setting. For example, if the intervention is too minimal i.e. low P ( S CF ≤ div ) - the counterfactual will still resemble observations from the factual class i.e. high P ( S F ≥ div ) . Therefore, the goal is to find the best balance between the two measures. Finally, we define the counterfactual latent divergence (CLD) metric as the LogSumExp of the two probability measures. The LogSumExp operation acts as a smooth approximation of the maximum function. It also penalizes relative peak values for any of the measures when compared to a simple summation. We denote CLD as:
$$\begin{aligned}
& S C F \leq d ( v ) + e x p ( P ( S _ { F } \geq d ( v ) ) ) \\
& = 1 0 .
\end{aligned}$$
We show, using the same experimental setup as in Sec. 5.1, that CLD improves counterfactual estimation when quantitatively compared with the baseline methods, as illustrated in Tab. 1.
## 5.3. Tuning the Hyperparameter s with CLD
We now utilize CLD, the proposed metric, for fine-tuning s , the scale hyperparameter of our framework detailed in Sec. 3.4. Incidentally, the model with hyperparameters achieving best CLD outperforms previous methods in other metrics (see Tab. 1) and output the best qualitative results (see Fig. 3( b )). This result further validate that our metric is suited for counterfactual evaluation.
6. We use the following set-builder notation: MY SET = { function ( input ) | input domain } .
Figure 4: Scale hyperparameter search using CLD (lower is better). The line plot shows the mean and 95% confidence interval. We found that s = 0 . 7 is the best value.
<details>
<summary>Image 4 Details</summary>

### Visual Description
## Line Chart with Confidence Band: CLD vs. Scale
### Overview
The image displays a line chart plotting a variable labeled "CLD" against a variable labeled "Scale." The chart features a single data series represented by a solid blue line, accompanied by a light blue shaded region indicating a confidence interval or range of uncertainty around the central trend.
### Components/Axes
* **X-Axis (Horizontal):**
* **Label:** "Scale"
* **Range:** 0.0 to 3.0
* **Major Tick Marks:** 0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0
* **Y-Axis (Vertical):**
* **Label:** "CLD"
* **Range:** Approximately 1.05 to 1.30
* **Major Tick Marks:** 1.05, 1.10, 1.15, 1.20, 1.25, 1.30
* **Data Series:**
* A single solid blue line representing the primary trend.
* A light blue shaded area surrounding the line, representing the confidence band or uncertainty range.
* **Legend:** No separate legend is present. The single data series and its associated uncertainty band are the only plotted elements.
* **Grid:** A faint gray grid is present in the background, aligned with the major tick marks on both axes.
### Detailed Analysis
**Trend Verification:** The blue line exhibits a distinct non-linear trend. It begins at a high value on the left, descends steeply to a minimum, and then ascends gradually before plateauing.
**Key Data Points (Approximate):**
* At **Scale = 0.0**, CLD is at its maximum, approximately **1.30**.
* The line descends sharply. At **Scale = 0.5**, CLD has fallen to approximately **1.10**.
* The minimum (trough) of the curve occurs between Scale 0.5 and 1.0. The lowest point appears to be at approximately **Scale = 0.75**, where CLD is approximately **1.06**.
* After the minimum, the line begins a gradual ascent. At **Scale = 1.0**, CLD is approximately **1.07**.
* At **Scale = 1.5**, CLD is approximately **1.10**.
* The ascent continues but the slope decreases. At **Scale = 2.0**, CLD is approximately **1.12**.
* From Scale 2.0 to 3.0, the line shows a very slight upward trend, appearing to plateau. At **Scale = 3.0**, CLD is approximately **1.125**.
**Confidence Band Analysis:**
* The shaded uncertainty band is narrowest at the start (Scale 0.0) and end (Scale 3.0) of the plotted range.
* The band is **widest around the minimum point** (Scale ~0.75), indicating the greatest uncertainty or variability in the CLD measurement at this scale.
* The band narrows again as the scale increases beyond 1.5.
### Key Observations
1. **Non-Monotonic Relationship:** The relationship between Scale and CLD is not linear or simply increasing/decreasing. It features a clear minimum.
2. **Steep Initial Decline:** The most dramatic change in CLD occurs for small Scale values (0.0 to ~0.75).
3. **Asymptotic Behavior:** For Scale values greater than 2.0, the CLD value changes very little, suggesting an asymptotic approach to a stable value around 1.12-1.13.
4. **Variable Uncertainty:** The precision of the CLD estimate (as indicated by the confidence band width) is not constant; it is poorest at the point of the minimum CLD value.
### Interpretation
This chart likely illustrates the result of an optimization or sensitivity analysis. The "Scale" parameter is being varied to observe its effect on a metric called "CLD."
* **What the data suggests:** There exists an **optimal Scale value around 0.75** that minimizes the CLD metric. Deviating from this value in either direction (smaller or larger Scale) results in a higher CLD.
* **How elements relate:** The steep descent indicates that CLD is highly sensitive to Scale when Scale is small. The plateau suggests that beyond a certain point (Scale > 2.0), further increases in Scale have a negligible impact on CLD.
* **Notable anomaly/feature:** The **widening of the confidence band at the minimum** is a critical observation. It implies that while the *average* CLD is lowest at Scale ~0.75, the *predictability* or *consistency* of achieving that low value is also at its worst. This could indicate a region of instability or high sensitivity to other uncontrolled variables in the underlying process.
* **Practical implication:** If minimizing CLD is the goal, a Scale near 0.75 is target. However, the high uncertainty at this point might necessitate additional controls or a trade-off consideration, perhaps opting for a slightly higher Scale (e.g., 1.0-1.5) where the CLD is only marginally higher but the outcome is more predictable (narrower confidence band).
</details>
Experimental Setup. We run Alg. 1 while varying the scale hyperparameter s in the [0 . 0 , 3 . 0] interval for MNIST data, as depicted in Fig. 4. When s = 0 , the classifier does not influence the generation, therefore, the counterfactuals are reconstructions of the factual data; resulting in a high CLD.
When s = 3 (too high), the diffusion model contributes much less than the classifier, therefore, the counterfactuals are driven towards the desired class while ignoring the exogenous noise of a given observation. High values of s correspond to strong interventions which do not hold the minimality property, also resulting in a high CLD. Therefore, the optimum point for s is an intermediate value where CLD is minimum. All MNIST experiments were performed using s = 0 . 7 , following this hyperparameter search. See Appendix I for qualitative results.
## 6. Conclusions
We propose a theoretical framework for causal estimation using generative diffusion models entitled Diff-SCM. Diff-SCM unifies recent advances in generative energy-based models and structural causal models. Our key idea is to use gradients of the marginal and conditional distributions entailed by an SCM for causal estimation. The main benefit of only using the distribution's gradients is that one can learn an anti-causal mechanism and use its gradients as a causal mechanism for generation. We show empirically how it can be applied to a two variable causal model. We leave the extension to more complex causal models to future work.
Furthermore, we present an algorithm for performing interventions and estimating counterfactuals with Diff-SCM. We acknowledge the difficulty of evaluating counterfactuals and propose a metric entitled counterfactual latent divergence (CLD). CLD measures the distance, in a latent space, between the observation and the generated counterfactual by comparison with other distances between samples in the dataset. We use CLD for comparison with baseline methods and for hyperparameter search. Finally, we show that the proposed Diff-SCM achieves better quantitative and qualitative results compared to state-of-the-art methods for counterfactual generation on MNIST.
Limitations and future work. We only have specifications for two variables in our empirical setting, therefore, applying an intervention on x (2) means changing all the correlated variables within this dataset. Applying Diff-SCM to more complex causal models would require the use of additional techniques. For instance, consider the SCM depicted in Fig. 2, a classifier naively trained to predict x (2) (class) from x (1) (image) would be biased towards the confounder x (3) . Therefore, the gradient of the classifier w.r.t the image would also be biased. This would make the intervention do( x (2) ) not correct. In this case, the graph mutilation (removing edges from parents of node intervened on) would not happen because the gradients from the classifier would pass information about x (3) . We leave this extension for future work.
## 7. Acknowledgement
We thank Spyridon Thermos, Xiao Liu, Jeremy Voisey, Grzegorz Jacenkow and Alison O'Neil for their input on the manuscript and research support. This work was supported by the University of Edinburgh, the Royal Academy of Engineering and Canon Medical Research Europe via Pedro Sanchez's PhD studentship. This work was partially supported by the Alan Turing Institute under the EPSRC grant EP N510129 \ 1. We thank Nvidia for donating a TitanX GPU. S.A. Tsaftaris acknowledges the support of Canon Medical and the Royal Academy of Engineering and the Research Chairs and Senior Research Fellowships scheme (grant RCSRF1819 \ 825).
## References
- E Bareinboim, J Correa, D Ibeling, and T Icard. On Pearl's Hierarchy and the Foundations of Causal Inference, 2020.
- Tineke Blom, Stephan Bongers, and Joris M Mooij. Beyond Structural Causal Models: Causal Constraints Models. In Proc. 35th Uncertainty in Artificial Intelligence Conference , pages 585594, 2020.
- Stephan Bongers and Joris M Mooij. From Random Differential Equations to Structural Causal Models: the stochastic case. arxiv pre-print , 2018.
- Ricky T Q Chen, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. Neural Ordinary Differential Equations. In Advances in Neural Information Processing Systems , 2018.
- Saloni Dash, Vineeth N Balasubramanian, and Amit Sharma. Evaluating and Mitigating Bias in Image Classifiers: A Causal Perspective Using Counterfactuals. arxiv pre-print , 2020.
- Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In Proc. of Conference on Computer Vision and Pattern Recognition , pages 248-255. IEEE, 2009.
- Emily Denton, Ben Hutchinson, Margaret Mitchell, Timnit Gebru, and Andrew Zaldivar. Image Counterfactual Sensitivity Analysis for Detecting Unintended Bias. arxiv pre-print , 12 2019.
- Prafulla Dhariwal and Alex Nichol. Diffusion Models Beat GANs on Image Synthesis. In Advances in Neural Information Processing Systems , 2021.
- Xin Du, Lei Sun, Wouter Duivesteijn, Alexander Nikolaev, and Mykola Pechenizkiy. Adversarial balancing-based representation learning for causal effect inference with observational data. Data Mining and Knowledge Discovery , 35(4):1713-1738, 12 2021.
- Yilun Du and Igor Mordatch. Implicit Generation and Generalization in Energy-Based Models. In Advances in Neural Information Processing Systems , 12 2019.
- Yilun Du, Shuang Li, Igor Mordatch, and Google Brain. Compositional Visual Generation with Energy Based Models. In Advances in Neural Information Processing Systems , 2020.
## SANCHEZ TSAFTARIS
- Yash Goyal, Ziyan Wu, Jan Ernst, Dhruv Batra, Devi Parikh, and Stefan Lee. Counterfactual Visual Explanations. Proc. of 36th International Conference on Machine Learning , pages 4254-4262, 12 2019.
- Will Grathwohl, Kuan-Chieh Wang, J¨ orn-Henrik Jacobsen, David Duvenaud, Kevin Swersky, and Mohammad Norouzi. Your Classifier is Secretly an Energy Based Model and You Should Treat it Like One. In Proc. of International Conference on Learning Representations , 2020.
- Jennifer Hill. Bayesian Nonparametric Modeling for Causal Inference. Journal of Computational and Graphical Statistics , 20(1):217-240, 12 2012.
- Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising Diffusion Probabilistic Models. In Advances on Neural Information Processing Systems , 2020.
- Frederik Hvilshøj, Alexandros Iosifidis, and Ira Assent. ECINN: Efficient Counterfactuals from Invertible Neural Networks. 12 2021.
- Aapo Hyv¨ arinen. Estimation of Non-Normalized Statistical Models by Score Matching. Journal of Machine Learning Research , 6:695-709, 2005.
- Fredrik D Johansson, Uri Shalit, and David Sontag. Learning Representations for Counterfactual Inference. In Proc. of International Conference on Machine Learning , 2016.
- Divyansh Kaushik, Amrith Setlur, Eduard Hovy, and Zachary C Lipton. EXPLAINING THE EFFICACY OF COUNTERFACTUALLY AUGMENTED DATA, 12 2020.
- N Kilbertus, G Parascandolo, and B Scholkopf. Generalization in anti-causal learning. In NeurIPS Workshop on Critiquing and Correcting Trends in Machine Learning , 2018.
- Diederik P Kingma and Max Welling. Auto-Encoding Variational Bayes. 2nd International Conference on Learning Representations , 2014.
- Matt Kusner, Joshua Loftus, Chris Russell, and Ricardo Silva. Counterfactual Fairness. In Advances on Neural Information Processing Systems , 2017.
- Y Lecun, L Bottou, Y Bengio, and P Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE , 86(11):2278-2324, 1998.
- Arnaud Van Looveren and Janis Klaise. Interpretable Counterfactual Explanations Guided by Prototypes. In European Conference on Machine Learning and Principles and Practice of Knowledge Discovery in Databases , volume 1907.02584, 12 2021.
- Christos Louizos, Uri Shalit, Joris Mooij, David Sontag, Richard Zemel, and Max Welling. Causal Effect Inference with Deep Latent-Variable Models. In Advances on Neural Information Processing Systems , 2017.
- Chaochao Lu, Biwei Huang, Ke Wang, Jos´ e Miguel Hern´ andez-Lobato, Kun Zhang, and Bernhard Sch¨ olkopf. Sample-Efficient Reinforcement Learning via Counterfactual-Based Data Augmentation. arxiv pre-print , 12 2020.
- Joris M Mooij, Dominik Janzing, and Bernhard Sch¨ olkopf. From Ordinary Differential Equations to Structural Causal Models: The Deterministic Case. In Proceedings of the Twenty-Ninth Conference on Uncertainty in Artificial Intelligence , pages 440-448, 2013.
- Alex Nichol and Prafulla Dhariwal. Improved Denoising Diffusion Probabilistic Models. arxiv pre-print , 12 2021.
- Bernt Øksendal. Stochastic Differential Equations: An Introduction with Applications . Springer, fifth edition edition, 2003. ISBN 978-3-642-14394-6.
- Nick Pawlowski, Daniel C Castro, and Ben Glocker. Deep Structural Causal Models for Tractable Counterfactual Inference. In Advances in Neural Information Processing Systems , 2020.
- Judea Pearl. Causality . Cambridge University Press, 2009. doi: 10.1017/CBO9780511803161.
- Judea Pearl. Causal inference in statistics : a primer . John Wiley & Sons Ltd, Chichester, West Sussex, UK, 2016. ISBN 978-1-119-18684-7.
- Judea Pearl and Dana Mackenzie. The Book of Why: The New Science of Cause and Effect . Basic books, 2018.
- Jonas Peters, Dominik Janzing, and Bernhard Sch¨ olkopf. Elements of causal inference . MIT Press, 2017.
- ORonneberger, P.Fischer, and T Brox. U-Net: Convolutional Networks for Biomedical Image Segmentation. In Proc. of Medical Image Computing and Computer-Assisted Intervention , volume 9351, pages 234-241. Springer, 2015.
- Paul K Rubenstein, Stephan Bongers, uvanl Bernhard Sch¨ olkopf, and Joris M Mooij. From Deterministic ODEs to Dynamic Structural Causal Models. In Proceedings of the 34th Annual Conference on Uncertainty in Artificial Intelligence (UAI-18). , 2018.
- Simo S¨ arkk¨ a and Arno Solin. Applied Stochastic Differential Equations , volume 10. Cambridge University Press, 2019.
- Axel Sauer and Andreas Geiger. Counterfactual Generative Networks. In Proc. of International Conference on Learning Representations , 12 2021.
- Bernhard Sch¨ olkopf, Dominik Janzing, Jonas Peters, Eleni Sgouritsa, Kun Zhang, and Joris Mooij JMOOIJ. On Causal and Anticausal Learning. In Proc. of the International Conference on Machine Learning , 2012.
- Bernhard Scholkopf, Francesco Locatello, Stefan Bauer, Nan Rosemary Ke, Nal Kalchbrenner, Anirudh Goyal, and Yoshua Bengio. Toward Causal Representation Learning. Proceedings of the IEEE , 2021.
- Lisa Schut, Oscar Key, Rory McGrath, Luca Costabello, Bogdan Sacaleanu, Medb Corcoran, and Yarin Gal. Generating Interpretable Counterfactual Explanations By Implicit Minimisation of Epistemic and Aleatoric Uncertainties. In Proc. of The 24th International Conference on Artificial Intelligence and Statistics , pages 1756-1764, 2021.
## SANCHEZ TSAFTARIS
- Xinwei Shen, Furui Liu, Hanze Dong, Qing Lian, Zhitang Chen, and Tong Zhang. Disentangled Generative Causal Representation Learning. arxiv pre-print , 2020.
- Claudia Shi, David M Blei, and Victor Veitch. Adapting Neural Networks for the Estimation of Treatment Effects. In Proc. of Neural Information Processing Systems , 2019.
- Yishai Shimoni, Chen Yanover, Ehud Karavani, and Yaara Goldschmnidt. Benchmarking Framework for Performance-Evaluation of Causal Inference Analysis. arxiv pre-print , 12 2018.
- Abhishek Sinha, Jiaming Song, Chenlin Meng, and Stefano Ermon. D2C: Diffusion-Decoding Models for Few-Shot Conditional Generation. arXiv pre-print , 2021.
- Jascha Sohl-Dickstein, Eric A Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep Unsupervised Learning using Nonequilibrium Thermodynamics. Proc. of 32nd International Conference on Machine Learning , 3:2246-2255, 12 2015.
- Alexander Sokol and Niels Richard Hansen. Causal interpretation of stochastic differential equations. Electronic Journal of Probability , 19:1-24, 2014.
- Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising Diffusion Implicit Models. In Proc. of International Conference on Learning Representations , 2021a.
- Yang Song and Stefano Ermon. Generative Modeling by Estimating Gradients of the Data Distribution. Advances in Neural Information Processing Systems , 32, 2019.
- Yang Song Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-Based Generative Modeling Through Stochastic Differential Equations. In ICLR , 2021b.
- Vladimir N Vapnik. An overview of statistical learning theory. IEEE Transactions on Neural Networks , 10(5):988-999, 1999.
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention Is All You Need. In Advances in neural information processing systems , pages 5998-6008, 2017.
- Sahil Verma, John Dickerson, and Keegan Hines. Counterfactual Explanations for Machine Learning: A Review. arxiv pre-print , 12 2020.
- Sandra Wachter, Brent Mittelstadt, and Chris Russell. Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR. Harvard Journal of Law & Technology , 31 (2), 2018.
- Kevin Xia, Kai-Zhan Lee, Yoshua Bengio, and Elias Bareinboim. The Causal-Neural Connection: Expressiveness, Learnability, and Inference. 12 2021.
- Mengyue Yang, Furui Liu, Zhitang Chen, Xinwei Shen, Jianye Hao, and Jun Wang. CausalVAE: Disentangled Representation Learning via Neural Structural Causal Models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition , pages 9593-9602, 2021.
- Matej Zeˇ cevi, Devendra Singh Dhami, Petar Veliˇ ckovi, and Kristian Kersting. Relating Graph Neural Networks to Structural Causal Models. arxiv pre-print , 2021.
## Appendix A. Theory for Training Diffusion Models
We now review with more detailed the formulation of Denoising Diffusion Probabilistic Models (DDPMs) (Ho et al., 2020). In DDPM, samples are generated by reversing a diffusion process with a neural network from a Gaussian prior distribution. We begin by defining our data distribution x 0 ∼ p ( x 0 ) and a Markovian noising process which gradually adds noise to the data to produce noised samples x t up to x T . In particular, each step of the noising process adds Gaussian noise according to some variance schedule given by β t :
$$p ( x _ { 1 } | x _ { 1 - 1 } ) = N ( x _ { 1 } ; \sqrt { I - p _ { 1 } x _ { 1 } } )$$
In addition, it's possible to sample x t directly from x 0 without repeatedly sample from x t ∼ p ( x t | x t -1 ) . Instead, p ( x t | x 0 ) can be expressed as a Gaussian distribution by defining a variance of the noise for an arbitrary timestep α t := ∏ t j =0 (1 -β j ) . We, therefore, proceed to define
$$p ( x _ { 1 } | x _ { 0 } ) = N ^ { \prime } ( x _ { 1 } ; \sqrt { a _ { 1 } x _ { 0 } } ( 1 - a$$
$$= \sqrt { a _ { 1 } x _ { 0 } + e ^ { i v I - a _ { 2 } } , e > N }$$
However, we are interested in a generative process which consists in performing a reverse diffusion, going from noise x T to data x 0 . As such, the model trained with parameters θ should correspond to conditional distribution p θ ( x t -1 | x t ) .
Using Bayes theorem, one finds that the posterior p ( x t -1 | x t , x 0 ) is also a Gaussian with mean ˜ µ t ( x t , x 0 ) and variance ˜ β t defined as follows:
$$\frac { \sqrt { a _ { t } - 1 } } { 1 - a _ { t } } = \frac { a _ { t } ( x _ { t } , x _ { 0 } ) } { 1 - a _ { t } }$$
$$p ( x _ { t - 1 } | x _ { t } , x _ { 0 } ) = N ( x _ { t - 1 } ; \mu ( x _ { t } , x _ { 0 }$$
Training p θ ( x t -1 | x t ) such that p ( x 0 ) learns the true data distribution, the following variational lower-bound L vlb for p θ ( x 0 ) can be optimized:
$$L _ { v i b } := - \log p _ { 0 } ( x _ { 0 } | x )$$
Ho et al. (2020) considered a variational approximation of the Eq. 15 for training p θ ( x t -1 | x t ) efficiently. Instead of directly parameterize µ θ ( x t , t ) as a neural network, a model θ ( x t , t ) is trained to predict from Equation 13. This simplified objective is defined as follows:
$$L _ { simple } = E _ { t } \cdot x 0 - p d t a _ { c }$$
## Appendix B. Pearl's Causal Hierarchy
Bareinboim et al. (2020) use Pearl's Causal Hierarchy (PCH) nonmenclature after Pearl's seminal work on causality which is well illustrated in Pearl and Mackenzie (2018) as the Ladder of Causation . PCH states that structural causal models should be able to sample from a collection of three distributions (Peters et al. (2017), Ch. 6) which are related to cognitive capabilities:
1. The observational ('seeing') distribution p G ( x ( k ) ) .
2. The do-calculus (Pearl, 2009) formalizes sampling from the interventional ('doing') distribution p G ( x ( k ) | do ( x ( j ) = x ( j ) )) . The do () operator means an intervention on a specific variable is propagated only through it's descendants in the SCM G . The causal structure forces that only the descendants of the variable intervened upon will be modified by a given action.
3. Sampling from a counterfactual ('imagining') distribution p G ( x ( k ) | do ( x ( j ) = x ( j ) ); x ( k ) ) involves applying an intervention do ( x ( j ) = x ( j ) ) on an given instance x ( k ) . Contrary to the factual observation, a counterfactual corresponds to a hypothetical scenario.
## Appendix C. Example of Anti-causal Intervention
We illustrate Prop. 1 in a case with two variables, which is also used in the experiments. Consider a variable x (1) caused by x (2) , i.e. x (1) ← x (2) . Following the causal direction, the joint distribution can be factorised as p ( x (1) , x (2) ) = p ( x (1) | x (2) ) p ( x (2) ) . Applying an intervention with the SDE framework, however, one would only need ∇ x (1) log p t ( x (1) | x (2) = x (2) ) , as in Eq. 6. By applying Bayes' rule, one can derive p ( x (1) | x (2) ) = p ( x (2) | x (1) ) p ( x (1) ) /p ( x (2) ) . Therefore, the sampling process would be done with
$$\sum _ { x ( 1 ) } ^ { \infty } \log p ( x ^ { 2 } ) | x ( 1 ) } + \sum _ { x ( 1 ) } ^ { \infty } \log p ( x ^ { 2 } ) .$$
## Appendix D. DDIM sampling procedure
A variation of the DDPM (Ho et al., 2020) sampling procedure is done with Denoising Diffusion Implicit Models (DDIM, Song et al. (2021a)). DDIM formulates an alternative non-Markovian noising process that allows a deterministic mapping between latents to images. The deterministic mapping means that the noisy term in Eq. 2 is no longer necessary for sampling. This sampling approach has the same forward marginals as DDPM, therefore, it can be trained in the same manner. This approach was used for sampling throughout the paper as explained in Sec. 3.4.
Alg. 2 describes DDIM's sampling procedure from x T ∼ N (0 , I ) (exogenous noise distribution) to x 0 (data distribution) deterministic procedure. This formulation has two main advantages: (i) it allows a near-invertible mapping between x T and x 0 as shown in Alg. 3; and (ii) it allows efficient sampling with fewer iterations even when trained with the same diffusion discretization. This is done by choosing different undersampling t in the [0 , T ] interval.
## Algorithm 2 Sampling with DDIM - Image Generation
Models:
trained diffusion model θ .
Input :
x T ∼ N (0 , I)
Output:
x 0 - Image
for t ← T to 0 do
$$\begin { cases }
x _ { t - 1 } + \sqrt { \alpha _ { t - 1 } } ( x _ { t } - 1 ) \\
+ \sqrt { \alpha _ { t } }$$
end
| Algorithm 3 Reverse-Sampling with DDIM - Inferring the Noisy Latent | |
|-----------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------|
| Models: trained diffusion model θ . | Models: trained diffusion model θ . |
| Input : x 0 - Image | Input : x 0 - Image |
| Output: x T - Latent Space | Output: x T - Latent Space |
| for t ← T to 0 do √ | for t ← T to 0 do √ |
| x t +1 ← √ α t +1 ( x t - 1 - α t θ ( x t ,t ) √ α t ) + √ α t +1 θ ( x t , t ) | x t +1 ← √ α t +1 ( x t - 1 - α t θ ( x t ,t ) √ α t ) + √ α t +1 θ ( x t , t ) |
## Appendix E. Implementation Details
For each dataset, we train two models that are trained separately: (i) θ is implemented as an encoder-decoder architecture with skip-connections, i.e. a Unet-like network (Ronneberger et al., 2015). (ii) A (Anti-causal) classifier that uses the encoder of θ with a pooling layer followed by a linear classifier. All models are time conditioned. Time, which is a scalar, is embedded using the transformer's sinusoidal position embedding (Vaswani et al., 2017). The embedding is incorporated into the convolutional models with an Adaptive Group Normalization layer into each residual block (Nichol and Dhariwal, 2021). Our architectures and training procedure follow Dhariwal and Nichol (2021). They performed an extensive ablation study of important components from DDPM (Ho et al., 2020) and improved overall image quality and log-likelihoods on many image benchmarks. We use the same hyperparameters as Dhariwal and Nichol (2021) for the ImageNet and define ours for MNIST. The specific hyperparameters for diffusion and classification models follow Tab. 2. We train all of our models using Adam with β 1 = 0 . 9 and β 2 = 0 . 999 . We train in 16-bit precision using loss-scaling, but maintain 32-bit weights, EMA, and optimizer state. We use an EMA rate of 0.9999 for all experiments.
We use DDIM sampling for all experiments with 1000 timesteps. The same noise schedule is used for training. Even though DDIM allows faster sampling, we found that it does not work well for counterfactuals.
Table 2: Hyperparameters for models.
| dataset | ImageNet 256 | ImageNet 256 | MNIST | MNIST |
|----------------------|----------------|----------------|-----------|------------|
| model | diffusion | classifier | diffusion | classifier |
| Diffusion steps | 1000 | 1000 | 1000 | 1000 |
| Model size | 554M | 54M | 2M | 500K |
| Channels | 256 | 128 | 64 | 32 |
| Depth | 2 | 2 | 1 | 1 |
| Channels multiple | 1,1,2,2,4,4 | 1,1,2,2,4,4 | 1,2,4 | 1,2,4,4 |
| Attention resolution | 32,16,8 | 32,16,8 | - | - |
| Batch size | 256 | 256 | 256 | 256 |
| Iterations | ≈ 2 M | ≈ 500 K | 30K | 3K |
| Learning Rate | 1e-4 | 3e-4 | 1e-4 | 1e-4 |
## Appendix F. Sampling from The Interventional Distribution
In this section, we make sure that our method complies with the second level of Pearl's Causal Hierarchy (details in Appendix B). Diff-SCM can be used for efficiently sampling from the interventional distributions p G image ( x (1) | do ( x (2) = x (2) )) . Sampling from the interventional distribution can be done by using the second part ('Generation with Intervention') of Alg. 1 but sampling u ( k ) from a Gaussian prior, instead of inferring the latent space (using 'Abduction of Exogenous Noise'). This formulation is identical to Dhariwal and Nichol (2021) with guided DDIM (Song et al., 2021a) (details in appendix D). Dhariwal and Nichol (2021) achieves state-of-the-art image quality results in generation while providing faster sampling than DDPM. Since its capabilities in image synthesis compared to other generative models are shown in Dhariwal and Nichol (2021), we restrict ourselves to present qualitative results on ImageNet 256x256.
Experimental Setup. Our experiment, depicted in Fig. 5, consists in sampling a single latent space u (1) from a Gaussian distribution and generating samples for different classes. Since all images are generated from the same latent, this allows visualization of the effect of the classifier guidance for different classes. This setup differs from experiments in Dhariwal and Nichol (2021), where each image presented was a different sample u (1) ∼ u (1) . Here, by sampling u (1) only once, we isolate the contribution of the causal mechanism from the sampling of the exogenous noise u (1) . We use the scale hyperparameter s = 5 for these experiments.
Figure 5: Sampling ImageNet images from the interventional distribution. All images originate from the same initial noise u ( k ) but different interventions are applied at inference time.
<details>
<summary>Image 5 Details</summary>

### Visual Description
## Composite Figure: Labeled Image Panels
### Overview
The image is a horizontal composite of five distinct rectangular panels arranged side-by-side. Each panel consists of a text label at the top and a corresponding image below it. The figure appears to demonstrate or compare different image categories, possibly for a machine learning or computer vision context, starting with a noise pattern.
### Components/Axes
The image is structured as five vertical panels. Each panel has a consistent layout:
1. **Top Region (Label):** A line of text centered above the image.
2. **Main Region (Image):** A square or nearly square image filling the rest of the panel.
**Panel Labels (from left to right):**
1. `u^(k) [noise]`
2. `do(chimpanzee)`
3. `do(mushroom)`
4. `do(bookshop)`
5. `do(goose)`
### Detailed Analysis
**Panel 1 (Far Left):**
* **Label:** `u^(k) [noise]`
* **Image Content:** A uniform, textureless field of a dark teal or blue-green color. It appears to be digital noise or a blank initialization state, with no discernible objects or patterns.
**Panel 2:**
* **Label:** `do(chimpanzee)`
* **Image Content:** A clear, photographic image of a chimpanzee. The chimpanzee is sitting, facing slightly to the left, with its mouth open as if vocalizing. It is in a natural, outdoor setting with blurred green foliage in the background.
**Panel 3:**
* **Label:** `do(mushroom)`
* **Image Content:** A clear, photographic image of a single mushroom. It has a distinctive red cap with white spots (resembling an *Amanita muscaria*) and a pale stem. It is centered in the frame against a blurred background of green grass and foliage.
**Panel 4:**
* **Label:** `do(bookshop)`
* **Image Content:** A clear, photographic image of the interior of a bookshop or library. The view is dominated by tall, densely packed wooden bookshelves filled with books. The perspective looks down an aisle, creating a sense of depth.
**Panel 5 (Far Right):**
* **Label:** `do(goose)`
* **Image Content:** A clear, photographic image of a Canada goose. The goose is swimming in water, with its body angled to the left and its head turned to look back towards the right. The water shows gentle ripples.
### Key Observations
1. **Pattern:** The first panel is labeled as "noise" and contains a non-representational image. The subsequent four panels are labeled with the `do(...)` syntax and contain clear, recognizable photographic images matching the label's subject.
2. **Visual Consistency:** The four photographic panels (2-5) share similar qualities: they are well-lit, in-focus, and feature a single primary subject centered in the frame against a natural or context-appropriate background.
3. **Label Syntax:** The labels use a consistent format. The first uses mathematical notation (`u^(k)`). The others use a function-like notation (`do(...)`), which in causal inference literature often denotes an intervention.
### Interpretation
This figure likely illustrates a concept from machine learning, generative modeling, or causal inference. The progression from a noise pattern (`u^(k)`) to specific, coherent images (`do(chimpanzee)`, etc.) suggests a process of **image generation or transformation**.
* **What it demonstrates:** The `do(...)` operator may represent an intervention or a conditioning process that transforms an initial noise state (`u^(k)`) into a desired output class (chimpanzee, mushroom, etc.). It visually shows the result of "doing" or intervening to produce a specific category of image.
* **Relationship between elements:** The noise panel is the starting point or input. The subsequent panels are the outputs of applying different interventions (`do` commands) to that noise or to a generative model. The clear correspondence between label and image validates the effectiveness of the intervention.
* **Notable aspect:** The stark contrast between the first panel (pure noise) and the others (semantically meaningful images) highlights the power of the underlying process being demonstrated. There are no numerical outliers or trends, as this is a categorical comparison, not a data chart. The primary "trend" is the transformation from randomness to structured, meaningful visual data.
</details>
## Appendix G. IM1 and IM2
Looveren and Klaise (2021) propose IM1 and IM2 for measuring the realism and closeness to the data manifold. These metrics are based on the reconstruction losses of auto-encoders trained on specific classes:
$$\begin{array}{c}
M ( x _ { C F } , x _ { F } , x _ { C P } ) = \frac{1}{\left| x _ { C F } - A E \right|} \left| x _ { C F } - A E \right|^2 \\
&= \frac{1}{2} \left| x _ { C F } - A E \right|^2 + e
\end{array}$$
where AE x (2) denotes an autoencoder trained only on instances from class x (2) , and AE is an autoencoder trained on data from all classes. IM1 is the ratio of the reconstruction loss of an autoencoder trained on the counterfactual class divided by the loss of an autoencoder trained on all classes. IM2 is the normalized difference between the reconstruction of the CF under an autoencoder trained on the counterfactual class, and one trained on all classes.
## Appendix H. More MNIST Counterfactuals
Here, we show in Fig. 6 that we can generate counterfactuals of all MNIST classes, given factual image. We use the scale hyperparameter s = 0 . 7 for these experiments.
orig. rec. do(0) do(1) do(2) do(3) do(4) do(5) do(6) do(7) do(8) do(9)
<details>
<summary>Image 6 Details</summary>

### Visual Description
## Grid of Handwritten Digit Images
### Overview
The image displays a grid of handwritten digits, presented as white strokes on a black background. The digits are arranged in three horizontal rows, each containing a sequence of individual digit samples. The image appears to be a visualization from a dataset (such as MNIST) or a machine learning context, showcasing variations in handwriting style for digits 0 through 9.
### Components/Axes
- **Layout**: The image is divided into three horizontal strips (rows). Each row contains a series of digit images placed side-by-side.
- **Grouping**: Within each row, the digits are split into two distinct groups separated by a noticeable gap (approximately one digit-width of empty space).
- **Digit Cells**: Each digit is contained within an implied rectangular cell. The digits are rendered in a pixelated, grayscale style typical of low-resolution digit datasets.
- **Labels/Text**: There are no explicit axis titles, legends, or numerical scales. The only textual content is the handwritten digits themselves.
### Detailed Analysis
**Row 1 (Top Strip):**
- **Left Group (7 digits)**: 5, 5, 0, 1, 2, 3, 4
- **Right Group (4 digits)**: 6, 7, 8, 9
- **Total Digits**: 11
- **Observations**: The first two digits are both '5' but exhibit different handwriting styles. The sequence progresses from 0 to 9, with '5' repeated and '6' through '9' placed in the right group.
**Row 2 (Middle Strip):**
- **Left Group (8 digits)**: 6, 6, 0, 1, 2, 3, 4, 5
- **Right Group (3 digits)**: 7, 8, 9
- **Total Digits**: 11
- **Observations**: The first two digits are both '6' with stylistic variation. The sequence includes all digits from 0 to 9, with '6' repeated and '7' through '9' in the right group.
**Row 3 (Bottom Strip):**
- **Left Group (10 digits)**: 8, 8, 0, 1, 2, 3, 8, 5, 6, 7
- **Right Group (1 digit)**: 8
- **Total Digits**: 11
- **Observations**: The first two digits are '8', and another '8' appears later in the left group. The sequence is less ordered, with digits 0, 1, 2, 3, 5, 6, 7, and multiple '8's. The right group contains a single '8'.
### Key Observations
1. **Intra-class Variation**: For digits that appear multiple times (e.g., '5' in Row 1, '6' in Row 2, '8' in Row 3), there are clear differences in handwriting style, stroke width, and shape.
2. **Grouping Pattern**: Each row consistently splits its digit sequence into a left group and a right group, with the gap position varying slightly. The right group generally contains the higher digits (6-9), but this is not strict (e.g., Row 3's right group has only '8').
3. **Digit Coverage**: All digits from 0 to 9 are present across the three rows, though not in a perfectly sequential order in each row.
4. **Visual Style**: The digits are low-resolution, anti-aliased, and appear to be normalized to a fixed size within their cells.
### Interpretation
This image likely serves as a qualitative demonstration of a handwritten digit dataset, emphasizing the natural variability in human writing. The grouping may represent different subsets, batches, or classes within the data. The repetition of certain digits (like '5', '6', '8') within a single row highlights the challenge for machine learning models to generalize across different writing styles for the same numeral. The lack of strict numerical order suggests the samples are randomly drawn or organized by some non-sequential criterion (e.g., writer ID, time of collection). For a technical document, this visualization underscores the importance of robust feature extraction and model training to handle such intra-class diversity in optical character recognition (OCR) or digit classification tasks.
</details>
Figure 6: MNIST counterfactuals. From the left to right, one can observe the original image ( orig. ), the reconstruction ( rec. , which entails in running the algorithm 1 without the anti-causal predictor) and the resulting counterfactuals for each of the digit classes in the dataset.
## Appendix I. Qualitative influence of classifier scale
Here, we show in Fig. 7 the influence of changing the classifier's scale s quantitatively. If s is too low, the intervention will have a mild effect. On the other had, if s is too high, the intervention will neglect the information present in the exogenous noise, therefore, the counterfactual is maintain less factors from the original image.
Figure 7: MNIST counterfactuals. From top to bottom, one can observe the original image ( orig. ), the reconstruction ( rec. , and the resulting counterfactuals for the intervention do (5) over three scales. As shown in Fig. 4, s = 0 . 7 is the optimal scale for MNIST data.
<details>
<summary>Image 7 Details</summary>

### Visual Description
## Diagram: Handwritten Digit Grid with Intervention Labels
### Overview
The image displays a 5-row by 10-column grid of handwritten digits (white on a black background), likely from the MNIST dataset. The grid is annotated on the left with labels indicating different processing conditions or interventions applied to the digits. The primary purpose appears to be a visual comparison of original digits, their reconstructions, and the effect of a specific intervention ("do(5)") at varying strength levels.
### Components/Axes
* **Left-side Labels (Text):**
* Row 1: `orig.`
* Row 2: `rec.`
* Rows 3-5: Grouped under a vertical line labeled `do(5)`, with individual row labels:
* Row 3: `s 0.1`
* Row 4: `s 0.7`
* Row 5: `s 2.0`
* **Grid Content:** Each cell contains a single handwritten digit image. The digits are not textual data but visual representations.
### Detailed Analysis
**Row-by-Row Digit Content (Visual Transcription):**
* **Row 1 (`orig.`):** 7, 3, 1, 2, 9, 7, 9, 6, 0, 0
* **Row 2 (`rec.`):** 7, 3, 1, 2, 9, 7, 9, 6, 0, 0
* **Row 3 (`do(5) s 0.1`):** 7, 3, 1, 5, 9, 7, 5, 6, 0, 0
* **Row 4 (`do(5) s 0.7`):** 5, 5, 5, 5, 9, 5, 5, 5, 5, 5
* **Row 5 (`do(5) s 2.0`):** 5, 5, 5, 5, 5, 5, 5, 5, 5, 5
**Trend Verification:**
* **`orig.` to `rec.`:** The reconstruction is visually identical to the original, indicating a high-fidelity model.
* **`do(5)` Intervention Trend:** As the scale parameter `s` increases from 0.1 to 2.0, there is a clear and progressive transformation of the digits towards the digit "5".
* At `s=0.1`, only two digits (columns 4 and 7) have changed to "5".
* At `s=0.7`, nine out of ten digits are "5", with only the digit in column 5 (originally a "9") remaining unchanged.
* At `s=2.0`, all ten digits have become "5".
### Key Observations
1. **Intervention Efficacy:** The `do(5)` operation is highly effective at forcing the output towards the digit "5", with its strength controlled by the parameter `s`.
2. **Resistance Point:** The digit "9" in the 5th column shows notable resistance to the intervention. It remains a "9" at `s=0.1` and `s=0.7`, only succumbing to become a "5" at the highest strength (`s=2.0`).
3. **Reconstruction Fidelity:** The `rec.` row is a perfect visual match for the `orig.` row, suggesting the underlying generative or autoencoder model has very low reconstruction error on this sample.
### Interpretation
This diagram illustrates a controlled experiment on a generative model (e.g., a Variational Autoencoder or a Generative Adversarial Network). The label `do(5)` strongly suggests the application of a **causal intervention** (using the `do`-operator from causal inference) aimed at forcing the model's output to generate the digit "5".
* **What the data demonstrates:** It visually proves that the intervention works and that its effect is dose-dependent. The parameter `s` likely controls the magnitude of the intervention in the model's latent space. A small `s` (0.1) causes minor perturbations, a medium `s` (0.7) causes near-total conversion, and a large `s` (2.0) results in complete dominance of the target digit.
* **Relationship between elements:** The `orig.` and `rec.` rows establish a baseline of model performance. The `do(5)` rows show the model's behavior under active manipulation. The grid format allows for direct, column-wise comparison of how each specific original digit responds to the same intervention.
* **Notable anomaly:** The persistent "9" in column 5 at `s=0.7` is a key finding. It indicates that the model's internal representation of that particular "9" is either very strong or is located in a region of the latent space that is initially orthogonal to the direction of the "5" intervention. This could be due to the visual similarity between a "9" and a "5" (both have a loop and a stem), making it a harder case to transform until the intervention force is overwhelming.
**Language Declaration:** All embedded text (`orig.`, `rec.`, `do(5)`, `s 0.1`, etc.) is in English.
</details>