Skip to content

Experiments on counterfactual explanations for neural networks, based on the [latent shift method](https://arxiv.org/abs/2102.09475)

License

Notifications You must be signed in to change notification settings

augustebaum/tabsplanation

Repository files navigation

tabsplanation

pre-commit.ci status image

Installation

To get started, create the environment with

conda env create -f environment.yml

or

mamba create -f environment.yml

Then you'll need to install pytorch-lightning manually (don't ask me why!).

mamba activate tabsplanation
pip3 install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html
pip3 install lightning
pip3 install -e .

If the pip install -e doesn't work automatically, enter the environment and run it manually. This is the ensure that tabsplanation/src is in sys.path when pytask is run.

You might also need a working latex install if you want it to be used in the plots.

Usage

In the root of the project, in the environment, run

python3 -m src.experiments.run <name-of-experiment>

to run a given experiment. The name is given by the name of the directory in src/experiments.

By default, any output is captured by pytask. Hence, for visualizing the training of a model, it is recommended to use tensorboard. In the root of the project, run

tensorboard --logdir=bld/models/lightning_logs

and follow the instructions.

Rationale

This repository contains various experiments exploring a technique called Latent Shift, coined by Cohen et al. in their paper "Gifsplanation".

The point of pytask is to offer an elegant way to make data-related workflows cacheable, i.e. when each individual part of a workflow produces outputs, these outputs are saved so they can be re-used later and the workflow can be skipped.

Previously the same codebase was written with hydra in mind: hydra offers a somewhat intuitive system to parametrize an experiment using a configuration file, usually written in YAML. In this case an experiment workflow would be written as follows:

# `cfg` contains all configuration information about the workflow
@hydra.main(config="my_config.yaml")
def my_experiment(cfg):
    # Step 1
    data = create_data(cfg)
  
    # Step 2
    if load_models:
        models = load_models(path)
    else:
        models = train_models(cfg, data)
    
    # Step 3
    plot_data = create_plot_data(models, data)
    
    # Step 4
    plot = create_plot(plot_data)
    
    # Step 5
    show_plot(plot)

You can see that the "get models" part is cacheable: there is an option to ask the system to load models from a specific part. Indeed, of all the steps in the workflow, this step is by far the most time-consuming. However, the other steps are still re-run, every time. What is more, when loading models, the option is currently to pass a directory path; so when doing this, one must be certain that the data that was loaded with get_data is exactly the same as the one used to train the loaded models; the book-keeping has to be done manually, which is error-prone and frustrating.

Instead, it would be saner to divide up the workflow into each step, and let pytask handle the caching:

def task_create_dataset(depends_on, produces):
    # Load config
    cfg = depends_on["config"]

    # Create the dataset according to config
    data = create_dataset(cfg)

    # Cache results
    save(data, produces["data"])

and similarly for all the other steps. Now, instead of getting hydra to run my_experiment, you could just ask pytask to run the show_plot task according to a configuration file, and let it figure out what needs to be done to make that happen. In particular, if you ask it to run show_plot two times with the same config, the second run should be very quick because everything is cached; hence you can afford to run the whole workflow even if it's just to tweak the plot visuals.

Credits

This project was created with cookiecutter and the cookiecutter-pytask-project template.

TODO

  • The paths to files that are output by a task should be printed to the console when a task is run. The annoying thing is pytask captures everything by default, and there not yet enough granularity to surface a particular print.
  • There is no facility to allow multiple config files at the same time; the only option is the change the appropriate yaml file every time.

About

Experiments on counterfactual explanations for neural networks, based on the [latent shift method](https://arxiv.org/abs/2102.09475)

Topics

Resources

License

Stars

Watchers

Forks

Languages