This repository contains code and resources for testing Large Causal Models (LCMs). The aim is discovering temporal causal relationships in time-series datasets, using pretrained deep NN models. The LCM takes as input a temporal dataset
Install the dependencies from the requirements.txt
file, either on your base environment or into an existing conda environment using
pip install -r requirements.txt
Obtain the LCM weights from the following URLs and place them into the res
folder:
deep_CI_12_3_fine_tuned_frozen_sim_82k_pre_joint_220k
(210 MB): download link (external links will be available soon)lcm_CI_RH_12_3_merged_290k
(4.6 GB): download link (external links will be available soon)
At first, import the necessary modules for data generation, model prediction, and result visualization:
from pathlib import Path
from utils.model_wrapper import Architecture_PL
from utils.cp_utils import set_seed, create_example_data, run_cp_and_parse_res
from utils.plotting_utils import plot_summary_from_pred
We generate synthetic data with 1000 time samples, where each column represents a different time series. Data are Min-max normalized and random seed set to 42
for reproducibility.
set_seed(42)
df = create_example_data(n=1000)
variable_names = list(df.columns)
Load the .ckpt
pretrained model for causal prediction:
models_path = 'res'
model_name = 'lcm_CI_RH_12_3_merged_290k'
model = Architecture_PL.load_from_checkpoint(Path(models_path) / f"{model_name}.ckpt")
Run run_cp_and_parse_res
to perform causal discovery on the previous data. The max_lag
parameter specifies the maximum time window size for analyzing causal relationships:
# Run causal discovery with a maximum lag of 2
pred = run_cp_and_parse_res(model_name, model=model, df=df, max_lag=2, seed=42)
The result is a lagged adjacency tensor of shape (N, N, max_lag)
where:
N
is the number of input time-seriespred[i, j, k]
represents the probability that thej
-th time-series at timet-(max_lag - k)
causes thei
-th time-series at timet
.
The predicted causal relationships can be visualized using plot_summary_from_pred
. The plt_thr
parameter controls the density of the graph: higher values result in fewer edges being displayed.
plot_summary_from_pred(pred, variable_names, plt_thr=0.5)
In the resulting graph, an edge from time series A to B marked as t-1
means that time series A at time t-1
caused time series B at time t
.
The above example can be found in simple_example.py
We assume Causal Markov Condition and Faithfulness throughout. The following assumptions are also made:
-
Causal inference of up to 12 variables and 3 time lags: The models can handle inputs up to 12 variables and
$\ell_\text{max}=1,2,3$ . -
No contemporaneous effects: All cause-effect pairs are assumed to occur with lag
$\ell > 0$ . -
No unobserved confounders: The model assumes that there are no latent confounders (unobserved common causes) that could influence the relationships between the variables.
-
Causal Stationarity: The graph structure and noise distribution of the SCM does not change over time.
-
Time-series stationarity: The input time-series are assumed to be stationary.