Skip to content

Latest commit

 

History

History
131 lines (90 loc) · 4.44 KB

README.md

File metadata and controls

131 lines (90 loc) · 4.44 KB

Amortized Causal Discovery

This repo contains the official PyTorch implementation of:

Sindy Löwe*, David Madras*, Richard Zemel, Max Welling - Amortized Causal Discovery: Learning to Infer Causal Graphs from Time-Series Data

With Amortized Causal Discovery we learn to infer causal relations from samples with different underlying causal graphs but shared dynamics. This enables us to generalize across samples and thus improve our performance with increasing training data size.

*equal contribution

What is Amortized Causal Discovery?

With Amortized Causal Discovery, we separate causal relation prediction from dynamics modelling. Our amortized encoder learns to infer causal relations across samples with different underlying graphs. Our decoder learns to model the shared dynamics of the predicted relations.

This separation allows us to train a joint model for samples with different underlying causal graphs. This is in contrast to previous approaches, which need to refit a new model whenever they encounter samples with a different underlying causal graph.

What we found exciting is that this allows us to achieve tremendous improvements in causal inference performance with increasing training data size. Amortized Causal Discovery (ACD) manages to outperform previous causal discovery approaches with as little as 50 training samples; with 50.000 samples it outperforms them by more than 30% points.

How to run the code

Dependencies

  • Python and Conda

  • Setup the conda environment ACD by running:

    bash setup_dependencies.sh

    If you want to make use of your GPU, you might have to install a cuda-enabled pytorch version manually. Use the appropriate command provided here to achieve this.

  • Don't forget to activate the environment and cd into the codebase directory when playing with the code later on

    source activate ACD
    cd codebase

Datasets

  • To generate the particles with springs dataset from our paper, run

    python -m data.generate_dataset
  • To generate a particles dataset with varying latent temperature, run

    python -m data.generate_dataset --temperature_dist --temperature_alpha 2 --temperature_num_cats 3
  • To generate the Kuramoto dataset from our paper, run

    python -m data.generate_ODE_dataset
  • The Netsim dataset is available here

Experiments

  • Run the Springs experiment by running

     python -m train --suffix _springs5

    the Kuramoto experiment with

    python -m train --suffix _kuramoto5 --encoder cnn

    and the Netsim experiment with

    python -m train --suffix netsim
  • To run the experiment with an unobserved temperature variable, run

     python -m train --suffix _springs5 --encoder cnn --decoder sim --global_temp --load_temperatures
  • To run the experiment with an unobserved time-series, run

     python -m train --suffix _springs5 --unobserved 1
  • View all possible command-line options by running

    python -m train --help

Cite

Please cite our paper if you use this code in your own work:

@article{lowe2022amortized,
  title={Amortized Causal Discovery: Learning to Infer Causal Graphs from Time-Series Data},
  author={L{\"o}we, Sindy and Madras, David and Zemel, Richard, and Welling, Max},
  journal={Causal Learning and Reasoning (CLeaR)},
  year={2022}
}

References

Acknowledgements

The Robert Bosch GmbH is acknowledged for financial support.